diff --git a/llama/build-info.cpp b/llama/build-info.cpp index e6e66949..63732571 100644 --- a/llama/build-info.cpp +++ b/llama/build-info.cpp @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/clip.cpp b/llama/clip.cpp index 5cbb4532..2039bdc8 100644 --- a/llama/clip.cpp +++ b/llama/clip.cpp @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * @@ -42,6 +42,10 @@ #include "ggml-metal.h" #endif +#ifdef GGML_USE_CANN +#include "ggml-cann.h" +#endif + #define STB_IMAGE_IMPLEMENTATION #include "stb_image.h" @@ -891,7 +895,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 embeddings = peg_0; } else { - GGML_ASSERT(false); + GGML_ABORT("fatal error"); } } @@ -1027,6 +1031,11 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { LOG_TEE("%s: CLIP using Metal backend\n", __func__); #endif +#ifdef GGML_USE_CANN + new_clip->backend = ggml_backend_cann_init(0); + LOG_TEE("%s: CLIP using CANN backend\n", __func__); +#endif + if (!new_clip->backend) { new_clip->backend = ggml_backend_cpu_init(); @@ -1147,20 +1156,20 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { } if (n < 32) hparams.image_grid_pinpoints[n] = 0; - } catch (std::runtime_error & e) { + } catch (std::runtime_error & /*e*/) { hparams.image_grid_pinpoints[0]=0; } try { int idx = get_key_idx(ctx, KEY_MM_PATCH_MERGE_TYPE); strcpy(hparams.mm_patch_merge_type, gguf_get_val_str(ctx, idx)); - } catch (std::runtime_error & e) { + } catch (std::runtime_error & /*e*/) { strcpy(hparams.mm_patch_merge_type, "flat"); } try { hparams.image_crop_resolution = get_u32(ctx, KEY_IMAGE_CROP_RESOLUTION); // llava-1.6 - } catch(const std::exception& e) { + } catch(const std::exception& /*e*/) { hparams.image_crop_resolution = hparams.image_size; } @@ -1199,7 +1208,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { try { vision_model.class_embedding = get_tensor(new_clip->ctx_data, TN_CLASS_EMBD); new_clip->has_class_embedding = true; - } catch (const std::exception& e) { + } catch (const std::exception& /*e*/) { new_clip->has_class_embedding = false; } @@ -1207,7 +1216,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { vision_model.pre_ln_w = get_tensor(new_clip->ctx_data, format(TN_LN_PRE, "v", "weight")); vision_model.pre_ln_b = get_tensor(new_clip->ctx_data, format(TN_LN_PRE, "v", "bias")); new_clip->has_pre_norm = true; - } catch (std::exception & e) { + } catch (std::exception & /*e*/) { new_clip->has_pre_norm = false; } @@ -1215,21 +1224,21 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { vision_model.post_ln_w = get_tensor(new_clip->ctx_data, format(TN_LN_POST, "v", "weight")); vision_model.post_ln_b = get_tensor(new_clip->ctx_data, format(TN_LN_POST, "v", "bias")); new_clip->has_post_norm = true; - } catch (std::exception & e) { + } catch (std::exception & /*e*/) { new_clip->has_post_norm = false; } try { vision_model.patch_bias = get_tensor(new_clip->ctx_data, TN_PATCH_BIAS); new_clip->has_patch_bias = true; - } catch (std::exception & e) { + } catch (std::exception & /*e*/) { new_clip->has_patch_bias = false; } try { vision_model.patch_embeddings = get_tensor(new_clip->ctx_data, TN_PATCH_EMBD); vision_model.position_embeddings = get_tensor(new_clip->ctx_data, format(TN_POS_EMBD, "v")); - } catch(const std::exception& e) { + } catch(const std::exception& /*e*/) { LOG_TEE("%s: failed to load vision model tensors\n", __func__); } @@ -1241,26 +1250,26 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { // Yi-type llava vision_model.mm_1_w = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 1, "weight")); vision_model.mm_1_b = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 1, "bias")); - } catch (std::runtime_error & e) { } + } catch (std::runtime_error & /*e*/) { } try { // missing in Yi-type llava vision_model.mm_2_w = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 2, "weight")); vision_model.mm_2_b = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 2, "bias")); - } catch (std::runtime_error & e) { } + } catch (std::runtime_error & /*e*/) { } try { // Yi-type llava vision_model.mm_3_w = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 3, "weight")); vision_model.mm_3_b = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 3, "bias")); - } catch (std::runtime_error & e) { } + } catch (std::runtime_error & /*e*/) { } try { // Yi-type llava vision_model.mm_4_w = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 4, "weight")); vision_model.mm_4_b = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 4, "bias")); - } catch (std::runtime_error & e) { } + } catch (std::runtime_error & /*e*/) { } try { vision_model.image_newline = get_tensor(new_clip->ctx_data, TN_IMAGE_NEWLINE); // LOG_TEE("%s: image_newline tensor (llava-1.6) found\n", __func__); - } catch (std::runtime_error & e) { } + } catch (std::runtime_error & /*e*/) { } } else if (new_clip->proj_type == PROJECTOR_TYPE_LDP) { // MobileVLM projection vision_model.mm_model_mlp_1_w = get_tensor(new_clip->ctx_data, format(TN_MVLM_PROJ_MLP, 1, "weight")); diff --git a/llama/clip.h b/llama/clip.h index 9d95e0b2..8665ad6a 100644 --- a/llama/clip.h +++ b/llama/clip.h @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/common.cpp b/llama/common.cpp index d87a1dea..f542c129 100644 --- a/llama/common.cpp +++ b/llama/common.cpp @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * @@ -24,6 +24,10 @@ * SOFTWARE. */ +#if defined(_MSC_VER) +#define _SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING +#endif + #include "common.h" // Change JSON_ASSERT from assert() to GGML_ASSERT: #define JSON_ASSERT GGML_ASSERT @@ -32,7 +36,6 @@ #include "llama.h" #include -#include #include #include #include @@ -217,6 +220,12 @@ int32_t cpu_get_num_math() { // CLI argument parsing // +void gpt_params_handle_hf_token(gpt_params & params) { + if (params.hf_token.empty() && std::getenv("HF_TOKEN")) { + params.hf_token = std::getenv("HF_TOKEN"); + } +} + void gpt_params_handle_model_default(gpt_params & params) { if (!params.hf_repo.empty()) { // short-hand to avoid specifying --hf-file -> default it to --model @@ -226,19 +235,13 @@ void gpt_params_handle_model_default(gpt_params & params) { } params.hf_file = params.model; } else if (params.model.empty()) { - std::string cache_directory = fs_get_cache_directory(); - const bool success = fs_create_directory_with_parents(cache_directory); - if (!success) { - throw std::runtime_error("failed to create cache directory: " + cache_directory); - } - params.model = cache_directory + string_split(params.hf_file, '/').back(); + params.model = fs_get_cache_file(string_split(params.hf_file, '/').back()); } } else if (!params.model_url.empty()) { if (params.model.empty()) { auto f = string_split(params.model_url, '#').front(); f = string_split(f, '?').front(); - f = string_split(f, '/').back(); - params.model = "models/" + f; + params.model = fs_get_cache_file(string_split(f, '/').back()); } } else if (params.model.empty()) { params.model = DEFAULT_MODEL_PATH; @@ -270,6 +273,8 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { gpt_params_handle_model_default(params); + gpt_params_handle_hf_token(params); + if (params.escape) { string_process_escapes(params.prompt); string_process_escapes(params.input_prefix); @@ -306,26 +311,22 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { return true; } +#define CHECK_ARG if (++i >= argc) { invalid_param = true; return true; } + bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_params & params, int & i, bool & invalid_param) { const char split_delim = ','; llama_sampling_params & sparams = params.sparams; if (arg == "-s" || arg == "--seed") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG // TODO: this is temporary, in the future the sampling state will be moved fully to llama_sampling_context. params.seed = std::stoul(argv[i]); sparams.seed = std::stoul(argv[i]); return true; } if (arg == "-t" || arg == "--threads") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.n_threads = std::stoi(argv[i]); if (params.n_threads <= 0) { params.n_threads = std::thread::hardware_concurrency(); @@ -333,10 +334,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "-tb" || arg == "--threads-batch") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.n_threads_batch = std::stoi(argv[i]); if (params.n_threads_batch <= 0) { params.n_threads_batch = std::thread::hardware_concurrency(); @@ -344,10 +342,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "-td" || arg == "--threads-draft") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.n_threads_draft = std::stoi(argv[i]); if (params.n_threads_draft <= 0) { params.n_threads_draft = std::thread::hardware_concurrency(); @@ -355,10 +350,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "-tbd" || arg == "--threads-batch-draft") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.n_threads_batch_draft = std::stoi(argv[i]); if (params.n_threads_batch_draft <= 0) { params.n_threads_batch_draft = std::thread::hardware_concurrency(); @@ -366,10 +358,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "-p" || arg == "--prompt") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.prompt = argv[i]; return true; } @@ -382,10 +371,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "--prompt-cache") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.path_prompt_cache = argv[i]; return true; } @@ -398,10 +384,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "-bf" || arg == "--binary-file") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG std::ifstream file(argv[i], std::ios::binary); if (!file) { fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); @@ -417,10 +400,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "-f" || arg == "--file") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG std::ifstream file(argv[i]); if (!file) { fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); @@ -436,10 +416,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "--in-file") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG std::ifstream file(argv[i]); if (!file) { fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); @@ -450,66 +427,42 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "-n" || arg == "--predict" || arg == "--n-predict") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.n_predict = std::stoi(argv[i]); return true; } if (arg == "--top-k") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG sparams.top_k = std::stoi(argv[i]); return true; } if (arg == "-c" || arg == "--ctx-size") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.n_ctx = std::stoi(argv[i]); return true; } if (arg == "--grp-attn-n" || arg == "-gan") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.grp_attn_n = std::stoi(argv[i]); return true; } if (arg == "--grp-attn-w" || arg == "-gaw") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.grp_attn_w = std::stoi(argv[i]); return true; } if (arg == "--rope-freq-base") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.rope_freq_base = std::stof(argv[i]); return true; } if (arg == "--rope-freq-scale") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.rope_freq_scale = std::stof(argv[i]); return true; } if (arg == "--rope-scaling") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG std::string value(argv[i]); /**/ if (value == "none") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_NONE; } else if (value == "linear") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_LINEAR; } @@ -518,217 +471,148 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "--rope-scale") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.rope_freq_scale = 1.0f / std::stof(argv[i]); return true; } if (arg == "--yarn-orig-ctx") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.yarn_orig_ctx = std::stoi(argv[i]); return true; } if (arg == "--yarn-ext-factor") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.yarn_ext_factor = std::stof(argv[i]); return true; } if (arg == "--yarn-attn-factor") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.yarn_attn_factor = std::stof(argv[i]); return true; } if (arg == "--yarn-beta-fast") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.yarn_beta_fast = std::stof(argv[i]); return true; } if (arg == "--yarn-beta-slow") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.yarn_beta_slow = std::stof(argv[i]); return true; } if (arg == "--pooling") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG std::string value(argv[i]); /**/ if (value == "none") { params.pooling_type = LLAMA_POOLING_TYPE_NONE; } else if (value == "mean") { params.pooling_type = LLAMA_POOLING_TYPE_MEAN; } else if (value == "cls") { params.pooling_type = LLAMA_POOLING_TYPE_CLS; } + else if (value == "last") { params.pooling_type = LLAMA_POOLING_TYPE_LAST; } + else { invalid_param = true; } + return true; + } + if (arg == "--attention") { + CHECK_ARG + std::string value(argv[i]); + /**/ if (value == "causal") { params.attention_type = LLAMA_ATTENTION_TYPE_CAUSAL; } + else if (value == "non-causal") { params.attention_type = LLAMA_ATTENTION_TYPE_NON_CAUSAL; } else { invalid_param = true; } return true; } if (arg == "--defrag-thold" || arg == "-dt") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.defrag_thold = std::stof(argv[i]); return true; } if (arg == "--samplers") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG const auto sampler_names = string_split(argv[i], ';'); sparams.samplers_sequence = llama_sampling_types_from_names(sampler_names, true); return true; } if (arg == "--sampling-seq") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG sparams.samplers_sequence = llama_sampling_types_from_chars(argv[i]); return true; } if (arg == "--top-p") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG sparams.top_p = std::stof(argv[i]); return true; } if (arg == "--min-p") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG sparams.min_p = std::stof(argv[i]); return true; } if (arg == "--temp") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG sparams.temp = std::stof(argv[i]); sparams.temp = std::max(sparams.temp, 0.0f); return true; } if (arg == "--tfs") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG sparams.tfs_z = std::stof(argv[i]); return true; } if (arg == "--typical") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG sparams.typical_p = std::stof(argv[i]); return true; } if (arg == "--repeat-last-n") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG sparams.penalty_last_n = std::stoi(argv[i]); sparams.n_prev = std::max(sparams.n_prev, sparams.penalty_last_n); return true; } if (arg == "--repeat-penalty") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG sparams.penalty_repeat = std::stof(argv[i]); return true; } if (arg == "--frequency-penalty") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG sparams.penalty_freq = std::stof(argv[i]); return true; } if (arg == "--presence-penalty") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG sparams.penalty_present = std::stof(argv[i]); return true; } if (arg == "--dynatemp-range") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG sparams.dynatemp_range = std::stof(argv[i]); return true; } if (arg == "--dynatemp-exp") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG sparams.dynatemp_exponent = std::stof(argv[i]); return true; } if (arg == "--mirostat") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG sparams.mirostat = std::stoi(argv[i]); return true; } if (arg == "--mirostat-lr") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG sparams.mirostat_eta = std::stof(argv[i]); return true; } if (arg == "--mirostat-ent") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG sparams.mirostat_tau = std::stof(argv[i]); return true; } if (arg == "--cfg-negative-prompt") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG sparams.cfg_negative_prompt = argv[i]; return true; } if (arg == "--cfg-negative-prompt-file") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG std::ifstream file(argv[i]); if (!file) { fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); @@ -742,203 +626,126 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "--cfg-scale") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG sparams.cfg_scale = std::stof(argv[i]); return true; } if (arg == "-b" || arg == "--batch-size") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.n_batch = std::stoi(argv[i]); return true; } if (arg == "-ub" || arg == "--ubatch-size") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.n_ubatch = std::stoi(argv[i]); return true; } if (arg == "--keep") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.n_keep = std::stoi(argv[i]); return true; } if (arg == "--draft") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.n_draft = std::stoi(argv[i]); return true; } if (arg == "--chunks") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.n_chunks = std::stoi(argv[i]); return true; } if (arg == "-np" || arg == "--parallel") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.n_parallel = std::stoi(argv[i]); return true; } if (arg == "-ns" || arg == "--sequences") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.n_sequences = std::stoi(argv[i]); return true; } if (arg == "--p-split" || arg == "-ps") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.p_split = std::stof(argv[i]); return true; } if (arg == "-m" || arg == "--model") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.model = argv[i]; return true; } if (arg == "-md" || arg == "--model-draft") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.model_draft = argv[i]; return true; } if (arg == "-a" || arg == "--alias") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.model_alias = argv[i]; return true; } if (arg == "-mu" || arg == "--model-url") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.model_url = argv[i]; return true; } - if (arg == "-hfr" || arg == "--hf-repo") { + if (arg == "-hft" || arg == "--hf-token") { if (++i >= argc) { - invalid_param = true; - return true; + invalid_param = true; + return true; } + params.hf_token = argv[i]; + return true; + } + if (arg == "-hfr" || arg == "--hf-repo") { + CHECK_ARG params.hf_repo = argv[i]; return true; } if (arg == "-hff" || arg == "--hf-file") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.hf_file = argv[i]; return true; } if (arg == "--lora") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.lora_adapter.emplace_back(argv[i], 1.0f); - params.use_mmap = false; return true; } if (arg == "--lora-scaled") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG const char* lora_adapter = argv[i]; - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.lora_adapter.emplace_back(lora_adapter, std::stof(argv[i])); - params.use_mmap = false; - return true; - } - if (arg == "--lora-base") { - if (++i >= argc) { - invalid_param = true; - return true; - } - params.lora_base = argv[i]; return true; } if (arg == "--control-vector") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.control_vectors.push_back({ 1.0f, argv[i], }); return true; } if (arg == "--control-vector-scaled") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG const char* fname = argv[i]; - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.control_vectors.push_back({ std::stof(argv[i]), fname, }); return true; } if (arg == "--control-vector-layer-range") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.control_vector_layer_start = std::stoi(argv[i]); - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.control_vector_layer_end = std::stoi(argv[i]); return true; } if (arg == "--mmproj") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.mmproj = argv[i]; return true; } if (arg == "--image") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.image.emplace_back(argv[i]); return true; } @@ -954,6 +761,21 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.embedding = true; return true; } + if (arg == "--embd-normalize") { + CHECK_ARG + params.embd_normalize = std::stoi(argv[i]); + return true; + } + if (arg == "--embd-output-format") { + CHECK_ARG + params.embd_out = argv[i]; + return true; + } + if (arg == "--embd-separator") { + CHECK_ARG + params.embd_sep = argv[i]; + return true; + } if (arg == "-if" || arg == "--interactive-first") { params.interactive_first = true; return true; @@ -982,7 +804,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.cache_type_v = argv[++i]; return true; } - if (arg == "--multiline-input") { + if (arg == "-mli" || arg == "--multiline-input") { params.multiline_input = true; return true; } @@ -994,6 +816,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.cont_batching = true; return true; } + if (arg == "-nocb" || arg == "--no-cont-batching") { + params.cont_batching = false; + return true; + } if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; return true; @@ -1007,10 +833,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "-ngl" || arg == "--gpu-layers" || arg == "--n-gpu-layers") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.n_gpu_layers = std::stoi(argv[i]); if (!llama_supports_gpu_offload()) { fprintf(stderr, "warning: not compiled with GPU offload support, --gpu-layers option will be ignored\n"); @@ -1019,10 +842,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "-ngld" || arg == "--gpu-layers-draft" || arg == "--gpu-layers-draft") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.n_gpu_layers_draft = std::stoi(argv[i]); if (!llama_supports_gpu_offload()) { fprintf(stderr, "warning: not compiled with GPU offload support, --gpu-layers-draft option will be ignored\n"); @@ -1031,10 +851,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "--main-gpu" || arg == "-mg") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.main_gpu = std::stoi(argv[i]); #ifndef GGML_USE_CUDA_SYCL_VULKAN fprintf(stderr, "warning: llama.cpp was compiled without CUDA/SYCL/Vulkan. Setting the main GPU has no effect.\n"); @@ -1042,10 +859,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "--split-mode" || arg == "-sm") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG std::string arg_next = argv[i]; if (arg_next == "none") { params.split_mode = LLAMA_SPLIT_MODE_NONE; @@ -1070,10 +884,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "--tensor-split" || arg == "-ts") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG std::string arg_next = argv[i]; // split string by , and / @@ -1098,10 +909,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "--rpc") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.rpc_servers = argv[i]; return true; } @@ -1110,10 +918,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "--numa") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG std::string value(argv[i]); /**/ if (value == "distribute" || value == "") { params.numa = GGML_NUMA_STRATEGY_DISTRIBUTE; } else if (value == "isolate") { params.numa = GGML_NUMA_STRATEGY_ISOLATE; } @@ -1126,10 +931,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "--verbosity") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.verbosity = std::stoi(argv[i]); return true; } @@ -1142,18 +944,12 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "-r" || arg == "--reverse-prompt") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.antiprompt.emplace_back(argv[i]); return true; } if (arg == "-ld" || arg == "--logdir") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.logdir = argv[i]; if (params.logdir.back() != DIRECTORY_SEPARATOR) { @@ -1162,26 +958,17 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "-lcs" || arg == "--lookup-cache-static") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.lookup_cache_static = argv[i]; return true; } if (arg == "-lcd" || arg == "--lookup-cache-dynamic") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.lookup_cache_dynamic = argv[i]; return true; } if (arg == "--save-all-logits" || arg == "--kl-divergence-base") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.logits_file = argv[i]; return true; } @@ -1190,26 +977,17 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "--ppl-stride") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.ppl_stride = std::stoi(argv[i]); return true; } if (arg == "--ppl-output-type") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.ppl_output_type = std::stoi(argv[i]); return true; } if (arg == "-ptc" || arg == "--print-token-count") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.n_print = std::stoi(argv[i]); return true; } @@ -1222,10 +1000,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "--hellaswag-tasks") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.hellaswag_tasks = std::stoi(argv[i]); return true; } @@ -1234,10 +1009,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "--winogrande-tasks") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.winogrande_tasks = std::stoi(argv[i]); return true; } @@ -1246,10 +1018,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "--multiple-choice-tasks") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.multiple_choice_tasks = std::stoi(argv[i]); return true; } @@ -1266,10 +1035,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "-l" || arg == "--logit-bias") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG std::stringstream ss(argv[i]); llama_token key; char sign; @@ -1299,37 +1065,32 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa } if (arg == "--in-prefix-bos") { params.input_prefix_bos = true; + params.enable_chat_template = false; return true; } if (arg == "--in-prefix") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.input_prefix = argv[i]; + params.enable_chat_template = false; return true; } if (arg == "--in-suffix") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.input_suffix = argv[i]; + params.enable_chat_template = false; + return true; + } + if (arg == "--spm-infill") { + params.spm_infill = true; return true; } if (arg == "--grammar") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG sparams.grammar = argv[i]; return true; } if (arg == "--grammar-file") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG std::ifstream file(argv[i]); if (!file) { fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); @@ -1344,18 +1105,12 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "-j" || arg == "--json-schema") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG sparams.grammar = json_schema_to_grammar(json::parse(argv[i])); return true; } if (arg == "--override-kv") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG if (!string_parse_kv_override(argv[i], params.kv_overrides)) { fprintf(stderr, "error: Invalid type for KV override: %s\n", argv[i]); invalid_param = true; @@ -1364,42 +1119,27 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "--host") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.hostname = argv[i]; return true; } if (arg == "--port") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.port = std::stoi(argv[i]); return true; } if (arg == "--path") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.public_path = argv[i]; return true; } if (arg == "--api-key") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.api_keys.push_back(argv[i]); return true; } if (arg == "--api-key-file") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG std::ifstream key_file(argv[i]); if (!key_file) { fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); @@ -1416,43 +1156,28 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "--ssl-key-file") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.ssl_file_key = argv[i]; return true; } if (arg == "--ssl-cert-file") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.ssl_file_cert = argv[i]; return true; } if (arg == "--timeout" || arg == "-to") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.timeout_read = std::stoi(argv[i]); params.timeout_write = std::stoi(argv[i]); return true; } if (arg == "--threads-http") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.n_threads_http = std::stoi(argv[i]); return true; } if (arg == "-spf" || arg == "--system-prompt-file") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG std::ifstream file(argv[i]); if (!file) { fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); @@ -1469,10 +1194,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "--log-format") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG if (std::strcmp(argv[i], "json") == 0) { params.log_json = true; } else if (std::strcmp(argv[i], "text") == 0) { @@ -1492,10 +1214,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "--slot-save-path") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.slot_save_path = argv[i]; // if doesn't end with DIRECTORY_SEPARATOR, add it if (!params.slot_save_path.empty() && params.slot_save_path[params.slot_save_path.size() - 1] != DIRECTORY_SEPARATOR) { @@ -1504,10 +1223,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "--chat-template") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG if (!llama_chat_verify_template(argv[i])) { fprintf(stderr, "error: the supplied chat template is not supported: %s\n", argv[i]); fprintf(stderr, "note: llama.cpp does not use jinja parser, we only support commonly used templates\n"); @@ -1517,42 +1233,35 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.chat_template = argv[i]; return true; } + if (arg == "--slot-prompt-similarity" || arg == "-sps") { + CHECK_ARG + params.slot_prompt_similarity = std::stof(argv[i]); + return true; + } if (arg == "-pps") { params.is_pp_shared = true; return true; } if (arg == "-npp") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG auto p = string_split(argv[i], split_delim); params.n_pp.insert(params.n_pp.end(), p.begin(), p.end()); return true; } if (arg == "-ntg") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG auto p = string_split(argv[i], split_delim); params.n_tg.insert(params.n_tg.end(), p.begin(), p.end()); return true; } if (arg == "-npl") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG auto p = string_split(argv[i], split_delim); params.n_pl.insert(params.n_pl.end(), p.begin(), p.end()); return true; } if (arg == "--context-file") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG std::ifstream file(argv[i], std::ios::binary); if (!file) { fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); @@ -1563,58 +1272,39 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "--chunk-size") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.chunk_size = std::stoi(argv[i]); return true; } if (arg == "--chunk-separator") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.chunk_separator = argv[i]; return true; } if (arg == "--junk") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.n_junk = std::stoi(argv[i]); return true; } if (arg == "--pos") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.i_pos = std::stoi(argv[i]); return true; } if (arg == "-o" || arg == "--output" || arg == "--output-file") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.out_file = argv[i]; + params.cvector_outfile = argv[i]; + params.lora_outfile = argv[i]; return true; } if (arg == "-ofreq" || arg == "--output-frequency") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.n_out_freq = std::stoi(argv[i]); return true; } if (arg == "--save-frequency") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.n_save_freq = std::stoi(argv[i]); return true; } @@ -1627,13 +1317,43 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "--chunk" || arg == "--from-chunk") { - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG params.i_chunk = std::stoi(argv[i]); return true; } + // cvector params + if (arg == "--positive-file") { + CHECK_ARG + params.cvector_positive_file = argv[i]; + return true; + } + if (arg == "--negative-file") { + CHECK_ARG + params.cvector_negative_file = argv[i]; + return true; + } + if (arg == "--pca-batch") { + CHECK_ARG + params.n_pca_batch = std::stoi(argv[i]); + return true; + } + if (arg == "--pca-iter") { + CHECK_ARG + params.n_pca_iterations = std::stoi(argv[i]); + return true; + } + if (arg == "--method") { + CHECK_ARG + std::string value(argv[i]); + /**/ if (value == "pca") { params.cvector_dimre_method = DIMRE_METHOD_PCA; } + else if (value == "mean") { params.cvector_dimre_method = DIMRE_METHOD_MEAN; } + else { invalid_param = true; } + return true; + } + if (arg == "--no-warmup") { + params.warmup = false; + return true; + } #ifndef LOG_DISABLE_LOGS // Parse args for logging parameters if (log_param_single_parse(argv[i])) { @@ -1645,10 +1365,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa // We have a matching known parameter requiring an argument, // now we need to check if there is anything after this argv // and flag invalid_param or parse it. - if (++i >= argc) { - invalid_param = true; - return true; - } + CHECK_ARG if (!log_param_pair_parse( /*check_but_dont_parse*/ false, argv[i - 1], argv[i])) { invalid_param = true; return true; @@ -1733,7 +1450,9 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "*", " --keep N", "number of tokens to keep from the initial prompt (default: %d, -1 = all)", params.n_keep }); options.push_back({ "*", " --chunks N", "max number of chunks to process (default: %d, -1 = all)", params.n_chunks }); options.push_back({ "*", "-fa, --flash-attn", "enable Flash Attention (default: %s)", params.flash_attn ? "enabled" : "disabled" }); - options.push_back({ "*", "-p, --prompt PROMPT", "prompt to start generation with (default: '%s')", params.prompt.c_str() }); + options.push_back({ "*", "-p, --prompt PROMPT", "prompt to start generation with\n" + "in conversation mode, this will be used as system prompt\n" + "(default: '%s')", params.prompt.c_str() }); options.push_back({ "*", "-f, --file FNAME", "a file containing the prompt (default: none)" }); options.push_back({ "*", " --in-file FNAME", "an input file (repeat to specify multiple files)" }); options.push_back({ "*", "-bf, --binary-file FNAME", "binary file containing the prompt (default: none)" }); @@ -1748,13 +1467,18 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param "halt generation at PROMPT, return control in interactive mode\n" "can be specified more than once for multiple prompts" }); options.push_back({ "main", "-sp, --special", "special tokens output enabled (default: %s)", params.special ? "true" : "false" }); - options.push_back({ "main", "-cnv, --conversation", "run in conversation mode (does not print special tokens and suffix/prefix) (default: %s)", params.conversation ? "true" : "false" }); + options.push_back({ "main", "-cnv, --conversation", "run in conversation mode, does not print special tokens and suffix/prefix\n" + "if suffix/prefix are not specified, default chat template will be used\n" + "(default: %s)", params.conversation ? "true" : "false" }); options.push_back({ "main infill", "-i, --interactive", "run in interactive mode (default: %s)", params.interactive ? "true" : "false" }); options.push_back({ "main infill", "-if, --interactive-first", "run in interactive mode and wait for input right away (default: %s)", params.interactive_first ? "true" : "false" }); options.push_back({ "main infill", "-mli, --multiline-input", "allows you to write or paste multiple lines without ending each in '\\'" }); options.push_back({ "main infill", " --in-prefix-bos", "prefix BOS to user inputs, preceding the `--in-prefix` string" }); options.push_back({ "main infill", " --in-prefix STRING", "string to prefix user inputs with (default: empty)" }); options.push_back({ "main infill", " --in-suffix STRING", "string to suffix after user inputs with (default: empty)" }); + options.push_back({ "main", " --no-warmup", "skip warming up the model with an empty run" }); + options.push_back({ "server infill", + " --spm-infill", "use Suffix/Prefix/Middle pattern for infill (instead of Prefix/Suffix/Middle) as some models prefer this. (default: %s)", params.spm_infill ? "enabled" : "disabled" }); options.push_back({ "sampling" }); options.push_back({ "*", " --samplers SAMPLERS", "samplers that will be used for generation in the order, separated by \';\'\n" @@ -1788,7 +1512,11 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "main", " --cfg-negative-prompt-file FNAME", "negative prompt file to use for guidance" }); options.push_back({ "main", " --cfg-scale N", "strength of guidance (default: %.1f, 1.0 = disable)", (double)sparams.cfg_scale }); - + options.push_back({ "main", " --chat-template JINJA_TEMPLATE", + "set custom jinja chat template (default: template taken from model's metadata)\n" + "if suffix/prefix are specified, template will be disabled\n" + "only commonly used templates are accepted:\n" + "https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template" }); options.push_back({ "grammar" }); options.push_back({ "*", " --grammar GRAMMAR", "BNF-like grammar to constrain generations (see samples in grammars/ dir) (default: '%s')", sparams.grammar.c_str() }); options.push_back({ "*", " --grammar-file FNAME", "file to read grammar from" }); @@ -1797,8 +1525,10 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param "For schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead" }); options.push_back({ "embedding" }); - options.push_back({ "embedding", " --pooling {none,mean,cls}", + options.push_back({ "embedding", " --pooling {none,mean,cls,last}", "pooling type for embeddings, use model default if unspecified" }); + options.push_back({ "embedding", " --attention {causal,non-causal}", + "attention type for embeddings, use model default if unspecified" }); options.push_back({ "context hacking" }); options.push_back({ "*", " --rope-scaling {none,linear,yarn}", @@ -1837,6 +1567,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "*", "-np, --parallel N", "number of parallel sequences to decode (default: %d)", params.n_parallel }); options.push_back({ "*", "-ns, --sequences N", "number of sequences to decode (default: %d)", params.n_sequences }); options.push_back({ "*", "-cb, --cont-batching", "enable continuous batching (a.k.a dynamic batching) (default: %s)", params.cont_batching ? "enabled" : "disabled" }); + options.push_back({ "*", "-nocb, --no-cont-batching", "disable continuous batching" }); options.push_back({ "multi-modality" }); options.push_back({ "*", " --mmproj FILE", "path to a multimodal projector file for LLaVA. see examples/llava/README.md" }); @@ -1844,6 +1575,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "backend" }); options.push_back({ "*", " --rpc SERVERS", "comma separated list of RPC servers" }); + if (llama_supports_mlock()) { options.push_back({ "*", " --mlock", "force system to keep model in RAM rather than swapping or compressing" }); } @@ -1878,12 +1610,13 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "*", " --override-kv KEY=TYPE:VALUE", "advanced option to override model metadata by key. may be specified multiple times.\n" "types: int, float, bool, str. example: --override-kv tokenizer.ggml.add_bos_token=bool:false" }); - options.push_back({ "*", " --lora FNAME", "apply LoRA adapter (implies --no-mmap)" }); - options.push_back({ "*", " --lora-scaled FNAME S", "apply LoRA adapter with user defined scaling S (implies --no-mmap)" }); - options.push_back({ "*", " --lora-base FNAME", "optional model to use as a base for the layers modified by the LoRA adapter" }); - options.push_back({ "*", " --control-vector FNAME", "add a control vector" }); + options.push_back({ "*", " --lora FNAME", "apply LoRA adapter (can be repeated to use multiple adapters)" }); + options.push_back({ "*", " --lora-scaled FNAME S", "apply LoRA adapter with user defined scaling S (can be repeated to use multiple adapters)" }); + options.push_back({ "*", " --control-vector FNAME", "add a control vector\n" + "note: this argument can be repeated to add multiple control vectors" }); options.push_back({ "*", " --control-vector-scaled FNAME SCALE", - "add a control vector with user defined scaling SCALE" }); + "add a control vector with user defined scaling SCALE\n" + "note: this argument can be repeated to add multiple scaled control vectors" }); options.push_back({ "*", " --control-vector-layer-range START END", "layer range to apply the control vector(s) to, start and end inclusive" }); options.push_back({ "*", "-m, --model FNAME", "model path (default: models/$filename with filename from --hf-file\n" @@ -1892,6 +1625,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "*", "-mu, --model-url MODEL_URL", "model download url (default: unused)" }); options.push_back({ "*", "-hfr, --hf-repo REPO", "Hugging Face model repository (default: unused)" }); options.push_back({ "*", "-hff, --hf-file FILE", "Hugging Face model file (default: unused)" }); + options.push_back({ "*", "-hft, --hf-token TOKEN", "Hugging Face access token (default: value from HF_TOKEN environment variable)" }); options.push_back({ "retrieval" }); options.push_back({ "retrieval", " --context-file FNAME", "file to load context from (repeat to specify multiple files)" }); @@ -1917,6 +1651,11 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "bench", "-ntg n0,n1,...", "number of text generation tokens" }); options.push_back({ "bench", "-npl n0,n1,...", "number of parallel prompts" }); + options.push_back({ "embedding" }); + options.push_back({ "embedding", " --embd-normalize", "normalisation for embendings (default: %d) (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)", params.embd_normalize }); + options.push_back({ "embedding", " --embd-output-format", "empty = default, \"array\" = [[],[]...], \"json\" = openai style, \"json+\" = same \"json\" + cosine similarity matrix" }); + options.push_back({ "embedding", " --embd-separator", "separator of embendings (default \\n) for example \"<#sep#>\"" }); + options.push_back({ "server" }); options.push_back({ "server", " --host HOST", "ip address to listen (default: %s)", params.hostname.c_str() }); options.push_back({ "server", " --port PORT", "port to listen (default: %d)", params.port }); @@ -1939,6 +1678,8 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param "set custom jinja chat template (default: template taken from model's metadata)\n" "only commonly used templates are accepted:\n" "https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template" }); + options.push_back({ "server", "-sps, --slot-prompt-similarity SIMILARITY", + "how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f, 0.0 = disabled)\n", params.slot_prompt_similarity }); #ifndef LOG_DISABLE_LOGS options.push_back({ "logging" }); @@ -1953,6 +1694,21 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "logging", " --log-append", "Don't truncate the old log file." }); #endif // LOG_DISABLE_LOGS + options.push_back({ "cvector" }); + options.push_back({ "cvector", "-o, --output FNAME", "output file (default: '%s')", params.cvector_outfile.c_str() }); + options.push_back({ "cvector", " --positive-file FNAME", "positive prompts file, one prompt per line (default: '%s')", params.cvector_positive_file.c_str() }); + options.push_back({ "cvector", " --negative-file FNAME", "negative prompts file, one prompt per line (default: '%s')", params.cvector_negative_file.c_str() }); + options.push_back({ "cvector", " --pca-batch N", "batch size used for PCA. Larger batch runs faster, but uses more memory (default: %d)", params.n_pca_batch }); + options.push_back({ "cvector", " --pca-iter N", "number of iterations used for PCA (default: %d)", params.n_pca_iterations }); + options.push_back({ "cvector", " --method {pca,mean}", "dimensionality reduction method to be used (default: pca)" }); + + options.push_back({ "export-lora" }); + options.push_back({ "export-lora", "-m, --model", "model path from which to load base model (default '%s')", params.model.c_str() }); + options.push_back({ "export-lora", " --lora FNAME", "path to LoRA adapter (can be repeated to use multiple adapters)" }); + options.push_back({ "export-lora", " --lora-scaled FNAME S", "path to LoRA adapter with user defined scaling S (can be repeated to use multiple adapters)" }); + options.push_back({ "*", "-t, --threads N", "number of threads to use during computation (default: %d)", params.n_threads }); + options.push_back({ "export-lora", "-o, --output FNAME", "output file (default: '%s')", params.lora_outfile.c_str() }); + printf("usage: %s [options]\n", argv[0]); for (const auto & o : options) { @@ -2295,6 +2051,16 @@ std::string fs_get_cache_directory() { return ensure_trailing_slash(cache_directory); } +std::string fs_get_cache_file(const std::string & filename) { + GGML_ASSERT(filename.find(DIRECTORY_SEPARATOR) == std::string::npos); + std::string cache_directory = fs_get_cache_directory(); + const bool success = fs_create_directory_with_parents(cache_directory); + if (!success) { + throw std::runtime_error("failed to create cache directory: " + cache_directory); + } + return cache_directory + filename; +} + // // Model utils @@ -2306,9 +2072,9 @@ std::tuple llama_init_from_gpt_par llama_model * model = nullptr; if (!params.hf_repo.empty() && !params.hf_file.empty()) { - model = llama_load_model_from_hf(params.hf_repo.c_str(), params.hf_file.c_str(), params.model.c_str(), mparams); + model = llama_load_model_from_hf(params.hf_repo.c_str(), params.hf_file.c_str(), params.model.c_str(), params.hf_token.c_str(), mparams); } else if (!params.model_url.empty()) { - model = llama_load_model_from_url(params.model_url.c_str(), params.model.c_str(), mparams); + model = llama_load_model_from_url(params.model_url.c_str(), params.model.c_str(), params.hf_token.c_str(), mparams); } else { model = llama_load_model_from_file(params.model.c_str(), mparams); } @@ -2354,18 +2120,26 @@ std::tuple llama_init_from_gpt_par for (unsigned int i = 0; i < params.lora_adapter.size(); ++i) { const std::string & lora_adapter = std::get<0>(params.lora_adapter[i]); float lora_scale = std::get<1>(params.lora_adapter[i]); - int err = llama_model_apply_lora_from_file(model, - lora_adapter.c_str(), - lora_scale, - ((i > 0) || params.lora_base.empty()) - ? NULL - : params.lora_base.c_str(), - params.n_threads); - if (err != 0) { - fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__); - llama_free(lctx); - llama_free_model(model); - return std::make_tuple(nullptr, nullptr); + + // try to load as gguf + auto adapter = llama_lora_adapter_init(model, lora_adapter.c_str()); + if (adapter == nullptr) { + fprintf(stderr, "%s: error: failed to apply lora adapter, trying ggla\n", __func__); + + // if that fails, try loading as ggla for compatibility + int err = llama_model_apply_lora_from_file(model, + lora_adapter.c_str(), + lora_scale, + nullptr, + params.n_threads); + if (err != 0) { + fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__); + llama_free(lctx); + llama_free_model(model); + return std::make_tuple(nullptr, nullptr); + } + } else { + llama_lora_adapter_set(lctx, adapter, lora_scale); } } @@ -2376,7 +2150,24 @@ std::tuple llama_init_from_gpt_par if (params.warmup) { LOG("warming up the model with an empty run\n"); - std::vector tmp = { llama_token_bos(model), llama_token_eos(model), }; + std::vector tmp; + llama_token bos = llama_token_bos(model); + llama_token eos = llama_token_eos(model); + // some models (e.g. T5) don't have a BOS token + if (bos != -1) { + tmp.push_back(bos); + } + tmp.push_back(eos); + + if (llama_model_has_encoder(model)) { + llama_encode(lctx, llama_batch_get_one(tmp.data(), tmp.size(), 0, 0)); + llama_token decoder_start_token_id = llama_model_decoder_start_token(model); + if (decoder_start_token_id == -1) { + decoder_start_token_id = bos; + } + tmp.clear(); + tmp.push_back(decoder_start_token_id); + } llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0)); llama_kv_cache_clear(lctx); llama_synchronize(lctx); @@ -2459,6 +2250,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param cparams.yarn_beta_slow = params.yarn_beta_slow; cparams.yarn_orig_ctx = params.yarn_orig_ctx; cparams.pooling_type = params.pooling_type; + cparams.attention_type = params.attention_type; cparams.defrag_thold = params.defrag_thold; cparams.cb_eval = params.cb_eval; cparams.cb_eval_user_data = params.cb_eval_user_data; @@ -2478,7 +2270,7 @@ static bool starts_with(const std::string & str, const std::string & prefix) { return str.rfind(prefix, 0) == 0; } -static bool llama_download_file(const std::string & url, const std::string & path) { +static bool llama_download_file(const std::string & url, const std::string & path, const std::string & hf_token) { // Initialize libcurl std::unique_ptr curl(curl_easy_init(), &curl_easy_cleanup); @@ -2493,6 +2285,15 @@ static bool llama_download_file(const std::string & url, const std::string & pat curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str()); curl_easy_setopt(curl.get(), CURLOPT_FOLLOWLOCATION, 1L); + // Check if hf-token or bearer-token was specified + if (!hf_token.empty()) { + std::string auth_header = "Authorization: Bearer "; + auth_header += hf_token.c_str(); + struct curl_slist *http_headers = NULL; + http_headers = curl_slist_append(http_headers, auth_header.c_str()); + curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers); + } + #if defined(_WIN32) // CURLSSLOPT_NATIVE_CA tells libcurl to use standard certificate store of // operating system. Currently implemented under MS-Windows. @@ -2609,7 +2410,14 @@ static bool llama_download_file(const std::string & url, const std::string & pat } // Set the output file - std::unique_ptr outfile(fopen(path_temporary.c_str(), "wb"), fclose); + + struct FILE_deleter { + void operator()(FILE * f) const { + fclose(f); + } + }; + + std::unique_ptr outfile(fopen(path_temporary.c_str(), "wb")); if (!outfile) { fprintf(stderr, "%s: error opening local file for writing: %s\n", __func__, path.c_str()); return false; @@ -2681,6 +2489,7 @@ static bool llama_download_file(const std::string & url, const std::string & pat struct llama_model * llama_load_model_from_url( const char * model_url, const char * path_model, + const char * hf_token, const struct llama_model_params & params) { // Basic validation of the model_url if (!model_url || strlen(model_url) == 0) { @@ -2688,7 +2497,7 @@ struct llama_model * llama_load_model_from_url( return NULL; } - if (!llama_download_file(model_url, path_model)) { + if (!llama_download_file(model_url, path_model, hf_token)) { return NULL; } @@ -2736,14 +2545,14 @@ struct llama_model * llama_load_model_from_url( // Prepare download in parallel std::vector> futures_download; for (int idx = 1; idx < n_split; idx++) { - futures_download.push_back(std::async(std::launch::async, [&split_prefix, &split_url_prefix, &n_split](int download_idx) -> bool { + futures_download.push_back(std::async(std::launch::async, [&split_prefix, &split_url_prefix, &n_split, hf_token](int download_idx) -> bool { char split_path[PATH_MAX] = {0}; llama_split_path(split_path, sizeof(split_path), split_prefix, download_idx, n_split); char split_url[LLAMA_CURL_MAX_URL_LENGTH] = {0}; llama_split_path(split_url, sizeof(split_url), split_url_prefix, download_idx, n_split); - return llama_download_file(split_url, split_path); + return llama_download_file(split_url, split_path, hf_token); }, idx)); } @@ -2762,6 +2571,7 @@ struct llama_model * llama_load_model_from_hf( const char * repo, const char * model, const char * path_model, + const char * hf_token, const struct llama_model_params & params) { // construct hugging face model url: // @@ -2777,7 +2587,7 @@ struct llama_model * llama_load_model_from_hf( model_url += "/resolve/main/"; model_url += model; - return llama_load_model_from_url(model_url.c_str(), path_model, params); + return llama_load_model_from_url(model_url.c_str(), path_model, hf_token, params); } #else @@ -2785,6 +2595,7 @@ struct llama_model * llama_load_model_from_hf( struct llama_model * llama_load_model_from_url( const char * /*model_url*/, const char * /*path_model*/, + const char * /*hf_token*/, const struct llama_model_params & /*params*/) { fprintf(stderr, "%s: llama.cpp built without libcurl, downloading from an url not supported.\n", __func__); return nullptr; @@ -2794,6 +2605,7 @@ struct llama_model * llama_load_model_from_hf( const char * /*repo*/, const char * /*model*/, const char * /*path_model*/, + const char * /*hf_token*/, const struct llama_model_params & /*params*/) { fprintf(stderr, "%s: llama.cpp built without libcurl, downloading from Hugging Face not supported.\n", __func__); return nullptr; @@ -2858,51 +2670,35 @@ std::vector llama_tokenize( } std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token, bool special) { - std::vector result(8, 0); - const int n_tokens = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size(), special); - if (n_tokens < 0) { - result.resize(-n_tokens); - int check = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size(), special); - GGML_ASSERT(check == -n_tokens); - } else { - result.resize(n_tokens); + std::string piece; + piece.resize(piece.capacity()); // using string internal cache, 15 bytes + '\n' + const int n_chars = llama_token_to_piece(llama_get_model(ctx), token, &piece[0], piece.size(), 0, special); + if (n_chars < 0) { + piece.resize(-n_chars); + int check = llama_token_to_piece(llama_get_model(ctx), token, &piece[0], piece.size(), 0, special); + GGML_ASSERT(check == -n_chars); + } + else { + piece.resize(n_chars); } - return std::string(result.data(), result.size()); + return piece; } -std::string llama_detokenize_spm(llama_context * ctx, const std::vector & tokens) { - const llama_token bos_id = llama_token_bos(llama_get_model(ctx)); - - std::string piece; - std::string result; - - for (size_t i = 0; i < tokens.size(); ++i) { - piece = llama_token_to_piece(ctx, tokens[i]); - - // remove the leading space of the first non-BOS token - if (((tokens[0] == bos_id && i == 1) || (tokens[0] != bos_id && i == 0)) && piece[0] == ' ') { - piece = piece.substr(1); - } - - result += piece; +std::string llama_detokenize(llama_context * ctx, const std::vector & tokens, bool special) { + std::string text; + text.resize(std::max(text.capacity(), tokens.size())); + int32_t n_chars = llama_detokenize(llama_get_model(ctx), tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special); + if (n_chars < 0) { + text.resize(-n_chars); + n_chars = llama_detokenize(llama_get_model(ctx), tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special); + GGML_ASSERT(n_chars <= (int32_t)text.size()); // whitespace trimming is performed after per-token detokenization } - return result; -} - -std::string llama_detokenize_bpe(llama_context * ctx, const std::vector & tokens) { - std::string piece; - std::string result; - - for (size_t i = 0; i < tokens.size(); ++i) { - piece = llama_token_to_piece(ctx, tokens[i]); - - result += piece; - } + text.resize(n_chars); // NOTE: the original tokenizer decodes bytes after collecting the pieces. - return result; + return text; } bool llama_should_add_bos_token(const llama_model * model) { @@ -2911,12 +2707,91 @@ bool llama_should_add_bos_token(const llama_model * model) { return add_bos != -1 ? bool(add_bos) : (llama_vocab_type(model) == LLAMA_VOCAB_TYPE_SPM); } +// +// Chat template utils +// + bool llama_chat_verify_template(const std::string & tmpl) { llama_chat_message chat[] = {{"user", "test"}}; int res = llama_chat_apply_template(nullptr, tmpl.c_str(), chat, 1, true, nullptr, 0); return res >= 0; } +std::string llama_chat_apply_template(const struct llama_model * model, + const std::string & tmpl, + const std::vector & msgs, + bool add_ass) { + int alloc_size = 0; + bool fallback = false; // indicate if we must fallback to default chatml + std::vector chat; + for (auto & msg : msgs) { + chat.push_back({msg.role.c_str(), msg.content.c_str()}); + alloc_size += (msg.role.size() + msg.content.size()) * 1.25; + } + + const char * ptr_tmpl = tmpl.empty() ? nullptr : tmpl.c_str(); + std::vector buf(alloc_size); + + // run the first time to get the total output length + int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size()); + + // error: chat template is not supported + if (res < 0) { + if (ptr_tmpl != nullptr) { + // if the custom "tmpl" is not supported, we throw an error + // this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template() + throw std::runtime_error("this custom template is not supported"); + } else { + // If the built-in template is not supported, we default to chatml + res = llama_chat_apply_template(nullptr, "chatml", chat.data(), chat.size(), add_ass, buf.data(), buf.size()); + fallback = true; + } + } + + // if it turns out that our buffer is too small, we resize it + if ((size_t) res > buf.size()) { + buf.resize(res); + res = llama_chat_apply_template( + fallback ? nullptr : model, + fallback ? "chatml" : ptr_tmpl, + chat.data(), chat.size(), add_ass, buf.data(), buf.size()); + } + + std::string formatted_chat(buf.data(), res); + return formatted_chat; +} + +std::string llama_chat_format_single(const struct llama_model * model, + const std::string & tmpl, + const std::vector & past_msg, + const llama_chat_msg & new_msg, + bool add_ass) { + std::ostringstream ss; + auto fmt_past_msg = past_msg.empty() ? "" : llama_chat_apply_template(model, tmpl, past_msg, false); + std::vector chat_new(past_msg); + // if the past_msg ends with a newline, we must preserve it in the formatted version + if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') { + ss << "\n"; + }; + // format chat with new_msg + chat_new.push_back(new_msg); + auto fmt_new_msg = llama_chat_apply_template(model, tmpl, chat_new, add_ass); + // get the diff part + ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size()); + return ss.str(); +} + +std::string llama_chat_format_example(const struct llama_model * model, + const std::string & tmpl) { + std::vector msgs = { + {"system", "You are a helpful assistant"}, + {"user", "Hello"}, + {"assistant", "Hi there"}, + {"user", "How are you?"}, + }; + return llama_chat_apply_template(model, tmpl, msgs, true); +} + // // KV cache utils // @@ -2996,14 +2871,34 @@ void llama_kv_cache_dump_view_seqs(const llama_kv_cache_view & view, int row_siz // Embedding utils // -void llama_embd_normalize(const float * inp, float * out, int n) { +void llama_embd_normalize(const float * inp, float * out, int n, int embd_norm) { double sum = 0.0; - for (int i = 0; i < n; i++) { - sum += inp[i] * inp[i]; - } - sum = sqrt(sum); - const float norm = sum > 0.0 ? 1.0f / sum : 0.0f; + switch (embd_norm) { + case -1: // no normalisation + sum = 1.0; + break; + case 0: // max absolute + for (int i = 0; i < n; i++) { + if (sum < std::abs(inp[i])) sum = std::abs(inp[i]); + } + sum /= 32760.0; // make an int16 range + break; + case 2: // euclidean + for (int i = 0; i < n; i++) { + sum += inp[i] * inp[i]; + } + sum = std::sqrt(sum); + break; + default: // p-norm (euclidean is p-norm p=2) + for (int i = 0; i < n; i++) { + sum += std::pow(std::abs(inp[i]), embd_norm); + } + sum = std::pow(sum, 1.0 / embd_norm); + break; + } + + const float norm = sum > 0.0 ? 1.0 / sum : 0.0f; for (int i = 0; i < n; i++) { out[i] = inp[i] * norm; @@ -3021,6 +2916,14 @@ float llama_embd_similarity_cos(const float * embd1, const float * embd2, int n) sum2 += embd2[i] * embd2[i]; } + // Handle the case where one or both vectors are zero vectors + if (sum1 == 0.0 || sum2 == 0.0) { + if (sum1 == 0.0 && sum2 == 0.0) { + return 1.0f; // two zero vectors are similar + } + return 0.0f; + } + return sum / (sqrt(sum1) * sqrt(sum2)); } @@ -3029,125 +2932,87 @@ float llama_embd_similarity_cos(const float * embd1, const float * embd2, int n) // static llama_control_vector_data llama_control_vector_load_one(const llama_control_vector_load_info & load_info) { - int32_t n_tensors; - - size_t n_bytes = 0; - - uint32_t max_direction_layer = 0; - llama_control_vector_data result = { -1, {} }; - // calculate size of ctx needed for tensors, ensure tensors are f32, and find max layer - { - struct ggml_init_params meta_params = { - /* .mem_size = */ ggml_tensor_overhead() * 128 + ggml_graph_overhead(), - /* .mem_buffer = */ nullptr, - /* .no_alloc = */ true, - }; - ggml_context * meta_ctx = ggml_init(meta_params); - struct gguf_init_params meta_gguf_params = { - /* .no_alloc = */ true, - /* .ctx = */ &meta_ctx, - }; - struct gguf_context * meta_ctx_gguf = gguf_init_from_file(load_info.fname.c_str(), meta_gguf_params); - if (!meta_ctx_gguf) { - fprintf(stderr, "%s: failed to load control vector from %s\n", __func__, load_info.fname.c_str()); - ggml_free(meta_ctx); - return result; - } - - n_tensors = gguf_get_n_tensors(meta_ctx_gguf); - for (int i = 0; i < n_tensors; i++) { - std::string name = gguf_get_tensor_name(meta_ctx_gguf, i); - - // split on '.' - size_t dotpos = name.find('.'); - if (dotpos != std::string::npos && name.substr(0, dotpos) == "direction") { - try { - uint32_t layer = std::stoi(name.substr(dotpos + 1)); - if (layer == 0) { - fprintf(stderr, "%s: direction tensor invalid in %s\n", __func__, load_info.fname.c_str()); - ggml_free(meta_ctx); - gguf_free(meta_ctx_gguf); - return result; - } - if (layer > max_direction_layer) { - max_direction_layer = layer; - } - } catch (...) { - fprintf(stderr, "%s: direction tensor invalid in %s\n", __func__, load_info.fname.c_str()); - ggml_free(meta_ctx); - gguf_free(meta_ctx_gguf); - return result; - } - } - - struct ggml_tensor * tensor_meta = ggml_get_tensor(meta_ctx, name.c_str()); - if (tensor_meta->type != GGML_TYPE_F32 || ggml_n_dims(tensor_meta) != 1) { - fprintf(stderr, "%s: direction tensor invalid in %s\n", __func__, load_info.fname.c_str()); - ggml_free(meta_ctx); - gguf_free(meta_ctx_gguf); - return result; - } - if (result.n_embd == -1) { - result.n_embd = ggml_nelements(tensor_meta); - } else if (ggml_nelements(tensor_meta) != result.n_embd) { - fprintf(stderr, "%s: direction tensor sizes mismatched in %s\n", __func__, load_info.fname.c_str()); - ggml_free(meta_ctx); - gguf_free(meta_ctx_gguf); - return result; - } - n_bytes += ggml_nbytes(tensor_meta); - } - ggml_free(meta_ctx); - gguf_free(meta_ctx_gguf); + ggml_context * ctx = nullptr; + struct gguf_init_params meta_gguf_params = { + /* .no_alloc = */ false, + /* .ctx = */ &ctx, + }; + struct gguf_context * ctx_gguf = gguf_init_from_file(load_info.fname.c_str(), meta_gguf_params); + if (!ctx_gguf) { + fprintf(stderr, "%s: failed to load control vector file from %s\n", __func__, load_info.fname.c_str()); + return result; } + int32_t n_tensors = gguf_get_n_tensors(ctx_gguf); if (n_tensors == 0) { fprintf(stderr, "%s: no direction tensors found in %s\n", __func__, load_info.fname.c_str()); - return result; } - // load and scale tensors into final control vector context - struct ggml_init_params ggml_params = { - /* .mem_size = */ ggml_tensor_overhead() * n_tensors + n_bytes, - /* .mem_buffer = */ nullptr, - /* .no_alloc = */ false, - }; - struct ggml_context * ctx = ggml_init(ggml_params); + for (int i = 0; i < n_tensors; i++) { + std::string name = gguf_get_tensor_name(ctx_gguf, i); - struct gguf_init_params params = { - /*.no_alloc = */ false, - /*.ctx = */ &ctx, - }; - struct gguf_context * ctx_gguf = gguf_init_from_file(load_info.fname.c_str(), params); - if (!ctx_gguf) { - fprintf(stderr, "%s: failed to load control vector from %s\n", __func__, load_info.fname.c_str()); - ggml_free(ctx); - return result; - } + int layer_idx = -1; - // do not store data for layer 0 (it's not used) - result.data.resize(result.n_embd * max_direction_layer); - - for (uint32_t il = 1; il <= max_direction_layer; il++) { - const std::string name = "direction." + std::to_string(il); - const ggml_tensor * tensor = ggml_get_tensor(ctx, name.c_str()); - - float * dst = result.data.data() + result.n_embd * (il - 1); - - if (tensor) { - const float * src = (const float *) tensor->data; - for (int j = 0; j < result.n_embd; j++) { - dst[j] = src[j] * load_info.strength; - } - } else { - for (int j = 0; j < result.n_embd; j++) { - dst[j] = 0.0f; + // split on '.' + size_t dotpos = name.find('.'); + if (dotpos != std::string::npos && name.substr(0, dotpos) == "direction") { + try { + layer_idx = std::stoi(name.substr(dotpos + 1)); + } catch (...) { + layer_idx = -1; } } + if (layer_idx < 0) { + fprintf(stderr, "%s: invalid/unparsable direction tensor layer index in %s\n", __func__, load_info.fname.c_str()); + result.n_embd = -1; + break; + } else if (layer_idx == 0) { + fprintf(stderr, "%s: invalid (zero) direction tensor layer index in %s\n", __func__, load_info.fname.c_str()); + result.n_embd = -1; + break; + } + + struct ggml_tensor * tensor = ggml_get_tensor(ctx, name.c_str()); + if (tensor->type != GGML_TYPE_F32) { + fprintf(stderr, "%s: invalid (non-F32) direction tensor type in %s\n", __func__, load_info.fname.c_str()); + result.n_embd = -1; + break; + } + if (ggml_n_dims(tensor) != 1) { + fprintf(stderr, "%s: invalid (non-1D) direction tensor shape in %s\n", __func__, load_info.fname.c_str()); + result.n_embd = -1; + break; + } + + if (result.n_embd == -1) { + result.n_embd = ggml_nelements(tensor); + } else if (ggml_nelements(tensor) != result.n_embd) { + fprintf(stderr, "%s: direction tensor in %s does not match previous dimensions\n", __func__, load_info.fname.c_str()); + result.n_embd = -1; + break; + } + + // extend if necessary - do not store data for layer 0 (it's not used) + result.data.resize(std::max(result.data.size(), static_cast(result.n_embd * layer_idx)), 0.0f); + + const float * src = (const float *) tensor->data; + float * dst = result.data.data() + result.n_embd * (layer_idx - 1); // layer 1 at [0] + for (int j = 0; j < result.n_embd; j++) { + dst[j] += src[j] * load_info.strength; // allows multiple directions for same layer in same file + } + } + if (result.n_embd == -1) { + fprintf(stderr, "%s: skipping %s due to invalid direction tensors\n", __func__, load_info.fname.c_str()); + result.data.clear(); + } + + gguf_free(ctx_gguf); + ggml_free(ctx); + return result; } @@ -3158,16 +3023,19 @@ llama_control_vector_data llama_control_vector_load(const std::vector(la).c_str(), std::get<1>(la)); } - fprintf(stream, "lora_base: %s\n", params.lora_base.c_str()); fprintf(stream, "main_gpu: %d # default: 0\n", params.main_gpu); fprintf(stream, "min_keep: %d # default: 0 (disabled)\n", sparams.min_keep); fprintf(stream, "mirostat: %d # default: 0 (disabled)\n", sparams.mirostat); diff --git a/llama/common.h b/llama/common.h index 3aba4bfa..181c412c 100644 --- a/llama/common.h +++ b/llama/common.h @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * @@ -78,6 +78,12 @@ int32_t cpu_get_num_math(); // CLI argument parsing // +// dimensionality reduction methods, used by cvector-generator +enum dimre_method { + DIMRE_METHOD_PCA, + DIMRE_METHOD_MEAN, +}; + struct gpt_params { uint32_t seed = LLAMA_DEFAULT_SEED; // RNG seed @@ -99,7 +105,6 @@ struct gpt_params { int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default) int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs - int32_t n_beams = 0; // if non-zero then use beam search of given width. int32_t grp_attn_n = 1; // group-attention factor int32_t grp_attn_w = 512; // group-attention width int32_t n_print = -1; // print token count every n tokens (-1 = disabled) @@ -120,6 +125,7 @@ struct gpt_params { enum llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs enum llama_rope_scaling_type rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED; enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings + enum llama_attention_type attention_type = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type for embeddings // // sampling parameters struct llama_sampling_params sparams; @@ -128,6 +134,7 @@ struct gpt_params { std::string model_draft = ""; // draft model for speculative decoding std::string model_alias = "unknown"; // model alias std::string model_url = ""; // model url to download + std::string hf_token = ""; // HF token std::string hf_repo = ""; // HF repo std::string hf_file = ""; // HF file std::string prompt = ""; @@ -147,7 +154,6 @@ struct gpt_params { // TODO: avoid tuple, use struct std::vector> lora_adapter; // lora adapter path with user defined scale - std::string lora_base = ""; // base model path for the lora adapter std::vector control_vectors; // control vector with user defined scale @@ -179,7 +185,6 @@ struct gpt_params { bool prompt_cache_all = false; // save user input and generations to prompt cache bool prompt_cache_ro = false; // open the prompt cache read-only and do not update it - bool embedding = false; // get only sentence embedding bool escape = true; // escape "\n", "\r", "\t", "\'", "\"", and "\\" bool multiline_input = false; // reverse the usage of `\` bool simple_io = false; // improves compatibility with subprocesses and limited consoles @@ -206,6 +211,12 @@ struct gpt_params { std::string mmproj = ""; // path to multimodal projector std::vector image; // path to image file(s) + // embedding + bool embedding = false; // get only sentence embedding + int32_t embd_normalize = 2; // normalisation for embendings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm) + std::string embd_out = ""; // empty = default, "array" = [[],[]...], "json" = openai style, "json+" = same "json" + cosine similarity matrix + std::string embd_sep = "\n"; // separator of embendings + // server params int32_t port = 8080; // server listens on this network port int32_t timeout_read = 600; // http read timeout in seconds @@ -216,6 +227,7 @@ struct gpt_params { std::string public_path = ""; std::string chat_template = ""; std::string system_prompt = ""; + bool enable_chat_template = true; std::vector api_keys; @@ -229,6 +241,8 @@ struct gpt_params { std::string slot_save_path; + float slot_prompt_similarity = 0.5f; + // batched-bench params bool is_pp_shared = false; @@ -256,8 +270,21 @@ struct gpt_params { bool process_output = false; // collect data for the output tensor bool compute_ppl = true; // whether to compute perplexity + + // cvector-generator params + int n_pca_batch = 100; + int n_pca_iterations = 1000; + dimre_method cvector_dimre_method = DIMRE_METHOD_PCA; + std::string cvector_outfile = "control_vector.gguf"; + std::string cvector_positive_file = "examples/cvector-generator/positive.txt"; + std::string cvector_negative_file = "examples/cvector-generator/negative.txt"; + + bool spm_infill = false; // suffix/prefix/middle pattern for infill + + std::string lora_outfile = "ggml-lora-merged-f16.gguf"; }; +void gpt_params_handle_hf_token(gpt_params & params); void gpt_params_handle_model_default(gpt_params & params); bool gpt_params_parse_ex (int argc, char ** argv, gpt_params & params); @@ -301,6 +328,7 @@ bool fs_validate_filename(const std::string & filename); bool fs_create_directory_with_parents(const std::string & path); std::string fs_get_cache_directory(); +std::string fs_get_cache_file(const std::string & filename); // // Model utils @@ -312,8 +340,8 @@ std::tuple llama_init_from_gpt_par struct llama_model_params llama_model_params_from_gpt_params (const gpt_params & params); struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params); -struct llama_model * llama_load_model_from_url(const char * model_url, const char * path_model, const struct llama_model_params & params); -struct llama_model * llama_load_model_from_hf(const char * repo, const char * file, const char * path_model, const struct llama_model_params & params); +struct llama_model * llama_load_model_from_url(const char * model_url, const char * path_model, const char * hf_token, const struct llama_model_params & params); +struct llama_model * llama_load_model_from_hf(const char * repo, const char * file, const char * path_model, const char * hf_token, const struct llama_model_params & params); // Batch utils @@ -351,21 +379,13 @@ std::string llama_token_to_piece( llama_token token, bool special = true); -// TODO: these should be moved in llama.h C-style API under single `llama_detokenize` function -// that takes into account the tokenizer type and decides how to handle the leading space -// // detokenizes a vector of tokens into a string // should work similar to Python's `tokenizer.decode` -// removes the leading space from the first non-BOS token -std::string llama_detokenize_spm( +// optionally renders special/control tokens +std::string llama_detokenize( llama_context * ctx, - const std::vector & tokens); - -// detokenizes a vector of tokens into a string -// should work similar to Python's `tokenizer.decode` -std::string llama_detokenize_bpe( - llama_context * ctx, - const std::vector & tokens); + const std::vector & tokens, + bool special = true); // Uses the value from the model metadata if possible, otherwise // defaults to true when model type is SPM, otherwise false. @@ -375,9 +395,34 @@ bool llama_should_add_bos_token(const llama_model * model); // Chat template utils // +// same with llama_chat_message, but uses std::string +struct llama_chat_msg { + std::string role; + std::string content; +}; + // Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid bool llama_chat_verify_template(const std::string & tmpl); +// CPP wrapper for llama_chat_apply_template +// If the built-in template is not supported, we default to chatml +// If the custom "tmpl" is not supported, we throw an error +std::string llama_chat_apply_template(const struct llama_model * model, + const std::string & tmpl, + const std::vector & chat, + bool add_ass); + +// Format single message, while taking into account the position of that message in chat history +std::string llama_chat_format_single(const struct llama_model * model, + const std::string & tmpl, + const std::vector & past_msg, + const llama_chat_msg & new_msg, + bool add_ass); + +// Returns an example of formatted chat +std::string llama_chat_format_example(const struct llama_model * model, + const std::string & tmpl); + // // KV cache utils // @@ -392,7 +437,7 @@ void llama_kv_cache_dump_view_seqs(const llama_kv_cache_view & view, int row_siz // Embedding utils // -void llama_embd_normalize(const float * inp, float * out, int n); +void llama_embd_normalize(const float * inp, float * out, int n, int embd_norm = 2); float llama_embd_similarity_cos(const float * embd1, const float * embd2, int n); @@ -436,4 +481,3 @@ void yaml_dump_string_multiline(FILE * stream, const char * prop_name, const cha void yaml_dump_non_result_info( FILE * stream, const gpt_params & params, const llama_context * lctx, const std::string & timestamp, const std::vector & prompt_tokens, const char * model_desc); - diff --git a/llama/ggml-aarch64.c b/llama/ggml-aarch64.c new file mode 100644 index 00000000..c2189c02 --- /dev/null +++ b/llama/ggml-aarch64.c @@ -0,0 +1,2219 @@ +/** + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file + * + * MIT License + * + * Copyright (c) 2023-2024 The ggml authors + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +// SPDX-FileCopyrightText: Copyright 2024 Arm Ltd. +#define GGML_COMMON_IMPL_C +#include "ggml-common.h" + +#include "ggml-quants.h" +#include "ggml-impl.h" + +#include +#include +#include +#include +#include // for qsort +#include // for GGML_ASSERT + +#include "ggml-aarch64.h" + +#if defined(__GNUC__) +#pragma GCC diagnostic ignored "-Woverlength-strings" +#endif + +#define UNUSED GGML_UNUSED + +// Functions to create the interleaved data layout formats + +// interleave 4 block_q4_0s in blocks of blck_size_interleave +// returns an interleaved block_q4_0x4 +// in the interleaved block_q4_0x4, place deltas for 4 block_q4_0 blocks +// first, then interleave quants from 4 block_q4_0s in blocks of blck_size_interleave +// +// - in : an array of block_q4_0 pointers +// - blck_size_interleave : the block_q4_0 quants bytes are interleaved in blocks of +// blck_size_interleave bytes +// - xor_mask : the mask to convert the nibbles in block_q4_0 quants bytes +// from bias offset form to pure sign form (this saves subtract +// operations durin unpacking) +// +static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int blck_size_interleave, unsigned int xor_mask) { + block_q4_0x4 out; + + for (int i = 0; i < 4; i++) { + out.d[i] = in[i].d; + } + + for (int i = 0; i < QK4_0 * 2; i++) { + int src_offset = (i / (4 * blck_size_interleave)) * blck_size_interleave; + int src_id = (i % (4 * blck_size_interleave)) / blck_size_interleave; + src_offset += (i % blck_size_interleave); + + out.qs[i] = in[src_id].qs[src_offset] ^ xor_mask; + } + + return out; +} + +// interleave 8 block_q4_0s in blocks of blck_size_interleave +// returns an interleaved block_q4_0x8 +// in the interleaved block_q4_0x8, place deltas for 8 block_q4_0 blocks +// first, then interleave quants from 8 block_q4_0s in blocks of blck_size_interleave +static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int blck_size_interleave, unsigned int xor_mask) { + block_q4_0x8 out; + + for (int i = 0; i < 8; i++) { + out.d[i] = in[i].d; + } + + for (int i = 0; i < QK4_0 * 4; i++) { + int src_offset = (i / (8 * blck_size_interleave)) * blck_size_interleave; + int src_id = (i % (8 * blck_size_interleave)) / blck_size_interleave; + src_offset += (i % blck_size_interleave); + + out.qs[i] = in[src_id].qs[src_offset] ^ xor_mask; + } + + return out; +} + +void quantize_q8_0_4x4(const float * restrict x, void * restrict vy, int64_t k) { + assert(QK8_0 == 32); + assert(k % QK8_0 == 0); + const int nb = k / QK8_0; + + block_q8_0x4 * restrict y = (block_q8_0x4 *) vy; + +#if defined(__ARM_NEON) + float32x4_t srcv[4][8]; + float id[4]; + + for (int i = 0; i < nb; i++) { + float32x4_t asrcv[8]; + float32x4_t amaxv[8]; + + for (int row_iter = 0; row_iter < 4; row_iter++) { + for (int j = 0; j < 8; j++) srcv[row_iter][j] = vld1q_f32(x + row_iter * k + i * 32 + 4 * j); + for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[row_iter][j]); + + for (int j = 0; j < 4; j++) amaxv[2 * j] = vmaxq_f32(asrcv[2 * j], asrcv[2 * j + 1]); + for (int j = 0; j < 2; j++) amaxv[4 * j] = vmaxq_f32(amaxv[4 * j], amaxv[4 * j + 2]); + for (int j = 0; j < 1; j++) amaxv[8 * j] = vmaxq_f32(amaxv[8 * j], amaxv[8 * j + 4]); + + const float amax = vmaxvq_f32(amaxv[0]); + + const float d = amax / ((1 << 7) - 1); + id[row_iter] = d ? 1.0f / d : 0.0f; + + y[i].d[row_iter] = GGML_FP32_TO_FP16(d); + } + + for (int j = 0; j < 8; j++) { + float32x4_t v = vmulq_n_f32(srcv[0][j], id[0]); + int32x4_t vi = vcvtnq_s32_f32(v); + y[i].qs[16 * j + 0] = vgetq_lane_s32(vi, 0); + y[i].qs[16 * j + 1] = vgetq_lane_s32(vi, 1); + y[i].qs[16 * j + 2] = vgetq_lane_s32(vi, 2); + y[i].qs[16 * j + 3] = vgetq_lane_s32(vi, 3); + + v = vmulq_n_f32(srcv[1][j], id[1]); + vi = vcvtnq_s32_f32(v); + y[i].qs[16 * j + 4] = vgetq_lane_s32(vi, 0); + y[i].qs[16 * j + 5] = vgetq_lane_s32(vi, 1); + y[i].qs[16 * j + 6] = vgetq_lane_s32(vi, 2); + y[i].qs[16 * j + 7] = vgetq_lane_s32(vi, 3); + + v = vmulq_n_f32(srcv[2][j], id[2]); + vi = vcvtnq_s32_f32(v); + y[i].qs[16 * j + 8] = vgetq_lane_s32(vi, 0); + y[i].qs[16 * j + 9] = vgetq_lane_s32(vi, 1); + y[i].qs[16 * j + 10] = vgetq_lane_s32(vi, 2); + y[i].qs[16 * j + 11] = vgetq_lane_s32(vi, 3); + + v = vmulq_n_f32(srcv[3][j], id[3]); + vi = vcvtnq_s32_f32(v); + y[i].qs[16 * j + 12] = vgetq_lane_s32(vi, 0); + y[i].qs[16 * j + 13] = vgetq_lane_s32(vi, 1); + y[i].qs[16 * j + 14] = vgetq_lane_s32(vi, 2); + y[i].qs[16 * j + 15] = vgetq_lane_s32(vi, 3); + } + } +#else + // scalar + const int blck_size_interleave = 4; + float srcv[4][QK8_0]; + float id[4]; + + for (int i = 0; i < nb; i++) { + for (int row_iter = 0; row_iter < 4; row_iter++) { + float amax = 0.0f; // absolute max + + for (int j = 0; j < QK8_0; j++) { + srcv[row_iter][j] = x[row_iter * k + i * QK8_0 + j]; + amax = MAX(amax, fabsf(srcv[row_iter][j])); + } + + const float d = amax / ((1 << 7) - 1); + id[row_iter] = d ? 1.0f / d : 0.0f; + + y[i].d[row_iter] = GGML_FP32_TO_FP16(d); + } + + for (int j = 0; j < QK8_0 * 4; j++) { + int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave; + int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave; + src_offset += (j % blck_size_interleave); + + float x0 = srcv[src_id][src_offset] * id[src_id]; + y[i].qs[j] = roundf(x0); + } + } +#endif +} + +void quantize_q8_0_4x8(const float * restrict x, void * restrict vy, int64_t k) { + assert(QK8_0 == 32); + assert(k % QK8_0 == 0); + const int nb = k / QK8_0; + + block_q8_0x4 * restrict y = (block_q8_0x4 *) vy; + +#if defined(__ARM_NEON) + float32x4_t srcv[4][8]; + float id[4]; + + for (int i = 0; i < nb; i++) { + float32x4_t asrcv[8]; + float32x4_t amaxv[8]; + + for (int row_iter = 0; row_iter < 4; row_iter++) { + for (int j = 0; j < 8; j++) srcv[row_iter][j] = vld1q_f32(x + row_iter * k + i * 32 + 4 * j); + for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[row_iter][j]); + + for (int j = 0; j < 4; j++) amaxv[2 * j] = vmaxq_f32(asrcv[2 * j], asrcv[2 * j + 1]); + for (int j = 0; j < 2; j++) amaxv[4 * j] = vmaxq_f32(amaxv[4 * j], amaxv[4 * j + 2]); + for (int j = 0; j < 1; j++) amaxv[8 * j] = vmaxq_f32(amaxv[8 * j], amaxv[8 * j + 4]); + + const float amax = vmaxvq_f32(amaxv[0]); + + const float d = amax / ((1 << 7) - 1); + id[row_iter] = d ? 1.0f / d : 0.0f; + + y[i].d[row_iter] = GGML_FP32_TO_FP16(d); + } + + for (int j = 0; j < 4; j++) { + float32x4_t v = vmulq_n_f32(srcv[0][2 * j], id[0]); + int32x4_t vi = vcvtnq_s32_f32(v); + y[i].qs[32 * j + 0] = vgetq_lane_s32(vi, 0); + y[i].qs[32 * j + 1] = vgetq_lane_s32(vi, 1); + y[i].qs[32 * j + 2] = vgetq_lane_s32(vi, 2); + y[i].qs[32 * j + 3] = vgetq_lane_s32(vi, 3); + v = vmulq_n_f32(srcv[0][2 * j + 1], id[0]); + vi = vcvtnq_s32_f32(v); + y[i].qs[32 * j + 4] = vgetq_lane_s32(vi, 0); + y[i].qs[32 * j + 5] = vgetq_lane_s32(vi, 1); + y[i].qs[32 * j + 6] = vgetq_lane_s32(vi, 2); + y[i].qs[32 * j + 7] = vgetq_lane_s32(vi, 3); + + v = vmulq_n_f32(srcv[1][2 * j], id[1]); + vi = vcvtnq_s32_f32(v); + y[i].qs[32 * j + 8] = vgetq_lane_s32(vi, 0); + y[i].qs[32 * j + 9] = vgetq_lane_s32(vi, 1); + y[i].qs[32 * j + 10] = vgetq_lane_s32(vi, 2); + y[i].qs[32 * j + 11] = vgetq_lane_s32(vi, 3); + v = vmulq_n_f32(srcv[1][2 * j + 1], id[1]); + vi = vcvtnq_s32_f32(v); + y[i].qs[32 * j + 12] = vgetq_lane_s32(vi, 0); + y[i].qs[32 * j + 13] = vgetq_lane_s32(vi, 1); + y[i].qs[32 * j + 14] = vgetq_lane_s32(vi, 2); + y[i].qs[32 * j + 15] = vgetq_lane_s32(vi, 3); + + v = vmulq_n_f32(srcv[2][2 * j], id[2]); + vi = vcvtnq_s32_f32(v); + y[i].qs[32 * j + 16] = vgetq_lane_s32(vi, 0); + y[i].qs[32 * j + 17] = vgetq_lane_s32(vi, 1); + y[i].qs[32 * j + 18] = vgetq_lane_s32(vi, 2); + y[i].qs[32 * j + 19] = vgetq_lane_s32(vi, 3); + v = vmulq_n_f32(srcv[2][2 * j + 1], id[2]); + vi = vcvtnq_s32_f32(v); + y[i].qs[32 * j + 20] = vgetq_lane_s32(vi, 0); + y[i].qs[32 * j + 21] = vgetq_lane_s32(vi, 1); + y[i].qs[32 * j + 22] = vgetq_lane_s32(vi, 2); + y[i].qs[32 * j + 23] = vgetq_lane_s32(vi, 3); + + v = vmulq_n_f32(srcv[3][2 * j], id[3]); + vi = vcvtnq_s32_f32(v); + y[i].qs[32 * j + 24] = vgetq_lane_s32(vi, 0); + y[i].qs[32 * j + 25] = vgetq_lane_s32(vi, 1); + y[i].qs[32 * j + 26] = vgetq_lane_s32(vi, 2); + y[i].qs[32 * j + 27] = vgetq_lane_s32(vi, 3); + v = vmulq_n_f32(srcv[3][2 * j + 1], id[3]); + vi = vcvtnq_s32_f32(v); + y[i].qs[32 * j + 28] = vgetq_lane_s32(vi, 0); + y[i].qs[32 * j + 29] = vgetq_lane_s32(vi, 1); + y[i].qs[32 * j + 30] = vgetq_lane_s32(vi, 2); + y[i].qs[32 * j + 31] = vgetq_lane_s32(vi, 3); + } + } +#else + // scalar + const int blck_size_interleave = 8; + float srcv[4][QK8_0]; + float id[4]; + + for (int i = 0; i < nb; i++) { + for (int row_iter = 0; row_iter < 4; row_iter++) { + float amax = 0.0f; // absolute max + + for (int j = 0; j < QK8_0; j++) { + srcv[row_iter][j] = x[row_iter * k + i * QK8_0 + j]; + amax = MAX(amax, fabsf(srcv[row_iter][j])); + } + + const float d = amax / ((1 << 7) - 1); + id[row_iter] = d ? 1.0f / d : 0.0f; + + y[i].d[row_iter] = GGML_FP32_TO_FP16(d); + } + + for (int j = 0; j < QK8_0 * 4; j++) { + int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave; + int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave; + src_offset += (j % blck_size_interleave); + + float x0 = srcv[src_id][src_offset] * id[src_id]; + y[i].qs[j] = roundf(x0); + } + } +#endif +} + +void quantize_mat_q8_0(const float * restrict x, void * restrict vy, int64_t nrow, int64_t n_per_row, int64_t blck_size_interleave) { + assert(nrow == 4); + UNUSED(nrow); + if (blck_size_interleave == 4) { + quantize_q8_0_4x4(x, vy, n_per_row); + } else if (blck_size_interleave == 8) { + quantize_q8_0_4x8(x, vy, n_per_row); + } else { + assert(false); + } +} + +static size_t quantize_q4_0_nr_bl(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, int nrows_interleaved, int blck_size_interleave) { + assert(n_per_row % QK4_0 == 0); + const int nb = n_per_row / QK4_0; + + void * out_ptr = NULL; + if (nrows_interleaved == 8) { + out_ptr = (block_q4_0x8 *) dst; + } + else if (nrows_interleaved == 4) { + out_ptr = (block_q4_0x4 *) dst; + } + assert(nrows_interleaved <= 8); + block_q4_0 dst_tmp[8]; + + for (int b = 0; b < (nrow * n_per_row); b += nrows_interleaved * n_per_row) { + + for (int64_t x = 0; x < nb; x++) { + + for (int i = 0; i < nrows_interleaved; i++ ) { + quantize_row_q4_0_ref(src + b + i * n_per_row + x * QK4_0, (block_q4_0 *) dst_tmp + i, QK4_0); + } + + if (nrows_interleaved == 8) { + *(block_q4_0x8 *) out_ptr = make_block_q4_0x8(dst_tmp, blck_size_interleave, 0x88); + out_ptr = (block_q4_0x8 *) out_ptr + 1; + } + else if (nrows_interleaved == 4) { + *(block_q4_0x4 *) out_ptr = make_block_q4_0x4(dst_tmp, blck_size_interleave, 0x88); + out_ptr = (block_q4_0x4 *) out_ptr + 1; + } + } + } + + return ((nrow * n_per_row) / QK4_0 * sizeof(block_q4_0)); +} + +size_t quantize_q4_0_4x4(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + if (!quant_weights) { + return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 4, 4); + } + else { + assert(false); + return 0; + } +} + +size_t quantize_q4_0_4x8(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + if (!quant_weights) { + return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 4, 8); + } + else { + assert(false); + return 0; + } +} + +size_t quantize_q4_0_8x8(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + if (!quant_weights) { + return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 8, 8); + } + else { + assert(false); + return 0; + } +} + +void ggml_gemv_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 4; + + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined(__ARM_FEATURE_SVE) + if (svcntw() == 8) { + GGML_ASSERT(!(ggml_cpu_has_sve() && (svcntw() == 8)) && + "__ARM_FEATURE_SVE defined, use the Q4_0_8_8 quantization format for optimal performance"); + } +#endif +#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) + GGML_ASSERT(!(ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) && + "__ARM_NEON and __ARM_FEATURE_MATMUL_INT8 defined, use the Q4_0_4_8 quantization format for optimal performance"); +#elif defined(__ARM_NEON) && defined(__aarch64__) && ! ((defined(_MSC_VER)) && ! defined(__clang__)) + const void * b_ptr = vx; + const void * a_ptr = vy; + float * res_ptr = s; + + __asm__ __volatile__( + "movi v31.16b, #0x4\n" + "movi v30.16b, #0xf0\n" + "add %x[b_ptr], %x[b_ptr], #0x8\n" + "1:" // Column loop + "add x22, %x[a_ptr], #0x2\n" + "movi v29.16b, #0x0\n" + "mov x21, %x[nb]\n" + "2:" // Block loop + "ldr q28, [%x[b_ptr], #0x0]\n" + "ldr q27, [x22, #0x0]\n" + "movi v26.4s, #0x0\n" + "sub x20, x22, #0x2\n" + "ldr q25, [x22, #0x10]\n" + "ldr q24, [%x[b_ptr], #0x10]\n" + "sub x21, x21, #0x1\n" + "add x22, x22, #0x22\n" + "ldr q23, [%x[b_ptr], #0x20]\n" + "ldr q22, [%x[b_ptr], #0x30]\n" + "ld1r { v21.8h }, [x20]\n" + "ldr q20, [%x[b_ptr], #-0x8]\n" + "sshl v16.16b, v28.16b, v31.16b\n" + "and v28.16b, v28.16b, v30.16b\n" + "sshl v19.16b, v24.16b, v31.16b\n" + "and v24.16b, v24.16b, v30.16b\n" + "add %x[b_ptr], %x[b_ptr], #0x48\n" + "sshl v18.16b, v23.16b, v31.16b\n" + "and v23.16b, v23.16b, v30.16b\n" + ".inst 0x4f9be21a // sdot v26.4s, v16.16b, v27.4b[0]\n" + "sshl v17.16b, v22.16b, v31.16b\n" + "and v22.16b, v22.16b, v30.16b\n" + "fcvtl v21.4s, v21.4h\n" + "fcvtl v16.4s, v20.4h\n" + ".inst 0x4f99e39a // sdot v26.4s, v28.16b, v25.4b[0]\n" + "fmul v16.4s, v16.4s, v21.4s\n" + ".inst 0x4fbbe27a // sdot v26.4s, v19.16b, v27.4b[1]\n" + ".inst 0x4fb9e31a // sdot v26.4s, v24.16b, v25.4b[1]\n" + ".inst 0x4f9bea5a // sdot v26.4s, v18.16b, v27.4b[2]\n" + ".inst 0x4f99eafa // sdot v26.4s, v23.16b, v25.4b[2]\n" + ".inst 0x4fbbea3a // sdot v26.4s, v17.16b, v27.4b[3]\n" + ".inst 0x4fb9eada // sdot v26.4s, v22.16b, v25.4b[3]\n" + "scvtf v26.4s, v26.4s, #0x4\n" + "fmla v29.4s, v26.4s, v16.4s\n" + "cbnz x21, 2b\n" + "sub %x[nc], %x[nc], #0x4\n" + "str q29, [%x[res_ptr], #0x0]\n" + "add %x[res_ptr], %x[res_ptr], #0x10\n" + "cbnz %x[nc], 1b\n" + : [b_ptr] "+&r" (b_ptr), [res_ptr] "+&r" (res_ptr), [nc] "+&r" (nc) + : [a_ptr] "r" (a_ptr), [nb] "r" (nb) + : "memory", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22" + ); +#else + float sumf[4]; + int sumi; + + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); + sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4; + } + sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d); + } + } + } + for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; + } +#endif +} + +void ggml_gemv_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 8; + + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined(__ARM_FEATURE_SVE) + if (svcntw() == 8) { + GGML_ASSERT(!(ggml_cpu_has_sve() && (svcntw() == 8)) && + "__ARM_FEATURE_SVE defined, use the Q4_0_8_8 quantization format for optimal performance"); + } +#endif +#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) && ! ((defined(_MSC_VER)) && ! defined(__clang__)) + const void * b_ptr = vx; + const void * a_ptr = vy; + float * res_ptr = s; + + __asm__ __volatile__( + "movi v2.16b, #0x4\n" + "movi v1.16b, #0xf0\n" + "add %x[b_ptr], %x[b_ptr], #0x8\n" + "1:" // Column loop + "add x23, %x[a_ptr], #0x2\n" + "movi v0.16b, #0x0\n" + "mov x22, %x[nb]\n" + "2:" // Block loop + "ldr q31, [%x[b_ptr], #0x0]\n" + "ldr q30, [%x[b_ptr], #0x10]\n" + "mov x21, x23\n" + "movi v29.4s, #0x0\n" + "ldr q28, [%x[b_ptr], #0x20]\n" + "ldr q27, [%x[b_ptr], #0x30]\n" + "movi v26.4s, #0x0\n" + "sub x20, x23, #0x2\n" + "ld1r { v25.8h }, [x20]\n" + "ldr q24, [%x[b_ptr], #-0x8]\n" + "sub x22, x22, #0x1\n" + "add x23, x23, #0x22\n" + "ld1r { v23.2d }, [x21], #0x8\n" + "sshl v22.16b, v31.16b, v2.16b\n" + "sshl v16.16b, v30.16b, v2.16b\n" + "add %x[b_ptr], %x[b_ptr], #0x48\n" + "ld1r { v21.2d }, [x21], #0x8\n" + "sshl v20.16b, v28.16b, v2.16b\n" + "sshl v19.16b, v27.16b, v2.16b\n" + "ld1r { v18.2d }, [x21], #0x8\n" + "ld1r { v17.2d }, [x21], #0x8\n" + "and v31.16b, v31.16b, v1.16b\n" + "and v30.16b, v30.16b, v1.16b\n" + ".inst 0x4e9796dd // sdot v29.4s, v22.16b, v23.16b\n" + ".inst 0x4e97961a // sdot v26.4s, v16.16b, v23.16b\n" + "and v28.16b, v28.16b, v1.16b\n" + "and v27.16b, v27.16b, v1.16b\n" + "fcvtl v25.4s, v25.4h\n" + "fcvtl v16.4s, v24.4h\n" + ".inst 0x4e95969d // sdot v29.4s, v20.16b, v21.16b\n" + ".inst 0x4e95967a // sdot v26.4s, v19.16b, v21.16b\n" + "fmul v16.4s, v16.4s, v25.4s\n" + ".inst 0x4e9297fd // sdot v29.4s, v31.16b, v18.16b\n" + ".inst 0x4e9297da // sdot v26.4s, v30.16b, v18.16b\n" + ".inst 0x4e91979d // sdot v29.4s, v28.16b, v17.16b\n" + ".inst 0x4e91977a // sdot v26.4s, v27.16b, v17.16b\n" + "addp v29.4s, v29.4s, v26.4s\n" + "scvtf v29.4s, v29.4s, #0x4\n" + "fmla v0.4s, v29.4s, v16.4s\n" + "cbnz x22, 2b\n" + "sub %x[nc], %x[nc], #0x4\n" + "str q0, [%x[res_ptr], #0x0]\n" + "add %x[res_ptr], %x[res_ptr], #0x10\n" + "cbnz %x[nc], 1b\n" + : [b_ptr] "+&r" (b_ptr), [res_ptr] "+&r" (res_ptr), [nc] "+&r" (nc) + : [a_ptr] "r" (a_ptr), [nb] "r" (nb) + : "memory", "v0", "v1", "v2", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22", "x23" + ); +#elif defined(__ARM_NEON) && defined(__aarch64__) + GGML_ASSERT((ggml_cpu_has_sve() || ggml_cpu_has_matmul_int8()) && + "__ARM_FEATURE_SVE and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 quantization format for optimal " + "performance"); +#else + float sumf[4]; + int sumi; + + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); + sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4; + } + sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d); + } + } + } + for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; + } +#endif +} + +void ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 8; + const int blocklen = 8; + + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined(__ARM_FEATURE_SVE) && ! ((defined(_MSC_VER)) && ! defined(__clang__)) + if (svcntw() == 8) { + const void * b_ptr = vx; + const void * a_ptr = vy; + float * res_ptr = s; + + __asm__ __volatile__( + "ptrue p0.b\n" + "add %x[b_ptr], %x[b_ptr], #0x10\n" + "1:" // Column loop + "add x22, %x[a_ptr], #0x2\n" + "mov z31.b, #0x0\n" + "mov x21, %x[nb]\n" + "2:" // Block loop + "ld1b { z30.b }, p0/Z, [%x[b_ptr]]\n" + "ld1b { z29.b }, p0/Z, [%x[b_ptr], #1, MUL VL]\n" + "mov z28.s, #0x0\n" + "mov z27.s, #0x0\n" + "ld1rd { z26.d }, p0/Z, [x22]\n" + "ld1b { z25.b }, p0/Z, [%x[b_ptr], #2, MUL VL]\n" + "sub x20, x22, #0x2\n" + "sub x21, x21, #0x1\n" + "ld1b { z24.b }, p0/Z, [%x[b_ptr], #3, MUL VL]\n" + "ld1rd { z23.d }, p0/Z, [x22, #8]\n" + "lsl z22.b, z30.b, #0x4\n" + "lsl z16.b, z29.b, #0x4\n" + "and z30.b, z30.b, #0xf0\n" + "and z29.b, z29.b, #0xf0\n" + "ld1rd { z21.d }, p0/Z, [x22, #16]\n" + "ld1rd { z20.d }, p0/Z, [x22, #24]\n" + "lsl z19.b, z25.b, #0x4\n" + "and z25.b, z25.b, #0xf0\n" + "ld1rh { z17.h }, p0/Z, [x20]\n" + "ld1h { z18.s }, p0/Z, [%x[b_ptr], #-1, MUL VL]\n" + "sdot z28.s, z22.b, z26.b\n" + "sdot z27.s, z16.b, z26.b\n" + "lsl z16.b, z24.b, #0x4\n" + "add x22, x22, #0x22\n" + "and z24.b, z24.b, #0xf0\n" + "add %x[b_ptr], %x[b_ptr], #0x90\n" + "fcvt z17.s, p0/m, z17.h\n" + "fcvt z18.s, p0/m, z18.h\n" + "sdot z28.s, z19.b, z23.b\n" + "sdot z27.s, z16.b, z23.b\n" + "fmul z18.s, z18.s, z17.s\n" + "sdot z28.s, z30.b, z21.b\n" + "sdot z27.s, z29.b, z21.b\n" + "sdot z28.s, z25.b, z20.b\n" + "sdot z27.s, z24.b, z20.b\n" + "uzp1 z17.s, z28.s, z27.s\n" + "uzp2 z16.s, z28.s, z27.s\n" + "add z17.s, z17.s, z16.s\n" + "asr z17.s, z17.s, #0x4\n" + "scvtf z17.s, p0/m, z17.s\n" + "fmla z31.s, p0/M, z17.s, z18.s\n" + "cbnz x21, 2b\n" + "sub %x[nc], %x[nc], #0x8\n" + "st1w { z31.s }, p0, [%x[res_ptr]]\n" + "add %x[res_ptr], %x[res_ptr], #0x20\n" + "cbnz %x[nc], 1b\n" + : [b_ptr] "+&r" (b_ptr), [res_ptr] "+&r" (res_ptr), [nc] "+&r" (nc) + : [a_ptr] "r" (a_ptr), [nb] "r" (nb) + : "memory", "p0", "x20", "x21", "x22", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" + ); + return; + } + else if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) { + GGML_ASSERT((ggml_cpu_has_sve() && (svcntw() == 8)) && + "__ARM_FEATURE_SVE for vector size of 256-bits not defined, use the Q4_0_4_8 quantization format for optimal " + "performance"); + } + else if (ggml_cpu_has_neon()) { + GGML_ASSERT(((ggml_cpu_has_sve() && (svcntw() == 8)) || ggml_cpu_has_matmul_int8()) && + "__ARM_FEATURE_SVE for vector size of 256-bits and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 " + "quantization format for optimal performance"); + } +#endif +#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) + GGML_ASSERT(ggml_cpu_has_sve() && + "__ARM_FEATURE_SVE not defined, use the Q4_0_4_8 quantization format for optimal performance"); +#elif defined(__ARM_NEON) && defined(__aarch64__) + GGML_ASSERT((ggml_cpu_has_sve() || ggml_cpu_has_matmul_int8()) && + "__ARM_FEATURE_SVE and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 quantization format for optimal " + "performance"); +#else + float sumf[8]; + int sumi; + + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); + sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4; + } + sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d); + } + } + } + for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; + } +#endif +} + +void ggml_gemm_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 4; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8) + if (svcntw() == 8) { + GGML_ASSERT(!(ggml_cpu_has_sve() && (svcntw() == 8)) && + "__ARM_FEATURE_SVE defined, use the Q4_0_8_8 quantization format for optimal performance"); + } +#endif +#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) + GGML_ASSERT(!(ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) && + "__ARM_NEON and __ARM_FEATURE_MATMUL_INT8 defined, use the Q4_0_4_8 quantization format for optimal performance"); +#elif defined(__ARM_NEON) && defined(__aarch64__) && ! ((defined(_MSC_VER)) && ! defined(__clang__)) + const void * b_ptr = vx; + const void * a_ptr = vy; + float * res_ptr = s; + size_t res_stride = bs * sizeof(float); + + __asm__ __volatile__( + "mov x10, %x[nr]\n" + "mov x9, #0x88\n" + "cmp x10, #0x10\n" + "mul x9, %x[nb], x9\n" + "blt 4f\n" + "1:" // Row loop + "add x28, %x[b_ptr], #0x8\n" + "mov x27, %x[nc]\n" + "add x26, %x[res_ptr], %x[res_stride], LSL #4\n" + "2:" // Column loop + "add x25, %x[a_ptr], #0x8\n" + "movi v15.16b, #0x0\n" + "movi v19.16b, #0x0\n" + "mov x24, %x[nb]\n" + "add x23, x25, x9\n" + "movi v18.16b, #0x0\n" + "movi v14.16b, #0x0\n" + "add x22, x23, x9\n" + "movi v11.16b, #0x0\n" + "movi v13.16b, #0x0\n" + "add x21, x22, x9\n" + "movi v23.16b, #0x0\n" + "movi v16.16b, #0x0\n" + "movi v25.16b, #0x0\n" + "movi v7.16b, #0x0\n" + "movi v0.16b, #0x0\n" + "movi v4.16b, #0x0\n" + "movi v5.16b, #0x0\n" + "movi v21.16b, #0x0\n" + "movi v8.16b, #0x0\n" + "movi v1.16b, #0x0\n" + "3:" // Block loop + "ldr q3, [x28, #0x0]\n" + "ldr q31, [x25, #0x0]\n" + "movi v28.16b, #0x4\n" + "movi v10.4s, #0x0\n" + "ldr q22, [x28, #0x10]\n" + "ldr q6, [x25, #0x10]\n" + "movi v29.4s, #0x0\n" + "movi v9.4s, #0x0\n" + "ldr q27, [x28, #0x20]\n" + "ldr q30, [x28, #0x30]\n" + "movi v20.4s, #0x0\n" + "movi v24.16b, #0xf0\n" + "ldr d2, [x25, #-0x8]\n" + "ldr d26, [x23, #-0x8]\n" + "sshl v12.16b, v3.16b, v28.16b\n" + "sub x20, x28, #0x8\n" + "ldr d17, [x20, #0x0]\n" + "and v3.16b, v3.16b, v24.16b\n" + "subs x24, x24, #0x1\n" + "add x28, x28, #0x48\n" + ".inst 0x4f9fe18a // sdot v10.4s, v12.16b, v31.4b[0]\n" + ".inst 0x4fbfe19d // sdot v29.4s, v12.16b, v31.4b[1]\n" + ".inst 0x4f9fe989 // sdot v9.4s, v12.16b, v31.4b[2]\n" + ".inst 0x4fbfe994 // sdot v20.4s, v12.16b, v31.4b[3]\n" + "sshl v31.16b, v22.16b, v28.16b\n" + "and v22.16b, v22.16b, v24.16b\n" + "fcvtl v17.4s, v17.4h\n" + "fcvtl v2.4s, v2.4h\n" + "fcvtl v26.4s, v26.4h\n" + ".inst 0x4f86e3ea // sdot v10.4s, v31.16b, v6.4b[0]\n" + ".inst 0x4fa6e3fd // sdot v29.4s, v31.16b, v6.4b[1]\n" + ".inst 0x4f86ebe9 // sdot v9.4s, v31.16b, v6.4b[2]\n" + ".inst 0x4fa6ebf4 // sdot v20.4s, v31.16b, v6.4b[3]\n" + "sshl v6.16b, v27.16b, v28.16b\n" + "sshl v28.16b, v30.16b, v28.16b\n" + "and v27.16b, v27.16b, v24.16b\n" + "and v30.16b, v30.16b, v24.16b\n" + "ldr q24, [x25, #0x20]\n" + ".inst 0x4f98e0ca // sdot v10.4s, v6.16b, v24.4b[0]\n" + ".inst 0x4fb8e0dd // sdot v29.4s, v6.16b, v24.4b[1]\n" + ".inst 0x4f98e8c9 // sdot v9.4s, v6.16b, v24.4b[2]\n" + ".inst 0x4fb8e8d4 // sdot v20.4s, v6.16b, v24.4b[3]\n" + "ldr q24, [x25, #0x30]\n" + ".inst 0x4f98e38a // sdot v10.4s, v28.16b, v24.4b[0]\n" + ".inst 0x4fb8e39d // sdot v29.4s, v28.16b, v24.4b[1]\n" + ".inst 0x4f98eb89 // sdot v9.4s, v28.16b, v24.4b[2]\n" + ".inst 0x4fb8eb94 // sdot v20.4s, v28.16b, v24.4b[3]\n" + "ldr q24, [x25, #0x40]\n" + ".inst 0x4f98e06a // sdot v10.4s, v3.16b, v24.4b[0]\n" + ".inst 0x4fb8e07d // sdot v29.4s, v3.16b, v24.4b[1]\n" + ".inst 0x4f98e869 // sdot v9.4s, v3.16b, v24.4b[2]\n" + ".inst 0x4fb8e874 // sdot v20.4s, v3.16b, v24.4b[3]\n" + "ldr q24, [x25, #0x50]\n" + ".inst 0x4f98e2ca // sdot v10.4s, v22.16b, v24.4b[0]\n" + ".inst 0x4fb8e2dd // sdot v29.4s, v22.16b, v24.4b[1]\n" + ".inst 0x4f98eac9 // sdot v9.4s, v22.16b, v24.4b[2]\n" + ".inst 0x4fb8ead4 // sdot v20.4s, v22.16b, v24.4b[3]\n" + "ldr q24, [x25, #0x60]\n" + ".inst 0x4f98e36a // sdot v10.4s, v27.16b, v24.4b[0]\n" + ".inst 0x4fb8e37d // sdot v29.4s, v27.16b, v24.4b[1]\n" + ".inst 0x4f98eb69 // sdot v9.4s, v27.16b, v24.4b[2]\n" + ".inst 0x4fb8eb74 // sdot v20.4s, v27.16b, v24.4b[3]\n" + "ldr q24, [x25, #0x70]\n" + "add x25, x25, #0x88\n" + ".inst 0x4f98e3ca // sdot v10.4s, v30.16b, v24.4b[0]\n" + ".inst 0x4fb8e3dd // sdot v29.4s, v30.16b, v24.4b[1]\n" + ".inst 0x4f98ebc9 // sdot v9.4s, v30.16b, v24.4b[2]\n" + ".inst 0x4fb8ebd4 // sdot v20.4s, v30.16b, v24.4b[3]\n" + "fmul v24.4s, v17.4s, v2.s[0]\n" + "scvtf v10.4s, v10.4s, #0x4\n" + "scvtf v29.4s, v29.4s, #0x4\n" + "scvtf v9.4s, v9.4s, #0x4\n" + "scvtf v20.4s, v20.4s, #0x4\n" + "fmla v15.4s, v10.4s, v24.4s\n" + "ldr q24, [x23, #0x0]\n" + "fmul v10.4s, v17.4s, v2.s[1]\n" + "fmla v19.4s, v29.4s, v10.4s\n" + "ldr q10, [x23, #0x10]\n" + "fmul v29.4s, v17.4s, v2.s[2]\n" + "fmul v2.4s, v17.4s, v2.s[3]\n" + "fmla v18.4s, v9.4s, v29.4s\n" + "movi v9.4s, #0x0\n" + "movi v29.4s, #0x0\n" + ".inst 0x4f98e189 // sdot v9.4s, v12.16b, v24.4b[0]\n" + ".inst 0x4fb8e19d // sdot v29.4s, v12.16b, v24.4b[1]\n" + "fmla v14.4s, v20.4s, v2.4s\n" + "movi v20.4s, #0x0\n" + "movi v2.4s, #0x0\n" + ".inst 0x4f98e994 // sdot v20.4s, v12.16b, v24.4b[2]\n" + ".inst 0x4fb8e982 // sdot v2.4s, v12.16b, v24.4b[3]\n" + "ldr q24, [x23, #0x20]\n" + ".inst 0x4f8ae3e9 // sdot v9.4s, v31.16b, v10.4b[0]\n" + ".inst 0x4faae3fd // sdot v29.4s, v31.16b, v10.4b[1]\n" + ".inst 0x4f8aebf4 // sdot v20.4s, v31.16b, v10.4b[2]\n" + ".inst 0x4faaebe2 // sdot v2.4s, v31.16b, v10.4b[3]\n" + "ldr q10, [x23, #0x30]\n" + ".inst 0x4f98e0c9 // sdot v9.4s, v6.16b, v24.4b[0]\n" + ".inst 0x4fb8e0dd // sdot v29.4s, v6.16b, v24.4b[1]\n" + ".inst 0x4f98e8d4 // sdot v20.4s, v6.16b, v24.4b[2]\n" + ".inst 0x4fb8e8c2 // sdot v2.4s, v6.16b, v24.4b[3]\n" + "ldr q24, [x23, #0x40]\n" + ".inst 0x4f8ae389 // sdot v9.4s, v28.16b, v10.4b[0]\n" + ".inst 0x4faae39d // sdot v29.4s, v28.16b, v10.4b[1]\n" + ".inst 0x4f8aeb94 // sdot v20.4s, v28.16b, v10.4b[2]\n" + ".inst 0x4faaeb82 // sdot v2.4s, v28.16b, v10.4b[3]\n" + "ldr q10, [x23, #0x50]\n" + ".inst 0x4f98e069 // sdot v9.4s, v3.16b, v24.4b[0]\n" + ".inst 0x4fb8e07d // sdot v29.4s, v3.16b, v24.4b[1]\n" + ".inst 0x4f98e874 // sdot v20.4s, v3.16b, v24.4b[2]\n" + ".inst 0x4fb8e862 // sdot v2.4s, v3.16b, v24.4b[3]\n" + "ldr q24, [x23, #0x60]\n" + ".inst 0x4f8ae2c9 // sdot v9.4s, v22.16b, v10.4b[0]\n" + ".inst 0x4faae2dd // sdot v29.4s, v22.16b, v10.4b[1]\n" + ".inst 0x4f8aead4 // sdot v20.4s, v22.16b, v10.4b[2]\n" + ".inst 0x4faaeac2 // sdot v2.4s, v22.16b, v10.4b[3]\n" + "ldr q10, [x23, #0x70]\n" + "add x23, x23, #0x88\n" + ".inst 0x4f98e369 // sdot v9.4s, v27.16b, v24.4b[0]\n" + ".inst 0x4fb8e37d // sdot v29.4s, v27.16b, v24.4b[1]\n" + ".inst 0x4f98eb74 // sdot v20.4s, v27.16b, v24.4b[2]\n" + ".inst 0x4fb8eb62 // sdot v2.4s, v27.16b, v24.4b[3]\n" + "ldr q24, [x22, #0x0]\n" + ".inst 0x4f8ae3c9 // sdot v9.4s, v30.16b, v10.4b[0]\n" + ".inst 0x4faae3dd // sdot v29.4s, v30.16b, v10.4b[1]\n" + ".inst 0x4f8aebd4 // sdot v20.4s, v30.16b, v10.4b[2]\n" + ".inst 0x4faaebc2 // sdot v2.4s, v30.16b, v10.4b[3]\n" + "fmul v10.4s, v17.4s, v26.s[0]\n" + "scvtf v9.4s, v9.4s, #0x4\n" + "scvtf v29.4s, v29.4s, #0x4\n" + "scvtf v20.4s, v20.4s, #0x4\n" + "scvtf v2.4s, v2.4s, #0x4\n" + "fmla v11.4s, v9.4s, v10.4s\n" + "ldr q9, [x22, #0x10]\n" + "fmul v10.4s, v17.4s, v26.s[1]\n" + "fmla v13.4s, v29.4s, v10.4s\n" + "ldr d29, [x22, #-0x8]\n" + "fmul v10.4s, v17.4s, v26.s[2]\n" + "fmul v26.4s, v17.4s, v26.s[3]\n" + "fcvtl v29.4s, v29.4h\n" + "fmla v23.4s, v20.4s, v10.4s\n" + "movi v20.4s, #0x0\n" + "movi v10.4s, #0x0\n" + "fmla v16.4s, v2.4s, v26.4s\n" + "movi v26.4s, #0x0\n" + "movi v2.4s, #0x0\n" + ".inst 0x4f98e194 // sdot v20.4s, v12.16b, v24.4b[0]\n" + ".inst 0x4fb8e18a // sdot v10.4s, v12.16b, v24.4b[1]\n" + ".inst 0x4f98e99a // sdot v26.4s, v12.16b, v24.4b[2]\n" + ".inst 0x4fb8e982 // sdot v2.4s, v12.16b, v24.4b[3]\n" + "ldr q24, [x22, #0x20]\n" + ".inst 0x4f89e3f4 // sdot v20.4s, v31.16b, v9.4b[0]\n" + ".inst 0x4fa9e3ea // sdot v10.4s, v31.16b, v9.4b[1]\n" + ".inst 0x4f89ebfa // sdot v26.4s, v31.16b, v9.4b[2]\n" + ".inst 0x4fa9ebe2 // sdot v2.4s, v31.16b, v9.4b[3]\n" + "ldr q9, [x22, #0x30]\n" + ".inst 0x4f98e0d4 // sdot v20.4s, v6.16b, v24.4b[0]\n" + ".inst 0x4fb8e0ca // sdot v10.4s, v6.16b, v24.4b[1]\n" + ".inst 0x4f98e8da // sdot v26.4s, v6.16b, v24.4b[2]\n" + ".inst 0x4fb8e8c2 // sdot v2.4s, v6.16b, v24.4b[3]\n" + "ldr q24, [x22, #0x40]\n" + ".inst 0x4f89e394 // sdot v20.4s, v28.16b, v9.4b[0]\n" + ".inst 0x4fa9e38a // sdot v10.4s, v28.16b, v9.4b[1]\n" + ".inst 0x4f89eb9a // sdot v26.4s, v28.16b, v9.4b[2]\n" + ".inst 0x4fa9eb82 // sdot v2.4s, v28.16b, v9.4b[3]\n" + "ldr q9, [x22, #0x50]\n" + ".inst 0x4f98e074 // sdot v20.4s, v3.16b, v24.4b[0]\n" + ".inst 0x4fb8e06a // sdot v10.4s, v3.16b, v24.4b[1]\n" + ".inst 0x4f98e87a // sdot v26.4s, v3.16b, v24.4b[2]\n" + ".inst 0x4fb8e862 // sdot v2.4s, v3.16b, v24.4b[3]\n" + "ldr q24, [x22, #0x60]\n" + ".inst 0x4f89e2d4 // sdot v20.4s, v22.16b, v9.4b[0]\n" + ".inst 0x4fa9e2ca // sdot v10.4s, v22.16b, v9.4b[1]\n" + ".inst 0x4f89eada // sdot v26.4s, v22.16b, v9.4b[2]\n" + ".inst 0x4fa9eac2 // sdot v2.4s, v22.16b, v9.4b[3]\n" + "ldr q9, [x22, #0x70]\n" + "add x22, x22, #0x88\n" + ".inst 0x4f98e374 // sdot v20.4s, v27.16b, v24.4b[0]\n" + ".inst 0x4fb8e36a // sdot v10.4s, v27.16b, v24.4b[1]\n" + ".inst 0x4f98eb7a // sdot v26.4s, v27.16b, v24.4b[2]\n" + ".inst 0x4fb8eb62 // sdot v2.4s, v27.16b, v24.4b[3]\n" + "ldr q24, [x21, #0x0]\n" + ".inst 0x4f89e3d4 // sdot v20.4s, v30.16b, v9.4b[0]\n" + ".inst 0x4fa9e3ca // sdot v10.4s, v30.16b, v9.4b[1]\n" + ".inst 0x4f89ebda // sdot v26.4s, v30.16b, v9.4b[2]\n" + ".inst 0x4fa9ebc2 // sdot v2.4s, v30.16b, v9.4b[3]\n" + "fmul v9.4s, v17.4s, v29.s[0]\n" + "scvtf v20.4s, v20.4s, #0x4\n" + "scvtf v10.4s, v10.4s, #0x4\n" + "scvtf v26.4s, v26.4s, #0x4\n" + "scvtf v2.4s, v2.4s, #0x4\n" + "fmla v25.4s, v20.4s, v9.4s\n" + "ldr q9, [x21, #0x10]\n" + "fmul v20.4s, v17.4s, v29.s[1]\n" + "fmla v7.4s, v10.4s, v20.4s\n" + "ldr d20, [x21, #-0x8]\n" + "fmul v10.4s, v17.4s, v29.s[2]\n" + "fmul v29.4s, v17.4s, v29.s[3]\n" + "fcvtl v20.4s, v20.4h\n" + "fmla v0.4s, v26.4s, v10.4s\n" + "movi v26.4s, #0x0\n" + "movi v10.4s, #0x0\n" + "fmla v4.4s, v2.4s, v29.4s\n" + "movi v2.4s, #0x0\n" + "movi v29.4s, #0x0\n" + ".inst 0x4f98e19a // sdot v26.4s, v12.16b, v24.4b[0]\n" + ".inst 0x4fb8e18a // sdot v10.4s, v12.16b, v24.4b[1]\n" + ".inst 0x4f98e982 // sdot v2.4s, v12.16b, v24.4b[2]\n" + ".inst 0x4fb8e99d // sdot v29.4s, v12.16b, v24.4b[3]\n" + "ldr q12, [x21, #0x20]\n" + "fmul v24.4s, v17.4s, v20.s[0]\n" + ".inst 0x4f89e3fa // sdot v26.4s, v31.16b, v9.4b[0]\n" + ".inst 0x4fa9e3ea // sdot v10.4s, v31.16b, v9.4b[1]\n" + ".inst 0x4f89ebe2 // sdot v2.4s, v31.16b, v9.4b[2]\n" + ".inst 0x4fa9ebfd // sdot v29.4s, v31.16b, v9.4b[3]\n" + "ldr q9, [x21, #0x30]\n" + "fmul v31.4s, v17.4s, v20.s[1]\n" + ".inst 0x4f8ce0da // sdot v26.4s, v6.16b, v12.4b[0]\n" + ".inst 0x4face0ca // sdot v10.4s, v6.16b, v12.4b[1]\n" + ".inst 0x4f8ce8c2 // sdot v2.4s, v6.16b, v12.4b[2]\n" + ".inst 0x4face8dd // sdot v29.4s, v6.16b, v12.4b[3]\n" + "ldr q12, [x21, #0x40]\n" + "fmul v6.4s, v17.4s, v20.s[2]\n" + "fmul v20.4s, v17.4s, v20.s[3]\n" + ".inst 0x4f89e39a // sdot v26.4s, v28.16b, v9.4b[0]\n" + ".inst 0x4fa9e38a // sdot v10.4s, v28.16b, v9.4b[1]\n" + ".inst 0x4f89eb82 // sdot v2.4s, v28.16b, v9.4b[2]\n" + ".inst 0x4fa9eb9d // sdot v29.4s, v28.16b, v9.4b[3]\n" + "ldr q9, [x21, #0x50]\n" + ".inst 0x4f8ce07a // sdot v26.4s, v3.16b, v12.4b[0]\n" + ".inst 0x4face06a // sdot v10.4s, v3.16b, v12.4b[1]\n" + ".inst 0x4f8ce862 // sdot v2.4s, v3.16b, v12.4b[2]\n" + ".inst 0x4face87d // sdot v29.4s, v3.16b, v12.4b[3]\n" + "ldr q12, [x21, #0x60]\n" + ".inst 0x4f89e2da // sdot v26.4s, v22.16b, v9.4b[0]\n" + ".inst 0x4fa9e2ca // sdot v10.4s, v22.16b, v9.4b[1]\n" + ".inst 0x4f89eac2 // sdot v2.4s, v22.16b, v9.4b[2]\n" + ".inst 0x4fa9eadd // sdot v29.4s, v22.16b, v9.4b[3]\n" + "ldr q17, [x21, #0x70]\n" + "add x21, x21, #0x88\n" + ".inst 0x4f8ce37a // sdot v26.4s, v27.16b, v12.4b[0]\n" + ".inst 0x4face36a // sdot v10.4s, v27.16b, v12.4b[1]\n" + ".inst 0x4f8ceb62 // sdot v2.4s, v27.16b, v12.4b[2]\n" + ".inst 0x4faceb7d // sdot v29.4s, v27.16b, v12.4b[3]\n" + ".inst 0x4f91e3da // sdot v26.4s, v30.16b, v17.4b[0]\n" + ".inst 0x4fb1e3ca // sdot v10.4s, v30.16b, v17.4b[1]\n" + ".inst 0x4f91ebc2 // sdot v2.4s, v30.16b, v17.4b[2]\n" + ".inst 0x4fb1ebdd // sdot v29.4s, v30.16b, v17.4b[3]\n" + "scvtf v26.4s, v26.4s, #0x4\n" + "scvtf v10.4s, v10.4s, #0x4\n" + "fmla v5.4s, v26.4s, v24.4s\n" + "scvtf v2.4s, v2.4s, #0x4\n" + "scvtf v29.4s, v29.4s, #0x4\n" + "fmla v21.4s, v10.4s, v31.4s\n" + "fmla v8.4s, v2.4s, v6.4s\n" + "fmla v1.4s, v29.4s, v20.4s\n" + "bgt 3b\n" + "mov x20, %x[res_ptr]\n" + "subs x27, x27, #0x4\n" + "add %x[res_ptr], %x[res_ptr], #0x10\n" + "str q15, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q19, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q18, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q14, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q11, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q13, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q23, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q16, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q25, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q7, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q0, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q4, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q5, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q21, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q8, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q1, [x20, #0x0]\n" + "bne 2b\n" + "mov x20, #0x4\n" + "sub x10, x10, #0x10\n" + "cmp x10, #0x10\n" + "mov %x[res_ptr], x26\n" + "madd %x[a_ptr], x20, x9, %x[a_ptr]\n" + "bge 1b\n" + "4:" // Row loop skip + "cbz x10, 9f\n" + "5:" // Row tail: Row loop + "add x24, %x[b_ptr], #0x8\n" + "mov x23, %x[nc]\n" + "add x22, %x[res_ptr], %x[res_stride], LSL #2\n" + "6:" // Row tail: Column loop + "movi v15.16b, #0x0\n" + "movi v19.16b, #0x0\n" + "add x25, %x[a_ptr], #0x8\n" + "mov x21, %x[nb]\n" + "movi v18.16b, #0x0\n" + "movi v14.16b, #0x0\n" + "7:" // Row tail: Block loop + "ldr q7, [x24, #0x0]\n" + "ldr q5, [x25, #0x0]\n" + "movi v9.16b, #0x4\n" + "movi v4.4s, #0x0\n" + "ldr q3, [x24, #0x10]\n" + "ldr q2, [x25, #0x10]\n" + "movi v1.4s, #0x0\n" + "movi v0.4s, #0x0\n" + "ldr q13, [x24, #0x20]\n" + "ldr q31, [x25, #0x20]\n" + "movi v30.4s, #0x0\n" + "movi v29.16b, #0xf0\n" + "ldr q28, [x24, #0x30]\n" + "ldr q27, [x25, #0x30]\n" + "sshl v20.16b, v7.16b, v9.16b\n" + "sub x20, x24, #0x8\n" + "ldr q26, [x25, #0x40]\n" + "ldr q25, [x25, #0x50]\n" + "sshl v17.16b, v3.16b, v9.16b\n" + "and v7.16b, v7.16b, v29.16b\n" + "ldr q24, [x25, #0x60]\n" + "ldr q16, [x25, #0x70]\n" + "sshl v22.16b, v13.16b, v9.16b\n" + "and v3.16b, v3.16b, v29.16b\n" + "ldr d21, [x20, #0x0]\n" + "ldr d12, [x25, #-0x8]\n" + ".inst 0x4f85e284 // sdot v4.4s, v20.16b, v5.4b[0]\n" + ".inst 0x4fa5e281 // sdot v1.4s, v20.16b, v5.4b[1]\n" + ".inst 0x4f85ea80 // sdot v0.4s, v20.16b, v5.4b[2]\n" + ".inst 0x4fa5ea9e // sdot v30.4s, v20.16b, v5.4b[3]\n" + "sshl v9.16b, v28.16b, v9.16b\n" + "subs x21, x21, #0x1\n" + "and v13.16b, v13.16b, v29.16b\n" + "and v28.16b, v28.16b, v29.16b\n" + "add x25, x25, #0x88\n" + "add x24, x24, #0x48\n" + "fcvtl v21.4s, v21.4h\n" + "fcvtl v12.4s, v12.4h\n" + ".inst 0x4f82e224 // sdot v4.4s, v17.16b, v2.4b[0]\n" + ".inst 0x4fa2e221 // sdot v1.4s, v17.16b, v2.4b[1]\n" + ".inst 0x4f82ea20 // sdot v0.4s, v17.16b, v2.4b[2]\n" + ".inst 0x4fa2ea3e // sdot v30.4s, v17.16b, v2.4b[3]\n" + "fmul v11.4s, v21.4s, v12.s[0]\n" + "fmul v23.4s, v21.4s, v12.s[1]\n" + "fmul v17.4s, v21.4s, v12.s[2]\n" + ".inst 0x4f9fe2c4 // sdot v4.4s, v22.16b, v31.4b[0]\n" + "fmul v6.4s, v21.4s, v12.s[3]\n" + ".inst 0x4fbfe2c1 // sdot v1.4s, v22.16b, v31.4b[1]\n" + ".inst 0x4f9feac0 // sdot v0.4s, v22.16b, v31.4b[2]\n" + ".inst 0x4fbfeade // sdot v30.4s, v22.16b, v31.4b[3]\n" + ".inst 0x4f9be124 // sdot v4.4s, v9.16b, v27.4b[0]\n" + ".inst 0x4fbbe121 // sdot v1.4s, v9.16b, v27.4b[1]\n" + ".inst 0x4f9be920 // sdot v0.4s, v9.16b, v27.4b[2]\n" + ".inst 0x4fbbe93e // sdot v30.4s, v9.16b, v27.4b[3]\n" + ".inst 0x4f9ae0e4 // sdot v4.4s, v7.16b, v26.4b[0]\n" + ".inst 0x4fbae0e1 // sdot v1.4s, v7.16b, v26.4b[1]\n" + ".inst 0x4f9ae8e0 // sdot v0.4s, v7.16b, v26.4b[2]\n" + ".inst 0x4fbae8fe // sdot v30.4s, v7.16b, v26.4b[3]\n" + ".inst 0x4f99e064 // sdot v4.4s, v3.16b, v25.4b[0]\n" + ".inst 0x4fb9e061 // sdot v1.4s, v3.16b, v25.4b[1]\n" + ".inst 0x4f99e860 // sdot v0.4s, v3.16b, v25.4b[2]\n" + ".inst 0x4fb9e87e // sdot v30.4s, v3.16b, v25.4b[3]\n" + ".inst 0x4f98e1a4 // sdot v4.4s, v13.16b, v24.4b[0]\n" + ".inst 0x4fb8e1a1 // sdot v1.4s, v13.16b, v24.4b[1]\n" + ".inst 0x4f98e9a0 // sdot v0.4s, v13.16b, v24.4b[2]\n" + ".inst 0x4fb8e9be // sdot v30.4s, v13.16b, v24.4b[3]\n" + ".inst 0x4f90e384 // sdot v4.4s, v28.16b, v16.4b[0]\n" + ".inst 0x4fb0e381 // sdot v1.4s, v28.16b, v16.4b[1]\n" + ".inst 0x4f90eb80 // sdot v0.4s, v28.16b, v16.4b[2]\n" + ".inst 0x4fb0eb9e // sdot v30.4s, v28.16b, v16.4b[3]\n" + "scvtf v4.4s, v4.4s, #0x4\n" + "scvtf v1.4s, v1.4s, #0x4\n" + "scvtf v0.4s, v0.4s, #0x4\n" + "fmla v15.4s, v4.4s, v11.4s\n" + "scvtf v30.4s, v30.4s, #0x4\n" + "fmla v19.4s, v1.4s, v23.4s\n" + "fmla v18.4s, v0.4s, v17.4s\n" + "fmla v14.4s, v30.4s, v6.4s\n" + "bgt 7b\n" + "mov x20, %x[res_ptr]\n" + "cmp x10, #0x1\n" + "str q15, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "ble 8f\n" + "cmp x10, #0x2\n" + "str q19, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "ble 8f\n" + "cmp x10, #0x3\n" + "str q18, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "ble 8f\n" + "str q14, [x20, #0x0]\n" + "8:" // Row tail: Accumulator store skip + "subs x23, x23, #0x4\n" + "add %x[res_ptr], %x[res_ptr], #0x10\n" + "bne 6b\n" + "subs x10, x10, #0x4\n" + "add %x[a_ptr], %x[a_ptr], x9\n" + "mov %x[res_ptr], x22\n" + "bgt 5b\n" + "9:" // Row tail: Row loop skip + : [a_ptr] "+&r" (a_ptr), [res_ptr] "+&r" (res_ptr) + : [b_ptr] "r" (b_ptr), [nr] "r" (nr), [nb] "r" (nb), [res_stride] "r" (res_stride), [nc] "r" (nc) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28" + ); +#else + float sumf[4][4]; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); + sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + + (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4; + } + sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d[m]); + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + } + } + } +#endif +} + +void ggml_gemm_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 8; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8) + if (svcntw() == 8) { + GGML_ASSERT(!(ggml_cpu_has_sve() && (svcntw() == 8)) && + "__ARM_FEATURE_SVE defined, use the Q4_0_8_8 quantization format for optimal performance"); + } +#endif +#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) && ! ((defined(_MSC_VER)) && ! defined(__clang__)) + const void * b_ptr = vx; + const void * a_ptr = vy; + float * res_ptr = s; + size_t res_stride = bs * sizeof(float); + + __asm__ __volatile__( + "mov x10, %x[nr]\n" + "mov x9, #0x88\n" + "cmp x10, #0x10\n" + "mul x9, %x[nb], x9\n" + "blt 4f\n" + "1:" // Row loop + "add x28, %x[b_ptr], #0x8\n" + "mov x27, %x[nc]\n" + "add x26, %x[res_ptr], %x[res_stride], LSL #4\n" + "2:" // Column loop + "add x25, %x[a_ptr], #0x8\n" + "movi v2.16b, #0x0\n" + "movi v10.16b, #0x0\n" + "mov x24, %x[nb]\n" + "add x23, x25, x9\n" + "movi v12.16b, #0x0\n" + "movi v28.16b, #0x0\n" + "add x22, x23, x9\n" + "movi v11.16b, #0x0\n" + "movi v13.16b, #0x0\n" + "add x21, x22, x9\n" + "movi v22.16b, #0x0\n" + "movi v23.16b, #0x0\n" + "movi v25.16b, #0x0\n" + "movi v5.16b, #0x0\n" + "movi v7.16b, #0x0\n" + "movi v4.16b, #0x0\n" + "movi v6.16b, #0x0\n" + "movi v30.16b, #0x0\n" + "movi v24.16b, #0x0\n" + "movi v14.16b, #0x0\n" + "3:" // Block loop + "ldr q21, [x28, #0x0]\n" + "ldr q16, [x28, #0x10]\n" + "movi v1.16b, #0x4\n" + "movi v19.4s, #0x0\n" + "ldr q27, [x25, #0x0]\n" + "ldr q15, [x25, #0x10]\n" + "movi v26.4s, #0x0\n" + "movi v18.4s, #0x0\n" + "ldr q29, [x28, #0x20]\n" + "ldr q3, [x28, #0x30]\n" + "movi v17.4s, #0x0\n" + "movi v0.16b, #0xf0\n" + "ldr d20, [x25, #-0x8]\n" + "ldr d9, [x23, #-0x8]\n" + "sshl v8.16b, v21.16b, v1.16b\n" + "sshl v31.16b, v16.16b, v1.16b\n" + "and v21.16b, v21.16b, v0.16b\n" + "and v16.16b, v16.16b, v0.16b\n" + "sub x20, x28, #0x8\n" + "subs x24, x24, #0x1\n" + "add x28, x28, #0x48\n" + ".inst 0x4e88a773 // smmla v19.4s, v27.16b, v8.16b\n" + ".inst 0x4e9fa77a // smmla v26.4s, v27.16b, v31.16b\n" + "ldr q27, [x25, #0x20]\n" + ".inst 0x4e88a5f2 // smmla v18.4s, v15.16b, v8.16b\n" + ".inst 0x4e9fa5f1 // smmla v17.4s, v15.16b, v31.16b\n" + "sshl v15.16b, v29.16b, v1.16b\n" + "sshl v1.16b, v3.16b, v1.16b\n" + "and v29.16b, v29.16b, v0.16b\n" + "and v3.16b, v3.16b, v0.16b\n" + "ldr q0, [x25, #0x30]\n" + "fcvtl v20.4s, v20.4h\n" + ".inst 0x4e8fa773 // smmla v19.4s, v27.16b, v15.16b\n" + "fcvtl v9.4s, v9.4h\n" + ".inst 0x4e81a77a // smmla v26.4s, v27.16b, v1.16b\n" + "ldr q27, [x25, #0x40]\n" + ".inst 0x4e8fa412 // smmla v18.4s, v0.16b, v15.16b\n" + ".inst 0x4e81a411 // smmla v17.4s, v0.16b, v1.16b\n" + "ldr q0, [x25, #0x50]\n" + ".inst 0x4e95a773 // smmla v19.4s, v27.16b, v21.16b\n" + ".inst 0x4e90a77a // smmla v26.4s, v27.16b, v16.16b\n" + "ldr q27, [x25, #0x60]\n" + ".inst 0x4e95a412 // smmla v18.4s, v0.16b, v21.16b\n" + ".inst 0x4e90a411 // smmla v17.4s, v0.16b, v16.16b\n" + "ldr q0, [x25, #0x70]\n" + "add x25, x25, #0x88\n" + ".inst 0x4e9da773 // smmla v19.4s, v27.16b, v29.16b\n" + ".inst 0x4e83a77a // smmla v26.4s, v27.16b, v3.16b\n" + "ldr d27, [x20, #0x0]\n" + ".inst 0x4e9da412 // smmla v18.4s, v0.16b, v29.16b\n" + ".inst 0x4e83a411 // smmla v17.4s, v0.16b, v3.16b\n" + "fcvtl v27.4s, v27.4h\n" + "uzp1 v0.2d, v19.2d, v26.2d\n" + "uzp2 v26.2d, v19.2d, v26.2d\n" + "fmul v19.4s, v27.4s, v20.s[0]\n" + "scvtf v0.4s, v0.4s, #0x4\n" + "scvtf v26.4s, v26.4s, #0x4\n" + "fmla v2.4s, v0.4s, v19.4s\n" + "ldr q19, [x23, #0x0]\n" + "uzp1 v0.2d, v18.2d, v17.2d\n" + "uzp2 v18.2d, v18.2d, v17.2d\n" + "fmul v17.4s, v27.4s, v20.s[1]\n" + "scvtf v0.4s, v0.4s, #0x4\n" + "scvtf v18.4s, v18.4s, #0x4\n" + "fmla v10.4s, v26.4s, v17.4s\n" + "ldr q17, [x23, #0x10]\n" + "fmul v26.4s, v27.4s, v20.s[2]\n" + "fmul v20.4s, v27.4s, v20.s[3]\n" + "fmla v12.4s, v0.4s, v26.4s\n" + "ldr d0, [x22, #-0x8]\n" + "ldr d26, [x21, #-0x8]\n" + "fcvtl v0.4s, v0.4h\n" + "fmla v28.4s, v18.4s, v20.4s\n" + "movi v20.4s, #0x0\n" + "movi v18.4s, #0x0\n" + ".inst 0x4e88a674 // smmla v20.4s, v19.16b, v8.16b\n" + ".inst 0x4e9fa672 // smmla v18.4s, v19.16b, v31.16b\n" + "ldr q19, [x23, #0x20]\n" + "fcvtl v26.4s, v26.4h\n" + ".inst 0x4e8fa674 // smmla v20.4s, v19.16b, v15.16b\n" + ".inst 0x4e81a672 // smmla v18.4s, v19.16b, v1.16b\n" + "ldr q19, [x23, #0x40]\n" + ".inst 0x4e95a674 // smmla v20.4s, v19.16b, v21.16b\n" + ".inst 0x4e90a672 // smmla v18.4s, v19.16b, v16.16b\n" + "ldr q19, [x23, #0x60]\n" + ".inst 0x4e9da674 // smmla v20.4s, v19.16b, v29.16b\n" + ".inst 0x4e83a672 // smmla v18.4s, v19.16b, v3.16b\n" + "uzp1 v19.2d, v20.2d, v18.2d\n" + "scvtf v19.4s, v19.4s, #0x4\n" + "uzp2 v20.2d, v20.2d, v18.2d\n" + "fmul v18.4s, v27.4s, v9.s[0]\n" + "scvtf v20.4s, v20.4s, #0x4\n" + "fmla v11.4s, v19.4s, v18.4s\n" + "ldr q18, [x22, #0x0]\n" + "fmul v19.4s, v27.4s, v9.s[1]\n" + "fmla v13.4s, v20.4s, v19.4s\n" + "movi v19.4s, #0x0\n" + "movi v20.4s, #0x0\n" + ".inst 0x4e88a633 // smmla v19.4s, v17.16b, v8.16b\n" + ".inst 0x4e9fa634 // smmla v20.4s, v17.16b, v31.16b\n" + "ldr q17, [x23, #0x30]\n" + ".inst 0x4e8fa633 // smmla v19.4s, v17.16b, v15.16b\n" + ".inst 0x4e81a634 // smmla v20.4s, v17.16b, v1.16b\n" + "ldr q17, [x23, #0x50]\n" + ".inst 0x4e95a633 // smmla v19.4s, v17.16b, v21.16b\n" + ".inst 0x4e90a634 // smmla v20.4s, v17.16b, v16.16b\n" + "ldr q17, [x23, #0x70]\n" + "add x23, x23, #0x88\n" + ".inst 0x4e9da633 // smmla v19.4s, v17.16b, v29.16b\n" + ".inst 0x4e83a634 // smmla v20.4s, v17.16b, v3.16b\n" + "uzp1 v17.2d, v19.2d, v20.2d\n" + "scvtf v17.4s, v17.4s, #0x4\n" + "uzp2 v20.2d, v19.2d, v20.2d\n" + "fmul v19.4s, v27.4s, v9.s[2]\n" + "fmul v9.4s, v27.4s, v9.s[3]\n" + "scvtf v20.4s, v20.4s, #0x4\n" + "fmla v22.4s, v17.4s, v19.4s\n" + "ldr q17, [x22, #0x10]\n" + "movi v19.4s, #0x0\n" + ".inst 0x4e88a653 // smmla v19.4s, v18.16b, v8.16b\n" + "fmla v23.4s, v20.4s, v9.4s\n" + "movi v20.4s, #0x0\n" + "movi v9.4s, #0x0\n" + ".inst 0x4e9fa654 // smmla v20.4s, v18.16b, v31.16b\n" + "ldr q18, [x22, #0x20]\n" + ".inst 0x4e88a629 // smmla v9.4s, v17.16b, v8.16b\n" + ".inst 0x4e8fa653 // smmla v19.4s, v18.16b, v15.16b\n" + ".inst 0x4e81a654 // smmla v20.4s, v18.16b, v1.16b\n" + "ldr q18, [x22, #0x40]\n" + ".inst 0x4e95a653 // smmla v19.4s, v18.16b, v21.16b\n" + ".inst 0x4e90a654 // smmla v20.4s, v18.16b, v16.16b\n" + "ldr q18, [x22, #0x60]\n" + ".inst 0x4e9da653 // smmla v19.4s, v18.16b, v29.16b\n" + ".inst 0x4e83a654 // smmla v20.4s, v18.16b, v3.16b\n" + "movi v18.4s, #0x0\n" + ".inst 0x4e9fa632 // smmla v18.4s, v17.16b, v31.16b\n" + "ldr q17, [x22, #0x30]\n" + ".inst 0x4e8fa629 // smmla v9.4s, v17.16b, v15.16b\n" + ".inst 0x4e81a632 // smmla v18.4s, v17.16b, v1.16b\n" + "ldr q17, [x22, #0x50]\n" + ".inst 0x4e95a629 // smmla v9.4s, v17.16b, v21.16b\n" + ".inst 0x4e90a632 // smmla v18.4s, v17.16b, v16.16b\n" + "ldr q17, [x22, #0x70]\n" + "add x22, x22, #0x88\n" + ".inst 0x4e9da629 // smmla v9.4s, v17.16b, v29.16b\n" + ".inst 0x4e83a632 // smmla v18.4s, v17.16b, v3.16b\n" + "uzp1 v17.2d, v19.2d, v20.2d\n" + "uzp2 v20.2d, v19.2d, v20.2d\n" + "fmul v19.4s, v27.4s, v0.s[0]\n" + "scvtf v17.4s, v17.4s, #0x4\n" + "scvtf v20.4s, v20.4s, #0x4\n" + "fmla v25.4s, v17.4s, v19.4s\n" + "ldr q19, [x21, #0x0]\n" + "fmul v17.4s, v27.4s, v0.s[1]\n" + "fmla v5.4s, v20.4s, v17.4s\n" + "ldr q17, [x21, #0x10]\n" + "uzp1 v20.2d, v9.2d, v18.2d\n" + "uzp2 v9.2d, v9.2d, v18.2d\n" + "fmul v18.4s, v27.4s, v0.s[2]\n" + "fmul v0.4s, v27.4s, v0.s[3]\n" + "scvtf v20.4s, v20.4s, #0x4\n" + "scvtf v9.4s, v9.4s, #0x4\n" + "fmla v7.4s, v20.4s, v18.4s\n" + "movi v20.4s, #0x0\n" + "movi v18.4s, #0x0\n" + ".inst 0x4e88a674 // smmla v20.4s, v19.16b, v8.16b\n" + ".inst 0x4e9fa672 // smmla v18.4s, v19.16b, v31.16b\n" + "ldr q19, [x21, #0x20]\n" + "fmla v4.4s, v9.4s, v0.4s\n" + "movi v9.4s, #0x0\n" + "movi v0.4s, #0x0\n" + ".inst 0x4e88a629 // smmla v9.4s, v17.16b, v8.16b\n" + "fmul v8.4s, v27.4s, v26.s[0]\n" + ".inst 0x4e9fa620 // smmla v0.4s, v17.16b, v31.16b\n" + "ldr q17, [x21, #0x30]\n" + ".inst 0x4e8fa674 // smmla v20.4s, v19.16b, v15.16b\n" + "fmul v31.4s, v27.4s, v26.s[1]\n" + ".inst 0x4e81a672 // smmla v18.4s, v19.16b, v1.16b\n" + "ldr q19, [x21, #0x40]\n" + ".inst 0x4e8fa629 // smmla v9.4s, v17.16b, v15.16b\n" + "fmul v15.4s, v27.4s, v26.s[2]\n" + "fmul v27.4s, v27.4s, v26.s[3]\n" + ".inst 0x4e81a620 // smmla v0.4s, v17.16b, v1.16b\n" + "ldr q1, [x21, #0x50]\n" + ".inst 0x4e95a674 // smmla v20.4s, v19.16b, v21.16b\n" + ".inst 0x4e90a672 // smmla v18.4s, v19.16b, v16.16b\n" + "ldr q26, [x21, #0x60]\n" + ".inst 0x4e95a429 // smmla v9.4s, v1.16b, v21.16b\n" + ".inst 0x4e90a420 // smmla v0.4s, v1.16b, v16.16b\n" + "ldr q21, [x21, #0x70]\n" + "add x21, x21, #0x88\n" + ".inst 0x4e9da754 // smmla v20.4s, v26.16b, v29.16b\n" + ".inst 0x4e83a752 // smmla v18.4s, v26.16b, v3.16b\n" + ".inst 0x4e9da6a9 // smmla v9.4s, v21.16b, v29.16b\n" + ".inst 0x4e83a6a0 // smmla v0.4s, v21.16b, v3.16b\n" + "uzp1 v29.2d, v20.2d, v18.2d\n" + "uzp2 v21.2d, v20.2d, v18.2d\n" + "scvtf v29.4s, v29.4s, #0x4\n" + "uzp1 v18.2d, v9.2d, v0.2d\n" + "uzp2 v16.2d, v9.2d, v0.2d\n" + "scvtf v21.4s, v21.4s, #0x4\n" + "fmla v6.4s, v29.4s, v8.4s\n" + "scvtf v18.4s, v18.4s, #0x4\n" + "scvtf v16.4s, v16.4s, #0x4\n" + "fmla v30.4s, v21.4s, v31.4s\n" + "fmla v24.4s, v18.4s, v15.4s\n" + "fmla v14.4s, v16.4s, v27.4s\n" + "bgt 3b\n" + "mov x20, %x[res_ptr]\n" + "subs x27, x27, #0x4\n" + "add %x[res_ptr], %x[res_ptr], #0x10\n" + "str q2, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q10, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q12, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q28, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q11, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q13, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q22, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q23, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q25, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q5, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q7, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q4, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q6, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q30, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q24, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q14, [x20, #0x0]\n" + "bne 2b\n" + "mov x20, #0x4\n" + "sub x10, x10, #0x10\n" + "cmp x10, #0x10\n" + "mov %x[res_ptr], x26\n" + "madd %x[a_ptr], x20, x9, %x[a_ptr]\n" + "bge 1b\n" + "4:" // Row loop skip + "cbz x10, 9f\n" + "5:" // Row tail: Row loop + "add x24, %x[b_ptr], #0x8\n" + "mov x23, %x[nc]\n" + "add x22, %x[res_ptr], %x[res_stride], LSL #2\n" + "6:" // Row tail: Column loop + "movi v2.16b, #0x0\n" + "movi v10.16b, #0x0\n" + "add x25, %x[a_ptr], #0x8\n" + "mov x21, %x[nb]\n" + "movi v12.16b, #0x0\n" + "movi v28.16b, #0x0\n" + "7:" // Row tail: Block loop + "ldr q6, [x24, #0x0]\n" + "ldr q5, [x24, #0x10]\n" + "movi v17.16b, #0x4\n" + "movi v8.4s, #0x0\n" + "ldr q4, [x25, #0x0]\n" + "ldr q13, [x25, #0x10]\n" + "movi v27.4s, #0x0\n" + "movi v0.4s, #0x0\n" + "ldr q31, [x24, #0x20]\n" + "ldr q14, [x24, #0x30]\n" + "movi v29.4s, #0x0\n" + "movi v22.16b, #0xf0\n" + "ldr q11, [x25, #0x20]\n" + "ldr q23, [x25, #0x30]\n" + "sshl v21.16b, v6.16b, v17.16b\n" + "sshl v16.16b, v5.16b, v17.16b\n" + "ldr q20, [x25, #0x40]\n" + "ldr q26, [x25, #0x50]\n" + "and v6.16b, v6.16b, v22.16b\n" + "and v5.16b, v5.16b, v22.16b\n" + "ldr q25, [x25, #0x60]\n" + "ldr q3, [x25, #0x70]\n" + "sshl v19.16b, v31.16b, v17.16b\n" + "sshl v18.16b, v14.16b, v17.16b\n" + "ldr d17, [x25, #-0x8]\n" + ".inst 0x4e95a488 // smmla v8.4s, v4.16b, v21.16b\n" + ".inst 0x4e90a49b // smmla v27.4s, v4.16b, v16.16b\n" + "and v31.16b, v31.16b, v22.16b\n" + ".inst 0x4e95a5a0 // smmla v0.4s, v13.16b, v21.16b\n" + ".inst 0x4e90a5bd // smmla v29.4s, v13.16b, v16.16b\n" + "and v14.16b, v14.16b, v22.16b\n" + "sub x20, x24, #0x8\n" + "ldr d16, [x20, #0x0]\n" + "subs x21, x21, #0x1\n" + "add x25, x25, #0x88\n" + "fcvtl v17.4s, v17.4h\n" + "add x24, x24, #0x48\n" + ".inst 0x4e93a568 // smmla v8.4s, v11.16b, v19.16b\n" + ".inst 0x4e92a57b // smmla v27.4s, v11.16b, v18.16b\n" + ".inst 0x4e93a6e0 // smmla v0.4s, v23.16b, v19.16b\n" + ".inst 0x4e92a6fd // smmla v29.4s, v23.16b, v18.16b\n" + "fcvtl v16.4s, v16.4h\n" + ".inst 0x4e86a688 // smmla v8.4s, v20.16b, v6.16b\n" + ".inst 0x4e85a69b // smmla v27.4s, v20.16b, v5.16b\n" + "fmul v23.4s, v16.4s, v17.s[0]\n" + "fmul v21.4s, v16.4s, v17.s[1]\n" + "fmul v1.4s, v16.4s, v17.s[2]\n" + "fmul v20.4s, v16.4s, v17.s[3]\n" + ".inst 0x4e86a740 // smmla v0.4s, v26.16b, v6.16b\n" + ".inst 0x4e85a75d // smmla v29.4s, v26.16b, v5.16b\n" + ".inst 0x4e9fa728 // smmla v8.4s, v25.16b, v31.16b\n" + ".inst 0x4e8ea73b // smmla v27.4s, v25.16b, v14.16b\n" + ".inst 0x4e9fa460 // smmla v0.4s, v3.16b, v31.16b\n" + ".inst 0x4e8ea47d // smmla v29.4s, v3.16b, v14.16b\n" + "uzp1 v19.2d, v8.2d, v27.2d\n" + "uzp2 v18.2d, v8.2d, v27.2d\n" + "scvtf v19.4s, v19.4s, #0x4\n" + "uzp1 v17.2d, v0.2d, v29.2d\n" + "uzp2 v16.2d, v0.2d, v29.2d\n" + "scvtf v18.4s, v18.4s, #0x4\n" + "fmla v2.4s, v19.4s, v23.4s\n" + "scvtf v17.4s, v17.4s, #0x4\n" + "scvtf v16.4s, v16.4s, #0x4\n" + "fmla v10.4s, v18.4s, v21.4s\n" + "fmla v12.4s, v17.4s, v1.4s\n" + "fmla v28.4s, v16.4s, v20.4s\n" + "bgt 7b\n" + "mov x20, %x[res_ptr]\n" + "cmp x10, #0x1\n" + "str q2, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "ble 8f\n" + "cmp x10, #0x2\n" + "str q10, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "ble 8f\n" + "cmp x10, #0x3\n" + "str q12, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "ble 8f\n" + "str q28, [x20, #0x0]\n" + "8:" // Row tail: Accumulator store skip + "subs x23, x23, #0x4\n" + "add %x[res_ptr], %x[res_ptr], #0x10\n" + "bne 6b\n" + "subs x10, x10, #0x4\n" + "add %x[a_ptr], %x[a_ptr], x9\n" + "mov %x[res_ptr], x22\n" + "bgt 5b\n" + "9:" // Row tail: Row loop skip + : [a_ptr] "+&r" (a_ptr), [res_ptr] "+&r" (res_ptr) + : [b_ptr] "r" (b_ptr), [nr] "r" (nr), [nb] "r" (nb), [res_stride] "r" (res_stride), [nc] "r" (nc) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28" + ); +#elif defined(__ARM_NEON) && defined(__aarch64__) + GGML_ASSERT((ggml_cpu_has_sve() || ggml_cpu_has_matmul_int8()) && + "__ARM_FEATURE_SVE and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 quantization format for optimal " + "performance"); +#else + float sumf[4][4]; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); + sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + + (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4; + } + sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d[m]); + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + } + } + } +#endif +} + +void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 8; + const int blocklen = 8; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8) && ! ((defined(_MSC_VER)) && ! defined(__clang__)) + if (svcntw() == 8) { + const void * b_ptr = vx; + const void * a_ptr = vy; + float * res_ptr = s; + size_t res_stride = bs * sizeof(float); + + __asm__ __volatile__( + "mov x20, #0x4\n" + "mov x13, %x[nr]\n" + "mov z28.s, #-0x4\n" + "mov x12, #0x88\n" + "ptrue p1.b\n" + "whilelt p0.s, XZR, x20\n" + "cmp x13, #0x10\n" + "mul x12, %x[nb], x12\n" + "blt 4f\n" + "1:" // Row loop + "add x11, %x[b_ptr], #0x10\n" + "mov x10, %x[nc]\n" + "add x9, %x[res_ptr], %x[res_stride], LSL #4\n" + "2:" // Column loop + "add x28, %x[a_ptr], #0x8\n" + "mov z24.b, #0x0\n" + "mov z15.b, #0x0\n" + "mov x27, %x[nb]\n" + "add x26, x28, x12\n" + "mov z12.b, #0x0\n" + "mov z0.b, #0x0\n" + "add x25, x26, x12\n" + "mov z13.b, #0x0\n" + "mov z1.b, #0x0\n" + "add x24, x25, x12\n" + "mov z20.b, #0x0\n" + "mov z25.b, #0x0\n" + "mov z11.b, #0x0\n" + "mov z16.b, #0x0\n" + "mov z19.b, #0x0\n" + "mov z26.b, #0x0\n" + "mov z8.b, #0x0\n" + "mov z29.b, #0x0\n" + "mov z27.b, #0x0\n" + "mov z10.b, #0x0\n" + "3:" // Block loop + "ld1b { z30.b }, p1/Z, [x11]\n" + "ld1b { z21.b }, p1/Z, [x11, #1, MUL VL]\n" + "mov z18.s, #0x0\n" + "mov z7.s, #0x0\n" + "ld1rqb { z3.b }, p1/Z, [x28]\n" + "ld1rqb { z5.b }, p1/Z, [x28, #16]\n" + "mov z9.s, #0x0\n" + "mov z22.s, #0x0\n" + "ld1b { z4.b }, p1/Z, [x11, #2, MUL VL]\n" + "ld1b { z17.b }, p1/Z, [x11, #3, MUL VL]\n" + "sub x20, x11, #0x10\n" + "sub x23, x28, #0x8\n" + "lsl z31.b, z30.b, #0x4\n" + "lsl z6.b, z21.b, #0x4\n" + "ld1h { z23.s }, p1/Z, [x20]\n" + "sub x22, x26, #0x8\n" + "and z30.b, z30.b, #0xf0\n" + "and z21.b, z21.b, #0xf0\n" + "sub x21, x25, #0x8\n" + "sub x20, x24, #0x8\n" + "lsl z14.b, z4.b, #0x4\n" + "lsl z2.b, z17.b, #0x4\n" + "subs x27, x27, #0x1\n" + "add x11, x11, #0x90\n" + ".inst 0x451f9872 // smmla z18.s, z3.b, z31.b\n" + ".inst 0x45069867 // smmla z7.s, z3.b, z6.b\n" + "ld1rqb { z3.b }, p1/Z, [x28, #32]\n" + "and z4.b, z4.b, #0xf0\n" + ".inst 0x451f98a9 // smmla z9.s, z5.b, z31.b\n" + ".inst 0x450698b6 // smmla z22.s, z5.b, z6.b\n" + "ld1rqb { z5.b }, p1/Z, [x28, #48]\n" + "and z17.b, z17.b, #0xf0\n" + "fcvt z23.s, p1/m, z23.h\n" + ".inst 0x450e9872 // smmla z18.s, z3.b, z14.b\n" + ".inst 0x45029867 // smmla z7.s, z3.b, z2.b\n" + "ld1rqb { z3.b }, p1/Z, [x28, #64]\n" + ".inst 0x450e98a9 // smmla z9.s, z5.b, z14.b\n" + ".inst 0x450298b6 // smmla z22.s, z5.b, z2.b\n" + "ld1rqb { z5.b }, p1/Z, [x28, #80]\n" + "fscale z23.s, p1/m, z23.s, z28.s\n" + ".inst 0x451e9872 // smmla z18.s, z3.b, z30.b\n" + ".inst 0x45159867 // smmla z7.s, z3.b, z21.b\n" + "ld1rqb { z3.b }, p1/Z, [x28, #96]\n" + ".inst 0x451e98a9 // smmla z9.s, z5.b, z30.b\n" + ".inst 0x451598b6 // smmla z22.s, z5.b, z21.b\n" + "ld1rqb { z5.b }, p1/Z, [x28, #112]\n" + "add x28, x28, #0x88\n" + ".inst 0x45049872 // smmla z18.s, z3.b, z4.b\n" + ".inst 0x45119867 // smmla z7.s, z3.b, z17.b\n" + "ld1h { z3.s }, p0/Z, [x23]\n" + ".inst 0x450498a9 // smmla z9.s, z5.b, z4.b\n" + ".inst 0x451198b6 // smmla z22.s, z5.b, z17.b\n" + "fcvt z3.s, p1/m, z3.h\n" + "uzp1 z5.d, z18.d, z7.d\n" + "uzp2 z18.d, z18.d, z7.d\n" + "mov z3.q, z3.q[0]\n" + "uzp1 z7.d, z9.d, z22.d\n" + "uzp2 z22.d, z9.d, z22.d\n" + "fmul z9.s, z23.s, z3.s[0]\n" + "scvtf z5.s, p1/m, z5.s\n" + "scvtf z18.s, p1/m, z18.s\n" + "scvtf z7.s, p1/m, z7.s\n" + "scvtf z22.s, p1/m, z22.s\n" + "fmla z24.s, p1/M, z5.s, z9.s\n" + "ld1rqb { z5.b }, p1/Z, [x26]\n" + "fmul z9.s, z23.s, z3.s[1]\n" + "fmla z15.s, p1/M, z18.s, z9.s\n" + "ld1rqb { z18.b }, p1/Z, [x26, #16]\n" + "fmul z9.s, z23.s, z3.s[2]\n" + "fmul z3.s, z23.s, z3.s[3]\n" + "fmla z12.s, p1/M, z7.s, z9.s\n" + "mov z9.s, #0x0\n" + "ld1h { z7.s }, p0/Z, [x22]\n" + ".inst 0x451f98a9 // smmla z9.s, z5.b, z31.b\n" + "fmla z0.s, p1/M, z22.s, z3.s\n" + "mov z22.s, #0x0\n" + "ld1h { z3.s }, p0/Z, [x21]\n" + ".inst 0x450698b6 // smmla z22.s, z5.b, z6.b\n" + "ld1rqb { z5.b }, p1/Z, [x26, #32]\n" + "fcvt z7.s, p1/m, z7.h\n" + "fcvt z3.s, p1/m, z3.h\n" + ".inst 0x450e98a9 // smmla z9.s, z5.b, z14.b\n" + ".inst 0x450298b6 // smmla z22.s, z5.b, z2.b\n" + "ld1rqb { z5.b }, p1/Z, [x26, #64]\n" + "mov z7.q, z7.q[0]\n" + "mov z3.q, z3.q[0]\n" + ".inst 0x451e98a9 // smmla z9.s, z5.b, z30.b\n" + ".inst 0x451598b6 // smmla z22.s, z5.b, z21.b\n" + "ld1rqb { z5.b }, p1/Z, [x26, #96]\n" + ".inst 0x450498a9 // smmla z9.s, z5.b, z4.b\n" + ".inst 0x451198b6 // smmla z22.s, z5.b, z17.b\n" + "uzp1 z5.d, z9.d, z22.d\n" + "scvtf z5.s, p1/m, z5.s\n" + "uzp2 z22.d, z9.d, z22.d\n" + "fmul z9.s, z23.s, z7.s[0]\n" + "scvtf z22.s, p1/m, z22.s\n" + "fmla z13.s, p1/M, z5.s, z9.s\n" + "ld1rqb { z9.b }, p1/Z, [x25]\n" + "fmul z5.s, z23.s, z7.s[1]\n" + "fmla z1.s, p1/M, z22.s, z5.s\n" + "mov z5.s, #0x0\n" + "mov z22.s, #0x0\n" + ".inst 0x451f9a45 // smmla z5.s, z18.b, z31.b\n" + ".inst 0x45069a56 // smmla z22.s, z18.b, z6.b\n" + "ld1rqb { z18.b }, p1/Z, [x26, #48]\n" + ".inst 0x450e9a45 // smmla z5.s, z18.b, z14.b\n" + ".inst 0x45029a56 // smmla z22.s, z18.b, z2.b\n" + "ld1rqb { z18.b }, p1/Z, [x26, #80]\n" + ".inst 0x451e9a45 // smmla z5.s, z18.b, z30.b\n" + ".inst 0x45159a56 // smmla z22.s, z18.b, z21.b\n" + "ld1rqb { z18.b }, p1/Z, [x26, #112]\n" + "add x26, x26, #0x88\n" + ".inst 0x45049a45 // smmla z5.s, z18.b, z4.b\n" + ".inst 0x45119a56 // smmla z22.s, z18.b, z17.b\n" + "uzp1 z18.d, z5.d, z22.d\n" + "scvtf z18.s, p1/m, z18.s\n" + "uzp2 z22.d, z5.d, z22.d\n" + "fmul z5.s, z23.s, z7.s[2]\n" + "fmul z7.s, z23.s, z7.s[3]\n" + "scvtf z22.s, p1/m, z22.s\n" + "fmla z20.s, p1/M, z18.s, z5.s\n" + "ld1rqb { z18.b }, p1/Z, [x25, #16]\n" + "ld1h { z5.s }, p0/Z, [x20]\n" + "fcvt z5.s, p1/m, z5.h\n" + "fmla z25.s, p1/M, z22.s, z7.s\n" + "mov z22.s, #0x0\n" + "mov z7.s, #0x0\n" + ".inst 0x451f9936 // smmla z22.s, z9.b, z31.b\n" + ".inst 0x45069927 // smmla z7.s, z9.b, z6.b\n" + "ld1rqb { z9.b }, p1/Z, [x25, #32]\n" + "mov z5.q, z5.q[0]\n" + ".inst 0x450e9936 // smmla z22.s, z9.b, z14.b\n" + ".inst 0x45029927 // smmla z7.s, z9.b, z2.b\n" + "ld1rqb { z9.b }, p1/Z, [x25, #64]\n" + ".inst 0x451e9936 // smmla z22.s, z9.b, z30.b\n" + ".inst 0x45159927 // smmla z7.s, z9.b, z21.b\n" + "ld1rqb { z9.b }, p1/Z, [x25, #96]\n" + ".inst 0x45049936 // smmla z22.s, z9.b, z4.b\n" + ".inst 0x45119927 // smmla z7.s, z9.b, z17.b\n" + "uzp1 z9.d, z22.d, z7.d\n" + "scvtf z9.s, p1/m, z9.s\n" + "uzp2 z22.d, z22.d, z7.d\n" + "fmul z7.s, z23.s, z3.s[0]\n" + "scvtf z22.s, p1/m, z22.s\n" + "fmla z11.s, p1/M, z9.s, z7.s\n" + "ld1rqb { z9.b }, p1/Z, [x24]\n" + "fmul z7.s, z23.s, z3.s[1]\n" + "fmla z16.s, p1/M, z22.s, z7.s\n" + "mov z22.s, #0x0\n" + "mov z7.s, #0x0\n" + ".inst 0x451f9a56 // smmla z22.s, z18.b, z31.b\n" + ".inst 0x45069a47 // smmla z7.s, z18.b, z6.b\n" + "ld1rqb { z18.b }, p1/Z, [x25, #48]\n" + ".inst 0x450e9a56 // smmla z22.s, z18.b, z14.b\n" + ".inst 0x45029a47 // smmla z7.s, z18.b, z2.b\n" + "ld1rqb { z18.b }, p1/Z, [x25, #80]\n" + ".inst 0x451e9a56 // smmla z22.s, z18.b, z30.b\n" + ".inst 0x45159a47 // smmla z7.s, z18.b, z21.b\n" + "ld1rqb { z18.b }, p1/Z, [x25, #112]\n" + "add x25, x25, #0x88\n" + ".inst 0x45049a56 // smmla z22.s, z18.b, z4.b\n" + ".inst 0x45119a47 // smmla z7.s, z18.b, z17.b\n" + "uzp1 z18.d, z22.d, z7.d\n" + "scvtf z18.s, p1/m, z18.s\n" + "uzp2 z7.d, z22.d, z7.d\n" + "fmul z22.s, z23.s, z3.s[2]\n" + "fmul z3.s, z23.s, z3.s[3]\n" + "scvtf z7.s, p1/m, z7.s\n" + "fmla z19.s, p1/M, z18.s, z22.s\n" + "ld1rqb { z18.b }, p1/Z, [x24, #16]\n" + "fmul z22.s, z23.s, z5.s[0]\n" + "fmla z26.s, p1/M, z7.s, z3.s\n" + "mov z3.s, #0x0\n" + "mov z7.s, #0x0\n" + ".inst 0x451f9923 // smmla z3.s, z9.b, z31.b\n" + ".inst 0x45069927 // smmla z7.s, z9.b, z6.b\n" + "ld1rqb { z9.b }, p1/Z, [x24, #32]\n" + ".inst 0x450e9923 // smmla z3.s, z9.b, z14.b\n" + ".inst 0x45029927 // smmla z7.s, z9.b, z2.b\n" + "mov z9.s, #0x0\n" + ".inst 0x451f9a49 // smmla z9.s, z18.b, z31.b\n" + "mov z31.s, #0x0\n" + ".inst 0x45069a5f // smmla z31.s, z18.b, z6.b\n" + "ld1rqb { z6.b }, p1/Z, [x24, #48]\n" + "ld1rqb { z18.b }, p1/Z, [x24, #64]\n" + ".inst 0x450e98c9 // smmla z9.s, z6.b, z14.b\n" + "fmul z14.s, z23.s, z5.s[1]\n" + ".inst 0x450298df // smmla z31.s, z6.b, z2.b\n" + "ld1rqb { z6.b }, p1/Z, [x24, #80]\n" + "fmul z2.s, z23.s, z5.s[2]\n" + "fmul z23.s, z23.s, z5.s[3]\n" + ".inst 0x451e9a43 // smmla z3.s, z18.b, z30.b\n" + ".inst 0x45159a47 // smmla z7.s, z18.b, z21.b\n" + "ld1rqb { z5.b }, p1/Z, [x24, #96]\n" + ".inst 0x451e98c9 // smmla z9.s, z6.b, z30.b\n" + ".inst 0x451598df // smmla z31.s, z6.b, z21.b\n" + "ld1rqb { z18.b }, p1/Z, [x24, #112]\n" + "add x24, x24, #0x88\n" + ".inst 0x450498a3 // smmla z3.s, z5.b, z4.b\n" + ".inst 0x451198a7 // smmla z7.s, z5.b, z17.b\n" + ".inst 0x45049a49 // smmla z9.s, z18.b, z4.b\n" + ".inst 0x45119a5f // smmla z31.s, z18.b, z17.b\n" + "uzp1 z18.d, z3.d, z7.d\n" + "uzp2 z5.d, z3.d, z7.d\n" + "scvtf z18.s, p1/m, z18.s\n" + "uzp1 z6.d, z9.d, z31.d\n" + "uzp2 z9.d, z9.d, z31.d\n" + "scvtf z5.s, p1/m, z5.s\n" + "fmla z8.s, p1/M, z18.s, z22.s\n" + "scvtf z6.s, p1/m, z6.s\n" + "scvtf z9.s, p1/m, z9.s\n" + "fmla z29.s, p1/M, z5.s, z14.s\n" + "fmla z27.s, p1/M, z6.s, z2.s\n" + "fmla z10.s, p1/M, z9.s, z23.s\n" + "bgt 3b\n" + "mov x20, %x[res_ptr]\n" + "subs x10, x10, #0x8\n" + "add %x[res_ptr], %x[res_ptr], #0x20\n" + "st1w { z24.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z15.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z12.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z0.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z13.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z1.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z20.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z25.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z11.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z16.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z19.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z26.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z8.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z29.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z27.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z10.s }, p1, [x20]\n" + "bne 2b\n" + "mov x20, #0x4\n" + "sub x13, x13, #0x10\n" + "cmp x13, #0x10\n" + "mov %x[res_ptr], x9\n" + "madd %x[a_ptr], x20, x12, %x[a_ptr]\n" + "bge 1b\n" + "4:" // Row loop skip + "cbz x13, 9f\n" + "5:" // Row tail: Row loop + "add x25, %x[b_ptr], #0x10\n" + "mov x24, %x[nc]\n" + "add x23, %x[res_ptr], %x[res_stride], LSL #2\n" + "6:" // Row tail: Column loop + "mov z24.b, #0x0\n" + "mov z15.b, #0x0\n" + "add x28, %x[a_ptr], #0x8\n" + "mov x22, %x[nb]\n" + "mov z12.b, #0x0\n" + "mov z0.b, #0x0\n" + "7:" // Row tail: Block loop + "ld1b { z3.b }, p1/Z, [x25]\n" + "ld1b { z6.b }, p1/Z, [x25, #1, MUL VL]\n" + "mov z2.s, #0x0\n" + "mov z25.s, #0x0\n" + "ld1rqb { z26.b }, p1/Z, [x28]\n" + "ld1rqb { z21.b }, p1/Z, [x28, #16]\n" + "mov z27.s, #0x0\n" + "mov z19.s, #0x0\n" + "ld1b { z29.b }, p1/Z, [x25, #2, MUL VL]\n" + "ld1b { z16.b }, p1/Z, [x25, #3, MUL VL]\n" + "sub x21, x25, #0x10\n" + "sub x20, x28, #0x8\n" + "lsl z20.b, z3.b, #0x4\n" + "lsl z4.b, z6.b, #0x4\n" + "ld1rqb { z10.b }, p1/Z, [x28, #32]\n" + "ld1rqb { z23.b }, p1/Z, [x28, #48]\n" + "and z3.b, z3.b, #0xf0\n" + "and z6.b, z6.b, #0xf0\n" + "ld1rqb { z11.b }, p1/Z, [x28, #64]\n" + "ld1rqb { z7.b }, p1/Z, [x28, #80]\n" + "lsl z8.b, z29.b, #0x4\n" + "lsl z14.b, z16.b, #0x4\n" + "ld1rqb { z18.b }, p1/Z, [x28, #96]\n" + "ld1rqb { z30.b }, p1/Z, [x28, #112]\n" + ".inst 0x45149b42 // smmla z2.s, z26.b, z20.b\n" + ".inst 0x45049b59 // smmla z25.s, z26.b, z4.b\n" + "and z29.b, z29.b, #0xf0\n" + "ld1h { z17.s }, p1/Z, [x21]\n" + ".inst 0x45149abb // smmla z27.s, z21.b, z20.b\n" + ".inst 0x45049ab3 // smmla z19.s, z21.b, z4.b\n" + "and z16.b, z16.b, #0xf0\n" + "ld1h { z4.s }, p0/Z, [x20]\n" + "subs x22, x22, #0x1\n" + "add x28, x28, #0x88\n" + "fcvt z17.s, p1/m, z17.h\n" + "add x25, x25, #0x90\n" + ".inst 0x45089942 // smmla z2.s, z10.b, z8.b\n" + ".inst 0x450e9959 // smmla z25.s, z10.b, z14.b\n" + "fcvt z4.s, p1/m, z4.h\n" + ".inst 0x45089afb // smmla z27.s, z23.b, z8.b\n" + ".inst 0x450e9af3 // smmla z19.s, z23.b, z14.b\n" + "fscale z17.s, p1/m, z17.s, z28.s\n" + "mov z4.q, z4.q[0]\n" + ".inst 0x45039962 // smmla z2.s, z11.b, z3.b\n" + ".inst 0x45069979 // smmla z25.s, z11.b, z6.b\n" + "fmul z23.s, z17.s, z4.s[0]\n" + "fmul z9.s, z17.s, z4.s[1]\n" + "fmul z21.s, z17.s, z4.s[2]\n" + "fmul z4.s, z17.s, z4.s[3]\n" + ".inst 0x450398fb // smmla z27.s, z7.b, z3.b\n" + ".inst 0x450698f3 // smmla z19.s, z7.b, z6.b\n" + ".inst 0x451d9a42 // smmla z2.s, z18.b, z29.b\n" + ".inst 0x45109a59 // smmla z25.s, z18.b, z16.b\n" + ".inst 0x451d9bdb // smmla z27.s, z30.b, z29.b\n" + ".inst 0x45109bd3 // smmla z19.s, z30.b, z16.b\n" + "uzp1 z31.d, z2.d, z25.d\n" + "uzp2 z13.d, z2.d, z25.d\n" + "scvtf z31.s, p1/m, z31.s\n" + "uzp1 z17.d, z27.d, z19.d\n" + "uzp2 z18.d, z27.d, z19.d\n" + "scvtf z13.s, p1/m, z13.s\n" + "fmla z24.s, p1/M, z31.s, z23.s\n" + "scvtf z17.s, p1/m, z17.s\n" + "scvtf z18.s, p1/m, z18.s\n" + "fmla z15.s, p1/M, z13.s, z9.s\n" + "fmla z12.s, p1/M, z17.s, z21.s\n" + "fmla z0.s, p1/M, z18.s, z4.s\n" + "bgt 7b\n" + "mov x20, %x[res_ptr]\n" + "cmp x13, #0x1\n" + "st1w { z24.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "ble 8f\n" + "cmp x13, #0x2\n" + "st1w { z15.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "ble 8f\n" + "cmp x13, #0x3\n" + "st1w { z12.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "ble 8f\n" + "st1w { z0.s }, p1, [x20]\n" + "8:" // Row tail: Accumulator store skip + "subs x24, x24, #0x8\n" + "add %x[res_ptr], %x[res_ptr], #0x20\n" + "bne 6b\n" + "subs x13, x13, #0x4\n" + "add %x[a_ptr], %x[a_ptr], x12\n" + "mov %x[res_ptr], x23\n" + "bgt 5b\n" + "9:" // Row tail: Row loop skip + : [a_ptr] "+&r" (a_ptr), [res_ptr] "+&r" (res_ptr) + : [b_ptr] "r" (b_ptr), [nr] "r" (nr), [nb] "r" (nb), [res_stride] "r" (res_stride), [nc] "r" (nc) + : "cc", "memory", "p0", "p1", "x9", "x10", "x11", "x12", "x13", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" + ); + return; + } + else if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) { + GGML_ASSERT((ggml_cpu_has_sve() && (svcntw() == 8)) && + "__ARM_FEATURE_SVE for vector size of 256-bits not defined, use the Q4_0_4_8 quantization format for optimal " + "performance"); + } + else if (ggml_cpu_has_neon()) { + GGML_ASSERT(((ggml_cpu_has_sve() && (svcntw() == 8)) || ggml_cpu_has_matmul_int8()) && + "__ARM_FEATURE_SVE for vector size of 256-bits and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 " + "quantization format for optimal performance"); + } +#endif +#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) + GGML_ASSERT(ggml_cpu_has_sve() && + "__ARM_FEATURE_SVE not defined, use the Q4_0_4_8 quantization format for optimal performance"); +#elif defined(__ARM_NEON) && defined(__aarch64__) + GGML_ASSERT((ggml_cpu_has_sve() || ggml_cpu_has_matmul_int8()) && + "__ARM_FEATURE_SVE and __ARM_FEATURE_MATMUL_INT8 not defined, use the Q4_0_4_4 quantization format for optimal " + "performance"); +#else + float sumf[4][8]; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); + sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + + (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4; + } + sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d[m]); + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + } + } + } +#endif +} diff --git a/llama/ggml-aarch64.h b/llama/ggml-aarch64.h new file mode 100644 index 00000000..f00fde74 --- /dev/null +++ b/llama/ggml-aarch64.h @@ -0,0 +1,65 @@ +/** + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file + * + * MIT License + * + * Copyright (c) 2023-2024 The ggml authors + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +// SPDX-FileCopyrightText: Copyright 2024 Arm Ltd. +#pragma once + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" + +#include "ggml.h" + +// GGML internal header + +#ifdef __cplusplus +extern "C" { +#endif + +// Quantization +void quantize_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); + +void quantize_mat_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t nrows, int64_t n_per_row, int64_t blck_size_interleave); + +// Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization") +size_t quantize_q4_0_4x4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_q4_0_4x8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_q4_0_8x8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); + +// GEMV +void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); + +// GEMM +void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); + +#ifdef __cplusplus +} +#endif + diff --git a/llama/ggml-alloc.c b/llama/ggml-alloc.c index 268f2234..ca84d2e9 100644 --- a/llama/ggml-alloc.c +++ b/llama/ggml-alloc.c @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * @@ -117,8 +117,7 @@ void ggml_tallocr_alloc(struct ggml_tallocr * talloc, struct ggml_tensor * tenso if (talloc->offset + size > ggml_backend_buffer_get_size(talloc->buffer)) { fprintf(stderr, "%s: not enough space in the buffer to allocate %s (needed %zu, available %zu)\n", __func__, tensor->name, size, ggml_backend_buffer_get_size(talloc->buffer) - talloc->offset); - GGML_ASSERT(!"not enough space in the buffer"); - return; + GGML_ABORT("not enough space in the buffer"); } void * addr = (char *)ggml_backend_buffer_get_base(talloc->buffer) + talloc->offset; @@ -159,7 +158,7 @@ static void add_allocated_tensor(struct ggml_dyn_tallocr * alloc, size_t offset, return; } } - GGML_ASSERT(!"out of allocated_tensors"); + GGML_ABORT("out of allocated_tensors"); } static void remove_allocated_tensor(struct ggml_dyn_tallocr * alloc, size_t offset, const struct ggml_tensor * tensor) { for (int i = 0; i < 1024; i++) { @@ -168,8 +167,7 @@ static void remove_allocated_tensor(struct ggml_dyn_tallocr * alloc, size_t offs return; } } - fprintf(stderr, "tried to free tensor %s not found\n", tensor->name); - GGML_ASSERT(!"tensor not found"); + GGML_ABORT("tried to free tensor %s not found\n", tensor->name); } #endif @@ -202,8 +200,7 @@ static size_t ggml_dyn_tallocr_alloc(struct ggml_dyn_tallocr * alloc, size_t siz // this should never happen fprintf(stderr, "%s: not enough space in the buffer to allocate %zu bytes, largest block available %zu bytes\n", __func__, size, max_avail); - GGML_ASSERT(!"not enough space in the buffer"); - GGML_UNREACHABLE(); + GGML_ABORT("not enough space in the buffer"); } } @@ -365,6 +362,7 @@ struct hash_node { }; struct tensor_alloc { + int buffer_id; size_t offset; size_t size_max; // 0 = pre-allocated, unused, or view }; @@ -375,7 +373,6 @@ struct leaf_alloc { }; struct node_alloc { - int buffer_id; struct tensor_alloc dst; struct tensor_alloc src[GGML_MAX_SRC]; }; @@ -412,8 +409,19 @@ ggml_gallocr_t ggml_gallocr_new_n(ggml_backend_buffer_type_t * bufts, int n_bufs for (int i = 0; i < n_bufs; i++) { galloc->bufts[i] = bufts[i]; galloc->buffers[i] = NULL; - size_t alignment = ggml_backend_buft_get_alignment(bufts[i]); - galloc->buf_tallocs[i] = ggml_dyn_tallocr_new(alignment); + + // check if the same buffer type is used multiple times and reuse the same allocator + for (int j = 0; j < i; j++) { + if (bufts[i] == bufts[j]) { + galloc->buf_tallocs[i] = galloc->buf_tallocs[j]; + break; + } + } + + if (galloc->buf_tallocs[i] == NULL) { + size_t alignment = ggml_backend_buft_get_alignment(bufts[i]); + galloc->buf_tallocs[i] = ggml_dyn_tallocr_new(alignment); + } } galloc->n_buffers = n_bufs; @@ -431,14 +439,34 @@ void ggml_gallocr_free(ggml_gallocr_t galloc) { for (int i = 0; i < galloc->n_buffers; i++) { if (galloc->buffers != NULL) { - ggml_backend_buffer_free(galloc->buffers[i]); + // skip if already freed + bool freed = false; + for (int j = 0; j < i; j++) { + if (galloc->buffers[j] == galloc->buffers[i]) { + freed = true; + break; + } + } + if (!freed) { + ggml_backend_buffer_free(galloc->buffers[i]); + } } if (galloc->buf_tallocs != NULL) { - ggml_dyn_tallocr_free(galloc->buf_tallocs[i]); + // skip if already freed + bool freed = false; + for (int j = 0; j < i; j++) { + if (galloc->buf_tallocs[j] == galloc->buf_tallocs[i]) { + freed = true; + break; + } + } + if (!freed) { + ggml_dyn_tallocr_free(galloc->buf_tallocs[i]); + } } } - free(galloc->hash_set.keys); + ggml_hash_set_free(&galloc->hash_set); free(galloc->hash_values); free(galloc->bufts); free(galloc->buffers); @@ -451,7 +479,7 @@ void ggml_gallocr_free(ggml_gallocr_t galloc) { typedef struct ggml_gallocr * ggml_gallocr_t; static struct hash_node * ggml_gallocr_hash_get(ggml_gallocr_t galloc, struct ggml_tensor * t) { - size_t i = ggml_hash_find_or_insert(galloc->hash_set, t); + size_t i = ggml_hash_find_or_insert(&galloc->hash_set, t); return &galloc->hash_values[i]; } @@ -537,17 +565,18 @@ static void ggml_gallocr_allocate_node(ggml_gallocr_t galloc, struct ggml_tensor } } -static void ggml_gallocr_free_node(ggml_gallocr_t galloc, struct ggml_tensor * node, int buffer_id) { +static void ggml_gallocr_free_node(ggml_gallocr_t galloc, struct ggml_tensor * node) { // graph outputs are never freed if (node->flags & GGML_TENSOR_FLAG_OUTPUT) { AT_PRINTF("not freeing output %s\n", node->name); return; } - struct ggml_dyn_tallocr * alloc = galloc->buf_tallocs[buffer_id]; - ggml_backend_buffer_type_t buft = galloc->bufts[buffer_id]; struct hash_node * hn = ggml_gallocr_hash_get(galloc, node); size_t offset = hn->offset; + int buffer_id = hn->buffer_id; + struct ggml_dyn_tallocr * alloc = galloc->buf_tallocs[buffer_id]; + ggml_backend_buffer_type_t buft = galloc->bufts[buffer_id]; size_t size = ggml_backend_buft_get_alloc_size(buft, node); ggml_dyn_tallocr_free_tensor(alloc, offset, size, node); hn->allocated = false; @@ -559,8 +588,8 @@ static int get_node_buffer_id(const int * node_buffer_ids, int i) { static void ggml_gallocr_alloc_graph_impl(ggml_gallocr_t galloc, struct ggml_cgraph * graph, const int * node_buffer_ids, const int * leaf_buffer_ids) { // clear hash tables - memset(galloc->hash_set.keys, 0, galloc->hash_set.size * sizeof(struct ggml_tensor *)); - memset(galloc->hash_values, 0, galloc->hash_set.size * sizeof(struct hash_node)); + ggml_hash_set_reset(&galloc->hash_set); + memset(galloc->hash_values, 0, sizeof(struct hash_node) * galloc->hash_set.size); // allocate leafs // these may be tensors that the application is not using in the graph, but may still want to allocate for other purposes @@ -652,11 +681,11 @@ static void ggml_gallocr_alloc_graph_impl(ggml_gallocr_t galloc, struct ggml_cgr AT_PRINTF("view_src %s: %d children, %d views\n", view_src->name, view_src_hn->n_children, view_src_hn->n_views); if (view_src_hn->n_views == 0 && view_src_hn->n_children == 0 && view_src_hn->allocated) { - ggml_gallocr_free_node(galloc, view_src, buffer_id); + ggml_gallocr_free_node(galloc, view_src); } } else if (p_hn->allocated) { - ggml_gallocr_free_node(galloc, parent, buffer_id); + ggml_gallocr_free_node(galloc, parent); } } AT_PRINTF("\n"); @@ -665,21 +694,19 @@ static void ggml_gallocr_alloc_graph_impl(ggml_gallocr_t galloc, struct ggml_cgr } bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, const int * node_buffer_ids, const int * leaf_buffer_ids) { - size_t hash_size = graph->visited_hash_table.size; + size_t min_hash_size = graph->n_nodes + graph->n_leafs; + // add 25% margin to avoid hash collisions + min_hash_size += min_hash_size / 4; // initialize hash table - if (galloc->hash_set.size < hash_size) { - free(galloc->hash_set.keys); - free(galloc->hash_values); - galloc->hash_set.size = hash_size; - galloc->hash_set.keys = calloc(hash_size, sizeof(struct ggml_tensor *)); - galloc->hash_values = calloc(hash_size, sizeof(struct hash_node)); + if (galloc->hash_set.size < min_hash_size) { + ggml_hash_set_free(&galloc->hash_set); + galloc->hash_set = ggml_hash_set_new(min_hash_size); GGML_ASSERT(galloc->hash_set.keys != NULL); + + free(galloc->hash_values); + galloc->hash_values = malloc(sizeof(struct hash_node) * galloc->hash_set.size); GGML_ASSERT(galloc->hash_values != NULL); - } else { - // reset hash table - memset(galloc->hash_set.keys, 0, sizeof(struct ggml_tensor *) * galloc->hash_set.size); - memset(galloc->hash_values, 0, sizeof(struct hash_node) * galloc->hash_set.size); } // reset allocators @@ -700,22 +727,25 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c for (int i = 0; i < graph->n_nodes; i++) { struct ggml_tensor * node = graph->nodes[i]; struct node_alloc * node_alloc = &galloc->node_allocs[i]; - node_alloc->buffer_id = get_node_buffer_id(node_buffer_ids, i); if (node->view_src || node->data) { + node_alloc->dst.buffer_id = -1; node_alloc->dst.offset = SIZE_MAX; node_alloc->dst.size_max = 0; } else { struct hash_node * hn = ggml_gallocr_hash_get(galloc, node); - node_alloc->dst.offset = hn->offset; - node_alloc->dst.size_max = ggml_backend_buft_get_alloc_size(galloc->bufts[hn->buffer_id], node); + node_alloc->dst.buffer_id = hn->buffer_id; + node_alloc->dst.offset = hn->offset; + node_alloc->dst.size_max = ggml_backend_buft_get_alloc_size(galloc->bufts[hn->buffer_id], node); } for (int j = 0; j < GGML_MAX_SRC; j++) { struct ggml_tensor * src = node->src[j]; if (!src || src->view_src || src->data) { + node_alloc->src[j].buffer_id = -1; node_alloc->src[j].offset = SIZE_MAX; node_alloc->src[j].size_max = 0; } else { struct hash_node * hn = ggml_gallocr_hash_get(galloc, src); + node_alloc->src[j].buffer_id = hn->buffer_id; node_alloc->src[j].offset = hn->offset; node_alloc->src[j].size_max = ggml_backend_buft_get_alloc_size(galloc->bufts[hn->buffer_id], src); } @@ -732,9 +762,11 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c struct hash_node * hn = ggml_gallocr_hash_get(galloc, leaf); galloc->leaf_allocs[i].buffer_id = hn->buffer_id; if (leaf->view_src || leaf->data) { + galloc->leaf_allocs[i].leaf.buffer_id = -1; galloc->leaf_allocs[i].leaf.offset = SIZE_MAX; galloc->leaf_allocs[i].leaf.size_max = 0; } else { + galloc->leaf_allocs[i].leaf.buffer_id = hn->buffer_id; galloc->leaf_allocs[i].leaf.offset = hn->offset; galloc->leaf_allocs[i].leaf.size_max = ggml_backend_buft_get_alloc_size(galloc->bufts[hn->buffer_id], leaf); } @@ -742,6 +774,14 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c // reallocate buffers if needed for (int i = 0; i < galloc->n_buffers; i++) { + // if the buffer type is used multiple times, we reuse the same buffer + for (int j = 0; j < i; j++) { + if (galloc->buf_tallocs[j] == galloc->buf_tallocs[i]) { + galloc->buffers[i] = galloc->buffers[j]; + break; + } + } + size_t cur_size = galloc->buffers[i] ? ggml_backend_buffer_get_size(galloc->buffers[i]) : 0; size_t new_size = ggml_dyn_tallocr_max_size(galloc->buf_tallocs[i]); @@ -750,12 +790,14 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c #ifndef NDEBUG fprintf(stderr, "%s: reallocating %s buffer from size %.02f MiB to %.02f MiB\n", __func__, ggml_backend_buft_name(galloc->bufts[i]), cur_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0); #endif + ggml_backend_buffer_free(galloc->buffers[i]); galloc->buffers[i] = ggml_backend_buft_alloc_buffer(galloc->bufts[i], new_size); if (galloc->buffers[i] == NULL) { fprintf(stderr, "%s: failed to allocate %s buffer of size %zu\n", __func__, ggml_backend_buft_name(galloc->bufts[i]), new_size); return false; } + ggml_backend_buffer_set_usage(galloc->buffers[i], GGML_BACKEND_BUFFER_USAGE_COMPUTE); } } @@ -766,7 +808,8 @@ bool ggml_gallocr_reserve(ggml_gallocr_t galloc, struct ggml_cgraph *graph) { return ggml_gallocr_reserve_n(galloc, graph, NULL, NULL); } -static void ggml_gallocr_init_tensor(ggml_gallocr_t galloc, struct ggml_tensor * tensor, int buffer_id, struct tensor_alloc * tensor_alloc) { +static void ggml_gallocr_init_tensor(ggml_gallocr_t galloc, struct ggml_tensor * tensor, struct tensor_alloc * tensor_alloc) { + int buffer_id = tensor_alloc->buffer_id; assert(tensor->data || tensor->view_src || ggml_backend_buffer_get_alloc_size(galloc->buffers[buffer_id], tensor) <= tensor_alloc->size_max); if (tensor->view_src != NULL) { @@ -794,9 +837,8 @@ static void ggml_gallocr_init_tensor(ggml_gallocr_t galloc, struct ggml_tensor * } } -static bool ggml_gallocr_node_needs_realloc(ggml_gallocr_t galloc, struct ggml_tensor * node, struct node_alloc * nalloc, struct tensor_alloc * talloc) { - ggml_backend_buffer_type_t buft = galloc->bufts[nalloc->buffer_id]; - size_t node_size = (node->data || node->view_src) ? 0 : ggml_backend_buft_get_alloc_size(buft, node); +static bool ggml_gallocr_node_needs_realloc(ggml_gallocr_t galloc, struct ggml_tensor * node, struct tensor_alloc * talloc) { + size_t node_size = (node->data || node->view_src) ? 0 : ggml_backend_buft_get_alloc_size(galloc->bufts[talloc->buffer_id], node); return talloc->size_max >= node_size; } @@ -819,7 +861,7 @@ static bool ggml_gallocr_needs_realloc(ggml_gallocr_t galloc, struct ggml_cgraph struct ggml_tensor * node = graph->nodes[i]; struct node_alloc * node_alloc = &galloc->node_allocs[i]; - if (!ggml_gallocr_node_needs_realloc(galloc, node, node_alloc, &node_alloc->dst)) { + if (!ggml_gallocr_node_needs_realloc(galloc, node, &node_alloc->dst)) { #ifndef NDEBUG fprintf(stderr, "%s: node %s is not valid\n", __func__, node->name); #endif @@ -831,7 +873,7 @@ static bool ggml_gallocr_needs_realloc(ggml_gallocr_t galloc, struct ggml_cgraph if (src == NULL) { continue; } - if (!ggml_gallocr_node_needs_realloc(galloc, src, node_alloc, &node_alloc->src[j])) { + if (!ggml_gallocr_node_needs_realloc(galloc, src, &node_alloc->src[j])) { #ifndef NDEBUG fprintf(stderr, "%s: src %d (%s) of node %s is not valid\n", __func__, j, src->name, node->name); #endif @@ -872,7 +914,7 @@ bool ggml_gallocr_alloc_graph(ggml_gallocr_t galloc, struct ggml_cgraph * graph) for (int i = 0; i < graph->n_leafs; i++) { struct ggml_tensor * leaf = graph->leafs[i]; struct leaf_alloc * leaf_alloc = &galloc->leaf_allocs[i]; - ggml_gallocr_init_tensor(galloc, leaf, leaf_alloc->buffer_id, &leaf_alloc->leaf); + ggml_gallocr_init_tensor(galloc, leaf, &leaf_alloc->leaf); } // nodes for (int i = 0; i < graph->n_nodes; i++) { @@ -883,9 +925,9 @@ bool ggml_gallocr_alloc_graph(ggml_gallocr_t galloc, struct ggml_cgraph * graph) if (src == NULL) { continue; } - ggml_gallocr_init_tensor(galloc, src, node_alloc->buffer_id, &node_alloc->src[j]); + ggml_gallocr_init_tensor(galloc, src, &node_alloc->src[j]); } - ggml_gallocr_init_tensor(galloc, node, node_alloc->buffer_id, &node_alloc->dst); + ggml_gallocr_init_tensor(galloc, node, &node_alloc->dst); } return true; @@ -897,6 +939,15 @@ size_t ggml_gallocr_get_buffer_size(ggml_gallocr_t galloc, int buffer_id) { if (galloc->buffers[buffer_id] == NULL) { return 0; } + + for (int i = 0; i < buffer_id; i++) { + if (galloc->buffers[i] == galloc->buffers[buffer_id]) { + // this buffer is the same as a previous one due to the same buffer type being used multiple times + // only return the buffer size the first time it appears to avoid double counting + return 0; + } + } + return ggml_backend_buffer_get_size(galloc->buffers[buffer_id]); } @@ -912,7 +963,7 @@ static bool alloc_tensor_range(struct ggml_context * ctx, fprintf(stderr, "%s: failed to allocate %s buffer of size %zu\n", __func__, ggml_backend_buft_name(buft), size); #endif for (size_t i = 0; i < *n_buffers; i++) { - ggml_backend_buffer_free(*buffers[i]); + ggml_backend_buffer_free((*buffers)[i]); } free(*buffers); return false; diff --git a/llama/ggml-alloc.h b/llama/ggml-alloc.h index 5311cc17..676c9695 100644 --- a/llama/ggml-alloc.h +++ b/llama/ggml-alloc.h @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-backend-impl.h b/llama/ggml-backend-impl.h index c9f4b7a6..c44e5b0f 100644 --- a/llama/ggml-backend-impl.h +++ b/llama/ggml-backend-impl.h @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * @@ -43,13 +43,15 @@ extern "C" { struct ggml_backend_buffer_type_i { const char * (*GGML_CALL get_name) (ggml_backend_buffer_type_t buft); + // allocate a buffer of this type ggml_backend_buffer_t (*GGML_CALL alloc_buffer) (ggml_backend_buffer_type_t buft, size_t size); - size_t (*GGML_CALL get_alignment) (ggml_backend_buffer_type_t buft); // tensor alignment - size_t (*GGML_CALL get_max_size) (ggml_backend_buffer_type_t buft); // allocation max size - size_t (*GGML_CALL get_alloc_size) (ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor); // data size needed to allocate the tensor, including padding - bool (*GGML_CALL supports_backend)(ggml_backend_buffer_type_t buft, ggml_backend_t backend); // check if the buffer type is usable by the backend + // tensor alignment + size_t (*GGML_CALL get_alignment) (ggml_backend_buffer_type_t buft); + // max buffer size that can be allocated + size_t (*GGML_CALL get_max_size) (ggml_backend_buffer_type_t buft); + // data size needed to allocate the tensor, including padding + size_t (*GGML_CALL get_alloc_size) (ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor); // check if tensor data is in host memory - // should be equivalent to supports_backend(buft, ggml_backend_cpu_init()) bool (*GGML_CALL is_host) (ggml_backend_buffer_type_t buft); }; @@ -118,27 +120,37 @@ extern "C" { void (*GGML_CALL synchronize)(ggml_backend_t backend); // compute graph with a plan (not used currently) + // create a new plan for a graph ggml_backend_graph_plan_t (*GGML_CALL graph_plan_create) (ggml_backend_t backend, const struct ggml_cgraph * cgraph); void (*GGML_CALL graph_plan_free) (ggml_backend_t backend, ggml_backend_graph_plan_t plan); + // update the plan with a new graph - this should be faster than creating a new plan when the graph has the same topology + void (*GGML_CALL graph_plan_update) (ggml_backend_t backend, ggml_backend_graph_plan_t plan, const struct ggml_cgraph * cgraph); + // compute the graph with the plan + enum ggml_status (*GGML_CALL graph_plan_compute)(ggml_backend_t backend, ggml_backend_graph_plan_t plan); - // compute graph with a plan - enum ggml_status (*GGML_CALL graph_plan_compute)(ggml_backend_t backend, ggml_backend_graph_plan_t plan); // compute graph without a plan (async) enum ggml_status (*GGML_CALL graph_compute) (ggml_backend_t backend, struct ggml_cgraph * cgraph); - // check if the backend supports an operation + // check if the backend can compute an operation bool (*GGML_CALL supports_op)(ggml_backend_t backend, const struct ggml_tensor * op); + // check if the backend can use tensors allocated in a buffer type + bool (*GGML_CALL supports_buft)(ggml_backend_t backend, ggml_backend_buffer_type_t buft); + // check if the backend wants to run an operation, even if the weights are allocated in a CPU buffer // these should be expensive operations with large batch sizes that may benefit from running on this backend // even if the weight has to be copied from the CPU temporarily bool (*GGML_CALL offload_op)(ggml_backend_t backend, const struct ggml_tensor * op); // (optional) event synchronization + // create a new event that can record events on this backend instance ggml_backend_event_t (*GGML_CALL event_new) (ggml_backend_t backend); void (*GGML_CALL event_free) (ggml_backend_event_t event); + // record an event on the backend instance that created it void (*GGML_CALL event_record) (ggml_backend_event_t event); + // wait for an event on on a different backend instance void (*GGML_CALL event_wait) (ggml_backend_t backend, ggml_backend_event_t event); + // block until an event is recorded void (*GGML_CALL event_synchronize) (ggml_backend_event_t event); }; diff --git a/llama/ggml-backend.c b/llama/ggml-backend.c index 0275b136..ca846cdb 100644 --- a/llama/ggml-backend.c +++ b/llama/ggml-backend.c @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * @@ -70,10 +70,6 @@ GGML_CALL size_t ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type_t buf return ggml_nbytes(tensor); } -bool ggml_backend_buft_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) { - return buft->iface.supports_backend(buft, backend); -} - bool ggml_backend_buft_is_host(ggml_backend_buffer_type_t buft) { if (buft->iface.is_host) { return buft->iface.is_host(buft); @@ -169,6 +165,10 @@ void ggml_backend_buffer_set_usage(ggml_backend_buffer_t buffer, enum ggml_backe } } +enum ggml_backend_buffer_usage ggml_backend_buffer_get_usage(ggml_backend_buffer_t buffer) { + return buffer->usage; +} + ggml_backend_buffer_type_t ggml_backend_buffer_get_type(ggml_backend_buffer_t buffer) { return buffer->buft; } @@ -317,6 +317,10 @@ bool ggml_backend_supports_op(ggml_backend_t backend, const struct ggml_tensor * return backend->iface.supports_op(backend, op); } +bool ggml_backend_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) { + return backend->iface.supports_buft(backend, buft); +} + bool ggml_backend_offload_op(ggml_backend_t backend, const struct ggml_tensor * op) { if (backend->iface.offload_op != NULL) { return backend->iface.offload_op(backend, op); @@ -425,7 +429,7 @@ void ggml_backend_event_wait(ggml_backend_t backend, ggml_backend_event_t event) // backend registry -#define GGML_REG_MAX_BACKENDS 16 +#define GGML_REG_MAX_BACKENDS 64 struct ggml_backend_reg { char name[128]; @@ -476,6 +480,11 @@ GGML_CALL static void ggml_backend_registry_init(void) { extern GGML_CALL void ggml_backend_kompute_reg_devices(void); ggml_backend_kompute_reg_devices(); #endif + +#ifdef GGML_USE_CANN + extern GGML_CALL int ggml_backend_cann_reg_devices(void); + ggml_backend_cann_reg_devices(); +#endif } GGML_CALL void ggml_backend_register(const char * name, ggml_backend_init_fn init_fn, ggml_backend_buffer_type_t default_buffer_type, void * user_data) { @@ -670,12 +679,6 @@ GGML_CALL static size_t ggml_backend_cpu_buffer_type_get_alignment(ggml_backend_ GGML_UNUSED(buft); } -GGML_CALL static bool ggml_backend_cpu_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) { - return ggml_backend_is_cpu(backend); - - GGML_UNUSED(buft); -} - GGML_CALL static bool ggml_backend_cpu_buffer_type_is_host(ggml_backend_buffer_type_t buft) { return true; @@ -690,7 +693,6 @@ GGML_CALL ggml_backend_buffer_type_t ggml_backend_cpu_buffer_type(void) { /* .get_alignment = */ ggml_backend_cpu_buffer_type_get_alignment, /* .get_max_size = */ NULL, // defaults to SIZE_MAX /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes - /* .supports_backend = */ ggml_backend_cpu_buffer_type_supports_backend, /* .is_host = */ ggml_backend_cpu_buffer_type_is_host, }, /* .context = */ NULL, @@ -746,7 +748,6 @@ ggml_backend_buffer_type_t ggml_backend_cpu_hbm_buffer_type(void) { /* .get_alignment = */ ggml_backend_cpu_buffer_type_get_alignment, /* .get_max_size = */ NULL, // defaults to SIZE_MAX /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes - /* .supports_backend = */ ggml_backend_cpu_buffer_type_supports_backend, /* .is_host = */ ggml_backend_cpu_buffer_type_is_host, }, /* .context = */ NULL, @@ -867,6 +868,12 @@ GGML_CALL static bool ggml_backend_cpu_supports_op(ggml_backend_t backend, const GGML_UNUSED(backend); } +GGML_CALL static bool ggml_backend_cpu_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) { + return ggml_backend_buft_is_host(buft); + + GGML_UNUSED(backend); +} + static struct ggml_backend_i cpu_backend_i = { /* .get_name = */ ggml_backend_cpu_name, /* .free = */ ggml_backend_cpu_free, @@ -877,9 +884,11 @@ static struct ggml_backend_i cpu_backend_i = { /* .synchronize = */ NULL, /* .graph_plan_create = */ ggml_backend_cpu_graph_plan_create, /* .graph_plan_free = */ ggml_backend_cpu_graph_plan_free, + /* .graph_plan_update = */ NULL, /* .graph_plan_compute = */ ggml_backend_cpu_graph_plan_compute, /* .graph_compute = */ ggml_backend_cpu_graph_compute, /* .supports_op = */ ggml_backend_cpu_supports_op, + /* .supports_buft = */ ggml_backend_cpu_supports_buft, /* .offload_op = */ NULL, /* .event_new = */ NULL, /* .event_free = */ NULL, @@ -1077,17 +1086,19 @@ struct ggml_backend_sched { ggml_backend_buffer_type_t bufts[GGML_SCHED_MAX_BACKENDS]; ggml_gallocr_t galloc; - // hash keys of the nodes in the graph - struct ggml_hash_set hash_set; - // hash values - int * tensor_backend_id; - struct ggml_tensor * (* tensor_copies)[GGML_SCHED_MAX_BACKENDS][GGML_SCHED_MAX_COPIES]; + // hash map of the nodes in the graph + struct ggml_hash_set hash_set; + int * hv_tensor_backend_ids; // [hash_set.size] + struct ggml_tensor ** hv_tensor_copies; // [hash_set.size][n_backends][n_copies] int * node_backend_ids; // [graph_size] int * leaf_backend_ids; // [graph_size] + int * prev_node_backend_ids; // [graph_size] + int * prev_leaf_backend_ids; // [graph_size] + // copy of the graph with modified inputs - struct ggml_cgraph * graph; + struct ggml_cgraph graph; // graph splits struct ggml_backend_sched_split * splits; @@ -1106,17 +1117,16 @@ struct ggml_backend_sched { ggml_backend_sched_eval_callback callback_eval; void * callback_eval_user_data; - // align context_buffer to GGML_MEM_ALIGN -#ifdef _MSC_VER - __declspec(align(GGML_MEM_ALIGN)) -#else - __attribute__((aligned(GGML_MEM_ALIGN))) -#endif - char context_buffer[GGML_SCHED_MAX_SPLITS*GGML_SCHED_MAX_SPLIT_INPUTS*2*sizeof(struct ggml_tensor) + sizeof(struct ggml_cgraph)]; + char * context_buffer; + size_t context_buffer_size; + + bool debug; }; -#define hash_id(tensor) ggml_hash_find_or_insert(sched->hash_set, tensor) -#define tensor_backend_id(tensor) sched->tensor_backend_id[hash_id(tensor)] +#define hash_id(tensor) ggml_hash_find_or_insert(&sched->hash_set, tensor) +#define tensor_backend_id(tensor) sched->hv_tensor_backend_ids[hash_id(tensor)] +#define tensor_id_copy(id, backend_id, copy_id) sched->hv_tensor_copies[(id) * sched->n_backends * sched->n_copies + (backend_id) * sched->n_copies + (copy_id)] +#define tensor_copy(tensor, backend_id, copy_id) tensor_id_copy(hash_id(tensor), backend_id, copy_id) // returns the priority of the backend, lower id is higher priority static int ggml_backend_sched_backend_id(ggml_backend_sched_t sched, ggml_backend_t backend) { @@ -1128,22 +1138,24 @@ static int ggml_backend_sched_backend_id(ggml_backend_sched_t sched, ggml_backen return -1; } -static int ggml_backend_sched_backend_from_buffer(ggml_backend_sched_t sched, const struct ggml_tensor * tensor) { +static int ggml_backend_sched_backend_from_buffer(ggml_backend_sched_t sched, const struct ggml_tensor * tensor, const struct ggml_tensor * op) { ggml_backend_buffer_t buffer = tensor->buffer; if (buffer == NULL) { return -1; } - // find highest prio backend that supports the buffer type + // find highest prio backend that supports the buffer type and the op for (int i = 0; i < sched->n_backends; i++) { - if (ggml_backend_buft_supports_backend(buffer->buft, sched->backends[i])) { + if (ggml_backend_supports_buft(sched->backends[i], buffer->buft) && + ggml_backend_supports_op(sched->backends[i], op)) { return i; } } - fprintf(stderr, "%s: error: no backend supports buffer type %s used in tensor %s\n", - __func__, ggml_backend_buffer_name(buffer), tensor->name); - GGML_ASSERT(false); +#ifndef NDEBUG + fprintf(stderr, "%s: warning: no backend supports op %s with a weight with buffer type %s used in tensor %s, the weight will need to be copied\n", + __func__, ggml_op_desc(tensor), ggml_backend_buffer_name(buffer), tensor->name); +#endif return -1; } @@ -1162,7 +1174,7 @@ static int ggml_backend_sched_backend_id_from_cur(ggml_backend_sched_t sched, st // TODO: use supports_op to check if the backend supports the op // assign pre-allocated nodes to their backend - int cur_backend_id = ggml_backend_sched_backend_from_buffer(sched, tensor); + int cur_backend_id = ggml_backend_sched_backend_from_buffer(sched, tensor, tensor); if (cur_backend_id != -1) { SET_CAUSE(tensor, "1.dst"); return cur_backend_id; @@ -1170,7 +1182,7 @@ static int ggml_backend_sched_backend_id_from_cur(ggml_backend_sched_t sched, st // view_src if (tensor->view_src != NULL) { - cur_backend_id = ggml_backend_sched_backend_from_buffer(sched, tensor->view_src); + cur_backend_id = ggml_backend_sched_backend_from_buffer(sched, tensor->view_src, tensor); if (cur_backend_id != -1) { SET_CAUSE(tensor, "1.vsrc"); return cur_backend_id; @@ -1184,7 +1196,6 @@ static int ggml_backend_sched_backend_id_from_cur(ggml_backend_sched_t sched, st return cur_backend_id; } - // assign nodes that use weights to the backend of the weights // operations with weights are preferably run on the same backend as the weights for (int i = 0; i < GGML_MAX_SRC; i++) { const struct ggml_tensor * src = tensor->src[i]; @@ -1192,11 +1203,11 @@ static int ggml_backend_sched_backend_id_from_cur(ggml_backend_sched_t sched, st continue; } if (src->buffer != NULL && src->buffer->usage == GGML_BACKEND_BUFFER_USAGE_WEIGHTS) { - int src_backend_id = ggml_backend_sched_backend_from_buffer(sched, src); + int src_backend_id = ggml_backend_sched_backend_from_buffer(sched, src, tensor); // check if a backend with higher prio wants to offload the op if (src_backend_id == sched->n_backends - 1) { for (int b = 0; b < src_backend_id; b++) { - if (ggml_backend_offload_op(sched->backends[b], tensor)) { + if (ggml_backend_supports_op(sched->backends[b], tensor) && ggml_backend_offload_op(sched->backends[b], tensor)) { SET_CAUSE(tensor, "1.off"); return b; } @@ -1254,10 +1265,33 @@ static void ggml_backend_sched_print_assignments(ggml_backend_sched_t sched, str } } -//#define DEBUG_PASS1 -//#define DEBUG_PASS2 -//#define DEBUG_PASS3 -//#define DEBUG_PASS4 +static bool ggml_backend_sched_buffer_supported(ggml_backend_sched_t sched, struct ggml_tensor * t, int backend_id) { + ggml_backend_buffer_t buf = t->view_src ? t->view_src->buffer : t->buffer; + ggml_backend_buffer_type_t buft = NULL; + + if (buf) { + // the tensor is already allocated + buft = buf->buft; + } else { + // see if the tensor already has a backend assigned, and use the buffer type of that backend + int tensor_backend_id = tensor_backend_id(t); + if (tensor_backend_id == -1 && t->view_src) { + tensor_backend_id = tensor_backend_id(t->view_src); + } + if (tensor_backend_id != -1) { + buft = sched->bufts[tensor_backend_id]; + } + } + + return buft != NULL && ggml_backend_supports_buft(sched->backends[backend_id], buft); +} + +static void ggml_backend_sched_set_if_supported(ggml_backend_sched_t sched, struct ggml_tensor * node, int cur_backend_id, int * node_backend_id) { + if (ggml_backend_supports_op(sched->backends[cur_backend_id], node)) { + *node_backend_id = cur_backend_id; + SET_CAUSE(node, "2.sup"); + } +} // assigns backends to ops and splits the graph into subgraphs that can be computed on the same backend static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) { @@ -1267,7 +1301,7 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg sched->is_reset = false; struct ggml_init_params params = { - /* .mem_size = */ sizeof(sched->context_buffer), + /* .mem_size = */ sched->context_buffer_size, /* .mem_buffer = */ sched->context_buffer, /* .no_alloc = */ true }; @@ -1276,52 +1310,52 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg sched->ctx = ggml_init(params); if (sched->ctx == NULL) { - fprintf(stderr, "%s: failed to initialize context\n", __func__); - GGML_ASSERT(false); + GGML_ABORT("%s: failed to initialize context\n", __func__); } // pass 1: assign backends to ops with pre-allocated inputs for (int i = 0; i < graph->n_leafs; i++) { struct ggml_tensor * leaf = graph->leafs[i]; int * leaf_backend_id = &tensor_backend_id(leaf); - if (*leaf_backend_id != -1) { - // do not overwrite user assignments - continue; + // do not overwrite user assignments + if (*leaf_backend_id == -1) { + *leaf_backend_id = ggml_backend_sched_backend_id_from_cur(sched, leaf); } - *leaf_backend_id = ggml_backend_sched_backend_id_from_cur(sched, leaf); } for (int i = 0; i < graph->n_nodes; i++) { struct ggml_tensor * node = graph->nodes[i]; int * node_backend_id = &tensor_backend_id(node); - if (*node_backend_id != -1) { - // do not overwrite user assignments - continue; - } - *node_backend_id = ggml_backend_sched_backend_id_from_cur(sched, node); - // src - for (int j = 0; j < GGML_MAX_SRC; j++) { - struct ggml_tensor * src = node->src[j]; - if (src == NULL) { + // do not overwrite user assignments + if (*node_backend_id == -1) { + *node_backend_id = ggml_backend_sched_backend_id_from_cur(sched, node); + +#if 0 + // src + if (node->op == GGML_OP_NONE) { continue; } - int * src_backend_id = &tensor_backend_id(src); - if (*src_backend_id == -1) { - *src_backend_id = ggml_backend_sched_backend_id_from_cur(sched, src); + + for (int j = 0; j < GGML_MAX_SRC; j++) { + struct ggml_tensor * src = node->src[j]; + if (src == NULL) { + continue; + } + int * src_backend_id = &tensor_backend_id(src); + if (*src_backend_id == -1) { + *src_backend_id = ggml_backend_sched_backend_id_from_cur(sched, src); + } } +#endif } } -#ifdef DEBUG_PASS1 - fprintf(stderr, "PASS 1 ASSIGNMENTS\n"); ggml_backend_sched_print_assignments(sched, graph); -#endif // pass 2: expand current backend assignments // assign the same backend to adjacent nodes // expand gpu backends (i.e. non last prio) up and down, ignoring cpu (the lowest priority backend) // thus, cpu will never be used unless weights are on cpu, or there are no gpu ops between cpu ops - - - // pass 2.2 expand gpu down + // ops unsupported by the backend being expanded will be left unassigned so that they can be assigned later when the locations of its inputs are known + // expand gpu down { int cur_backend_id = -1; for (int i = 0; i < graph->n_nodes; i++) { @@ -1337,13 +1371,12 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg } else { cur_backend_id = *node_backend_id; } - } else { - *node_backend_id = cur_backend_id; - SET_CAUSE(node, "2.2"); + } else if (cur_backend_id != -1) { + ggml_backend_sched_set_if_supported(sched, node, cur_backend_id, node_backend_id); } } } - // pass 2.1 expand gpu up + // expand gpu up { int cur_backend_id = -1; for (int i = graph->n_nodes - 1; i >= 0; i--) { @@ -1359,13 +1392,12 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg } else { cur_backend_id = *node_backend_id; } - } else { - *node_backend_id = cur_backend_id; - SET_CAUSE(node, "2.1"); + } else if (cur_backend_id != -1) { + ggml_backend_sched_set_if_supported(sched, node, cur_backend_id, node_backend_id); } } } - // pass 2.4 expand rest down + // expand rest down { int cur_backend_id = -1; for (int i = 0; i < graph->n_nodes; i++) { @@ -1376,13 +1408,12 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg int * node_backend_id = &tensor_backend_id(node); if (*node_backend_id != -1) { cur_backend_id = *node_backend_id; - } else { - *node_backend_id = cur_backend_id; - SET_CAUSE(node, "2.4"); + } else if (cur_backend_id != -1) { + ggml_backend_sched_set_if_supported(sched, node, cur_backend_id, node_backend_id); } } } - // pass 2.3 expand rest up + // expand rest up { int cur_backend_id = -1; for (int i = graph->n_nodes - 1; i >= 0; i--) { @@ -1393,24 +1424,80 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg int * node_backend_id = &tensor_backend_id(node); if (*node_backend_id != -1) { cur_backend_id = *node_backend_id; - } else { - *node_backend_id = cur_backend_id; - SET_CAUSE(node, "2.3"); + } else if (cur_backend_id != -1) { + ggml_backend_sched_set_if_supported(sched, node, cur_backend_id, node_backend_id); } } } -#ifdef DEBUG_PASS2 - fprintf(stderr, "PASS 2 ASSIGNMENTS\n"); ggml_backend_sched_print_assignments(sched, graph); -#endif + // pass 3: upgrade nodes to higher prio backends with compatible buffer types + // if the tensor is already in the same buffer type (*) as another higher priority backend, we should move it there + // however, we also need to verify that the sources are in compatible buffer types + // (*) the actual requirement is more relaxed, the buffer type of the backend should be supported by all the users of this tensor further down the graph + // however, this is slow to verify, so we have a more strict requirement that the buffer type is the same + // this is not uncommon since multiple backends can use host memory, with the same buffer type (eg. BLAS and CPU) + // additionally, set remaining unassigned nodes to the backend with the most supported inputs + // only nodes that could not be assigned during expansion due to the backend not supporting the op should be unassigned at this point + for (int i = 0; i < graph->n_nodes; i++) { + struct ggml_tensor * node = graph->nodes[i]; + if (ggml_is_view_op(node->op)) { + continue; + } + int * node_backend_id = &tensor_backend_id(node); + if (*node_backend_id == -1) { + // unassigned node: find the backend with the most supported inputs + int n_supported_best = -1; + for (int b = 0; b < sched->n_backends; b++) { + if (ggml_backend_supports_op(sched->backends[b], node)) { + int n_supported = 0; + for (int j = 0; j < GGML_MAX_SRC; j++) { + struct ggml_tensor * src = node->src[j]; + if (src == NULL) { + continue; + } + if ((tensor_backend_id(src) != -1 || tensor_backend_id(src->view_src) != -1) && ggml_backend_sched_buffer_supported(sched, src, b)) { + n_supported++; + } + } + if (n_supported > n_supported_best) { + n_supported_best = n_supported; + *node_backend_id = b; + SET_CAUSE(node, "3.best"); + } + } + } + } else { + // assigned node: upgrade to higher prio backend if possible + for (int b = 0; b < *node_backend_id; b++) { + if (sched->bufts[b] == sched->bufts[*node_backend_id] && ggml_backend_supports_op(sched->backends[b], node)) { + bool supported = true; + for (int j = 0; j < GGML_MAX_SRC; j++) { + struct ggml_tensor * src = node->src[j]; + if (src == NULL) { + continue; + } + if (!ggml_backend_sched_buffer_supported(sched, src, b)) { + supported = false; + break; + } + } + if (supported) { + *node_backend_id = b; + SET_CAUSE(node, "3.upg"); + break; + } + } + } + } + } - // pass 3: assign backends to remaining src from dst and view_src + // pass 4: assign backends to remaining src from dst and view_src for (int i = 0; i < graph->n_nodes; i++) { struct ggml_tensor * node = graph->nodes[i]; int * cur_backend_id = &tensor_backend_id(node); if (node->view_src != NULL && *cur_backend_id == -1) { *cur_backend_id = tensor_backend_id(node->view_src); - SET_CAUSE(node, "3.vsrc"); + SET_CAUSE(node, "4.vsrc"); } for (int j = 0; j < GGML_MAX_SRC; j++) { struct ggml_tensor * src = node->src[j]; @@ -1422,24 +1509,22 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg if (src->view_src != NULL) { // views are always on the same backend as the source *src_backend_id = tensor_backend_id(src->view_src); - SET_CAUSE(src, "3.vsrc"); + SET_CAUSE(src, "4.vsrc"); } else { *src_backend_id = *cur_backend_id; - SET_CAUSE(src, "3.cur"); + SET_CAUSE(src, "4.cur"); } } } } -#ifdef DEBUG_PASS3 - fprintf(stderr, "PASS 3 ASSIGNMENTS\n"); ggml_backend_sched_print_assignments(sched, graph); -#endif - // pass 4: split graph, find tensors that need to be copied + // pass 5: split graph, find tensors that need to be copied { int i_split = 0; struct ggml_backend_sched_split * split = &sched->splits[0]; // find the backend of the first split, skipping view ops - for (int i = 0; i < graph->n_nodes; i++) { + int i = 0; + for (; i < graph->n_nodes; i++) { struct ggml_tensor * node = graph->nodes[i]; if (!ggml_is_view_op(node->op)) { split->backend_id = tensor_backend_id(node); @@ -1448,9 +1533,8 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg } split->i_start = 0; split->n_inputs = 0; - memset(split->inputs, 0, sizeof(split->inputs)); //HACK int cur_backend_id = split->backend_id; - for (int i = 0; i < graph->n_nodes; i++) { + for (; i < graph->n_nodes; i++) { struct ggml_tensor * node = graph->nodes[i]; if (ggml_is_view_op(node->op)) { @@ -1459,7 +1543,7 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg const int node_backend_id = tensor_backend_id(node); - GGML_ASSERT(node_backend_id != -1); // all nodes should be assigned by now + assert(node_backend_id != -1); // all nodes should be assigned by now // check if we should start a new split based on the sources of the current node bool need_new_split = false; @@ -1473,16 +1557,18 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg // by starting a new split, the memory of the previously offloaded weights can be reused if (src->buffer != NULL && src->buffer->usage == GGML_BACKEND_BUFFER_USAGE_WEIGHTS) { int src_backend_id = tensor_backend_id(src); - if (src_backend_id != -1 && src_backend_id != cur_backend_id) { + if (src_backend_id != cur_backend_id) { need_new_split = true; break; } } // check if the split has too many inputs + // FIXME: count the number of inputs instead of only checking when full if (split->n_inputs == GGML_SCHED_MAX_SPLIT_INPUTS) { const size_t id = hash_id(src); - int src_backend_id = sched->tensor_backend_id[id]; - if (src_backend_id != cur_backend_id && sched->tensor_copies[hash_id(src)][cur_backend_id][0] == NULL) { + int src_backend_id = sched->hv_tensor_backend_ids[id]; + bool supported = ggml_backend_sched_buffer_supported(sched, src, cur_backend_id); + if (src_backend_id != cur_backend_id && tensor_id_copy(id, cur_backend_id, 0) == NULL && !supported) { //printf("starting new split because of too many inputs: node %s, input %s\n", node->name, src->name); need_new_split = true; break; @@ -1514,12 +1600,12 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg continue; } - const int src_backend_id = tensor_backend_id(src); + size_t src_id = hash_id(src); + const int src_backend_id = sched->hv_tensor_backend_ids[src_id]; assert(src_backend_id != -1); // all inputs should be assigned by now - if (src->flags & GGML_TENSOR_FLAG_INPUT && sched->n_copies > 1) { - size_t id = hash_id(src); - if (sched->tensor_copies[id][src_backend_id][0] == NULL) { + if (src->flags & GGML_TENSOR_FLAG_INPUT && sched->n_copies > 1) { + if (tensor_id_copy(src_id, src_backend_id, 0) == NULL) { ggml_backend_t backend = sched->backends[src_backend_id]; for (int c = 0; c < sched->n_copies; c++) { struct ggml_tensor * tensor_copy; @@ -1533,7 +1619,7 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg ggml_set_input(tensor_copy); ggml_set_output(tensor_copy); // prevent ggml-alloc from overwriting the tensor } - sched->tensor_copies[id][src_backend_id][c] = tensor_copy; + tensor_id_copy(src_id, src_backend_id, c) = tensor_copy; SET_CAUSE(tensor_copy, "4.cpy"); } int n_graph_inputs = sched->n_graph_inputs++; @@ -1542,10 +1628,9 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg } } - if (src_backend_id != node_backend_id) { + if (src_backend_id != cur_backend_id && !ggml_backend_sched_buffer_supported(sched, src, cur_backend_id)) { // create a copy of the input in the split's backend - const size_t id = hash_id(src); - if (sched->tensor_copies[id][cur_backend_id][0] == NULL) { + if (tensor_id_copy(src_id, cur_backend_id, 0) == NULL) { ggml_backend_t backend = sched->backends[cur_backend_id]; for (int c = 0; c < sched->n_copies; c++) { struct ggml_tensor * tensor_copy = ggml_dup_tensor_layout(sched->ctx, src); @@ -1554,27 +1639,49 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg ggml_set_input(tensor_copy); ggml_set_output(tensor_copy); // prevent ggml-alloc from overwriting the tensor } - sched->tensor_copies[id][cur_backend_id][c] = tensor_copy; + tensor_id_copy(src_id, cur_backend_id, c) = tensor_copy; SET_CAUSE(tensor_copy, "4.cpy"); } int n_inputs = split->n_inputs++; GGML_ASSERT(n_inputs < GGML_SCHED_MAX_SPLIT_INPUTS); split->inputs[n_inputs] = src; } - node->src[j] = sched->tensor_copies[id][cur_backend_id][sched->cur_copy]; + node->src[j] = tensor_id_copy(src_id, cur_backend_id, sched->cur_copy); } } } split->i_end = graph->n_nodes; sched->n_splits = i_split + 1; } -#ifdef DEBUG_PASS4 - fprintf(stderr, "PASS 4 ASSIGNMENTS\n"); ggml_backend_sched_print_assignments(sched, graph); -#endif - // create copies of the graph for each split - // TODO: avoid this copy - struct ggml_cgraph * graph_copy = ggml_new_graph_custom(sched->ctx, graph->n_nodes + sched->n_splits*GGML_SCHED_MAX_SPLIT_INPUTS*2, false); + if (sched->debug) { + ggml_backend_sched_print_assignments(sched, graph); + } + + // swap node_backend_ids and leaf _backend_ids with prevs + { + int * tmp = sched->node_backend_ids; + sched->node_backend_ids = sched->prev_node_backend_ids; + sched->prev_node_backend_ids = tmp; + + tmp = sched->leaf_backend_ids; + sched->leaf_backend_ids = sched->prev_leaf_backend_ids; + sched->prev_leaf_backend_ids = tmp; + } + + int graph_size = graph->n_nodes + sched->n_splits*GGML_SCHED_MAX_SPLIT_INPUTS*2; + if (sched->graph.size < graph_size) { + sched->graph.size = graph_size; + sched->graph.nodes = realloc(sched->graph.nodes, graph_size * sizeof(struct ggml_tensor *)); + sched->graph.leafs = realloc(sched->graph.leafs, graph_size * sizeof(struct ggml_tensor *)); + GGML_ASSERT(sched->graph.nodes != NULL); + GGML_ASSERT(sched->graph.leafs != NULL); + } + sched->graph.n_nodes = 0; + sched->graph.n_leafs = 0; + + struct ggml_cgraph * graph_copy = &sched->graph; + for (int i = 0; i < sched->n_splits; i++) { struct ggml_backend_sched_split * split = &sched->splits[i]; split->graph = ggml_graph_view(graph, split->i_start, split->i_end); @@ -1585,12 +1692,12 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg struct ggml_tensor * input = split->inputs[j]; const size_t input_id = hash_id(input); - struct ggml_tensor * input_cpy = sched->tensor_copies[input_id][split->backend_id][sched->cur_copy]; + struct ggml_tensor * input_cpy = tensor_id_copy(input_id, split->backend_id, sched->cur_copy); // add a dependency to the input source so that it is not freed before the copy is done struct ggml_tensor * input_dep = ggml_view_tensor(sched->ctx, input); input_dep->src[0] = input; - sched->node_backend_ids[graph_copy->n_nodes] = sched->tensor_backend_id[input_id]; + sched->node_backend_ids[graph_copy->n_nodes] = sched->hv_tensor_backend_ids[input_id]; graph_copy->nodes[graph_copy->n_nodes++] = input_dep; // add a dependency to the input copy so that it is allocated at the start of the split @@ -1612,7 +1719,7 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg size_t id = hash_id(input); int backend_id = tensor_backend_id(input); for (int c = 0; c < sched->n_copies; c++) { - struct ggml_tensor * input_cpy = sched->tensor_copies[id][backend_id][c]; + struct ggml_tensor * input_cpy = tensor_id_copy(id, backend_id, c); sched->leaf_backend_ids[graph_copy->n_leafs] = backend_id; graph_copy->leafs[graph_copy->n_leafs++] = input_cpy; } @@ -1625,7 +1732,7 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg struct ggml_tensor * input = split->inputs[j]; size_t id = hash_id(input); for (int c = 0; c < sched->n_copies; c++) { - struct ggml_tensor * input_cpy = sched->tensor_copies[id][backend_id][c]; + struct ggml_tensor * input_cpy = tensor_id_copy(id, backend_id, c); sched->leaf_backend_ids[graph_copy->n_leafs] = backend_id; graph_copy->leafs[graph_copy->n_leafs++] = input_cpy; } @@ -1639,20 +1746,36 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg sched->leaf_backend_ids[graph_copy->n_leafs] = tensor_backend_id(leaf); graph_copy->leafs[graph_copy->n_leafs++] = leaf; } - - sched->graph = graph_copy; } static bool ggml_backend_sched_alloc_splits(ggml_backend_sched_t sched) { + bool backend_ids_changed = false; + for (int i = 0; i < sched->graph.n_nodes; i++) { + if (sched->node_backend_ids[i] != sched->prev_node_backend_ids[i] && + sched->bufts[sched->node_backend_ids[i]] != sched->bufts[sched->prev_node_backend_ids[i]]) { + backend_ids_changed = true; + break; + } + } + if (!backend_ids_changed) { + for (int i = 0; i < sched->graph.n_leafs; i++) { + if (sched->leaf_backend_ids[i] != sched->prev_leaf_backend_ids[i] && + sched->bufts[sched->leaf_backend_ids[i]] != sched->bufts[sched->prev_leaf_backend_ids[i]]) { + backend_ids_changed = true; + break; + } + } + } + // allocate graph - if (!ggml_gallocr_alloc_graph(sched->galloc, sched->graph)) { + if (backend_ids_changed || !ggml_gallocr_alloc_graph(sched->galloc, &sched->graph)) { // the re-allocation may cause the split inputs to be moved to a different address ggml_backend_sched_synchronize(sched); #ifndef NDEBUG - fprintf(stderr, "%s: failed to allocate graph, reserving\n", __func__); + fprintf(stderr, "%s: failed to allocate graph, reserving (backend_ids_changed = %d)\n", __func__, backend_ids_changed); #endif - ggml_gallocr_reserve_n(sched->galloc, sched->graph, sched->node_backend_ids, sched->leaf_backend_ids); - if (!ggml_gallocr_alloc_graph(sched->galloc, sched->graph)) { + ggml_gallocr_reserve_n(sched->galloc, &sched->graph, sched->node_backend_ids, sched->leaf_backend_ids); + if (!ggml_gallocr_alloc_graph(sched->galloc, &sched->graph)) { fprintf(stderr, "%s: failed to allocate graph\n", __func__); return false; } @@ -1673,7 +1796,7 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s for (int j = 0; j < split->n_inputs; j++) { ggml_backend_t input_backend = ggml_backend_sched_get_tensor_backend(sched, split->inputs[j]); struct ggml_tensor * input = split->inputs[j]; - struct ggml_tensor * input_cpy = sched->tensor_copies[hash_id(input)][split_backend_id][sched->cur_copy]; + struct ggml_tensor * input_cpy = tensor_copy(input, split_backend_id, sched->cur_copy); if (input->flags & GGML_TENSOR_FLAG_INPUT) { // inputs from the user must be copied immediately to prevent the user overwriting the data before the copy is done @@ -1758,18 +1881,24 @@ ggml_backend_sched_t ggml_backend_sched_new( struct ggml_backend_sched * sched = calloc(1, sizeof(struct ggml_backend_sched)); + sched->debug = getenv("GGML_SCHED_DEBUG") != NULL; + sched->n_backends = n_backends; + sched->n_copies = parallel ? GGML_SCHED_MAX_COPIES : 1; + // initialize hash table - sched->hash_set = ggml_hash_set_new(graph_size); - sched->tensor_backend_id = calloc(sched->hash_set.size, sizeof(sched->tensor_backend_id[0])); - sched->tensor_copies = calloc(sched->hash_set.size, sizeof(sched->tensor_copies[0])); + // FIXME: needs to be size*2 to account for leafs (do it in graph_split instead) + sched->hash_set = ggml_hash_set_new(graph_size); + sched->hv_tensor_backend_ids = malloc(sched->hash_set.size * sizeof(sched->hv_tensor_backend_ids[0])); + sched->hv_tensor_copies = malloc(sched->hash_set.size * sched->n_backends * sched->n_copies * sizeof(struct ggml_tensor *)); const size_t nodes_size = graph_size + GGML_SCHED_MAX_SPLITS*GGML_SCHED_MAX_SPLIT_INPUTS*2; - sched->node_backend_ids = calloc(nodes_size, sizeof(sched->node_backend_ids[0])); - sched->leaf_backend_ids = calloc(nodes_size, sizeof(sched->leaf_backend_ids[0])); + sched->node_backend_ids = calloc(nodes_size, sizeof(sched->node_backend_ids[0])); + sched->leaf_backend_ids = calloc(nodes_size, sizeof(sched->leaf_backend_ids[0])); + sched->prev_node_backend_ids = calloc(nodes_size, sizeof(sched->prev_node_backend_ids[0])); + sched->prev_leaf_backend_ids = calloc(nodes_size, sizeof(sched->prev_leaf_backend_ids[0])); - sched->n_backends = n_backends; - - sched->n_copies = parallel ? GGML_SCHED_MAX_COPIES : 1; + sched->context_buffer_size = GGML_SCHED_MAX_SPLITS*GGML_SCHED_MAX_SPLIT_INPUTS*2*sizeof(struct ggml_tensor) + ggml_graph_overhead_custom(graph_size, false); + sched->context_buffer = malloc(sched->context_buffer_size); const int initial_splits_capacity = 16; sched->splits = calloc(initial_splits_capacity, sizeof(sched->splits[0])); @@ -1778,7 +1907,7 @@ ggml_backend_sched_t ggml_backend_sched_new( for (int b = 0; b < n_backends; b++) { sched->backends[b] = backends[b]; sched->bufts[b] = bufts ? bufts[b] : ggml_backend_get_default_buffer_type(backends[b]); - GGML_ASSERT(ggml_backend_buft_supports_backend(sched->bufts[b], backends[b])); + GGML_ASSERT(ggml_backend_supports_buft(backends[b], sched->bufts[b])); if (sched->n_copies > 1) { for (int c = 0; c < sched->n_copies; c++) { sched->events[b][c] = ggml_backend_event_new(backends[b]); @@ -1804,35 +1933,37 @@ void ggml_backend_sched_free(ggml_backend_sched_t sched) { } ggml_gallocr_free(sched->galloc); ggml_free(sched->ctx); + ggml_hash_set_free(&sched->hash_set); free(sched->splits); - free(sched->hash_set.keys); - free(sched->tensor_backend_id); - free(sched->tensor_copies); + free(sched->hv_tensor_backend_ids); + free(sched->hv_tensor_copies); free(sched->node_backend_ids); free(sched->leaf_backend_ids); + free(sched->prev_node_backend_ids); + free(sched->prev_leaf_backend_ids); + free(sched->context_buffer); + free(sched->graph.nodes); + free(sched->graph.leafs); free(sched); } void ggml_backend_sched_reset(ggml_backend_sched_t sched) { // reset state for the next run if (!sched->is_reset) { - size_t hash_size = sched->hash_set.size; - memset(sched->hash_set.keys, 0, sizeof(sched->hash_set.keys[0]) * hash_size); // NOLINT - memset(sched->tensor_backend_id, -1, sizeof(sched->tensor_backend_id[0]) * hash_size); - memset(sched->tensor_copies, 0, sizeof(sched->tensor_copies[0]) * hash_size); - + ggml_hash_set_reset(&sched->hash_set); + memset(sched->hv_tensor_backend_ids, -1, sched->hash_set.size * sizeof(sched->hv_tensor_backend_ids[0])); + memset(sched->hv_tensor_copies, 0, sched->hash_set.size * sched->n_backends * sched->n_copies * sizeof(struct ggml_tensor *)); sched->is_reset = true; } sched->is_alloc = false; } bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph) { - GGML_ASSERT((int)sched->hash_set.size >= measure_graph->n_nodes); + GGML_ASSERT((int)sched->hash_set.size >= measure_graph->n_nodes + measure_graph->n_leafs); ggml_backend_sched_split_graph(sched, measure_graph); - // TODO: extract this to a separate function - if (!ggml_gallocr_reserve_n(sched->galloc, sched->graph, sched->node_backend_ids, sched->leaf_backend_ids)) { + if (!ggml_gallocr_reserve_n(sched->galloc, &sched->graph, sched->node_backend_ids, sched->leaf_backend_ids)) { return false; } @@ -1843,10 +1974,11 @@ bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * } bool ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) { - GGML_ASSERT((int)sched->hash_set.size >= graph->n_nodes); + GGML_ASSERT((int)sched->hash_set.size >= graph->n_nodes + graph->n_leafs); ggml_backend_sched_split_graph(sched, graph); + if (!ggml_backend_sched_alloc_splits(sched)) { return false; } @@ -1895,6 +2027,15 @@ int ggml_backend_sched_get_n_copies(ggml_backend_sched_t sched) { return sched->n_copies; } +int ggml_backend_sched_get_n_backends(ggml_backend_sched_t sched) { + return sched->n_backends; +} + +ggml_backend_t ggml_backend_sched_get_backend(ggml_backend_sched_t sched, int i) { + GGML_ASSERT(i >= 0 && i < sched->n_backends); + return sched->backends[i]; +} + size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend) { int backend_index = ggml_backend_sched_backend_id(sched, backend); GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends); @@ -1906,6 +2047,8 @@ void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct gg int backend_index = ggml_backend_sched_backend_id(sched, backend); GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends); tensor_backend_id(node) = backend_index; + SET_CAUSE(node, "usr"); + sched->is_reset = false; } ggml_backend_t ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node) { @@ -1948,9 +2091,9 @@ static struct ggml_tensor * graph_copy_dup_tensor(struct ggml_hash_set hash_set, GGML_ASSERT(src != NULL); GGML_ASSERT(src->data && "graph must be allocated"); - size_t id = ggml_hash_insert(hash_set, src); - if (id == GGML_HASHTABLE_ALREADY_EXISTS) { - return node_copies[ggml_hash_find(hash_set, src)]; + size_t id = ggml_hash_insert(&hash_set, src); + if (id == GGML_HASHSET_ALREADY_EXISTS) { + return node_copies[ggml_hash_find(&hash_set, src)]; } struct ggml_tensor * dst = ggml_dup_tensor_layout(src->data && !src->view_src ? ctx_allocated : ctx_unallocated, src); @@ -1975,7 +2118,7 @@ static struct ggml_tensor * graph_copy_dup_tensor(struct ggml_hash_set hash_set, return dst; } -static void graph_copy_init_tensor(struct ggml_hash_set hash_set, struct ggml_tensor ** node_copies, bool * node_init, struct ggml_tensor * src) { +static void graph_copy_init_tensor(struct ggml_hash_set * hash_set, struct ggml_tensor ** node_copies, bool * node_init, struct ggml_tensor * src) { size_t id = ggml_hash_find(hash_set, src); if (node_init[id]) { return; @@ -2002,10 +2145,7 @@ static void graph_copy_init_tensor(struct ggml_hash_set hash_set, struct ggml_te } struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, struct ggml_cgraph * graph) { - struct ggml_hash_set hash_set = { - /* .size = */ graph->visited_hash_table.size, - /* .keys = */ calloc(graph->visited_hash_table.size, sizeof(hash_set.keys[0])) // NOLINT - }; + struct ggml_hash_set hash_set = ggml_hash_set_new(graph->visited_hash_set.size); struct ggml_tensor ** node_copies = calloc(hash_set.size, sizeof(node_copies[0])); // NOLINT bool * node_init = calloc(hash_set.size, sizeof(node_init[0])); @@ -2020,7 +2160,7 @@ struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, s if (ctx_allocated == NULL || ctx_unallocated == NULL) { fprintf(stderr, "failed to allocate context for graph copy\n"); - free(hash_set.keys); + ggml_hash_set_free(&hash_set); free(node_copies); free(node_init); ggml_free(ctx_allocated); @@ -2043,7 +2183,7 @@ struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, s ggml_backend_buffer_t buffer = ggml_backend_alloc_ctx_tensors(ctx_allocated, backend); if (buffer == NULL) { fprintf(stderr, "failed to allocate buffer for graph copy\n"); - free(hash_set.keys); + ggml_hash_set_free(&hash_set); free(node_copies); free(node_init); ggml_free(ctx_allocated); @@ -2061,19 +2201,19 @@ struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, s // copy data and init views for (int i = 0; i < graph->n_nodes; i++) { struct ggml_tensor * node = graph->nodes[i]; - graph_copy_init_tensor(hash_set, node_copies, node_init, node); + graph_copy_init_tensor(&hash_set, node_copies, node_init, node); } // build graph copy struct ggml_cgraph * graph_copy = ggml_new_graph_custom(ctx_allocated, graph->size, false); for (int i = 0; i < graph->n_nodes; i++) { struct ggml_tensor * node = graph->nodes[i]; - struct ggml_tensor * node_copy = node_copies[ggml_hash_find(hash_set, node)]; + struct ggml_tensor * node_copy = node_copies[ggml_hash_find(&hash_set, node)]; graph_copy->nodes[i] = node_copy; } graph_copy->n_nodes = graph->n_nodes; - free(hash_set.keys); + ggml_hash_set_free(&hash_set); free(node_copies); free(node_init); diff --git a/llama/ggml-backend.h b/llama/ggml-backend.h index 602cd030..7950571d 100644 --- a/llama/ggml-backend.h +++ b/llama/ggml-backend.h @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * @@ -49,28 +49,29 @@ extern "C" { GGML_API size_t ggml_backend_buft_get_alignment (ggml_backend_buffer_type_t buft); GGML_API size_t ggml_backend_buft_get_max_size (ggml_backend_buffer_type_t buft); GGML_API GGML_CALL size_t ggml_backend_buft_get_alloc_size (ggml_backend_buffer_type_t buft, struct ggml_tensor * tensor); - GGML_API bool ggml_backend_buft_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend); GGML_API bool ggml_backend_buft_is_host (ggml_backend_buffer_type_t buft); // buffer enum ggml_backend_buffer_usage { GGML_BACKEND_BUFFER_USAGE_ANY = 0, GGML_BACKEND_BUFFER_USAGE_WEIGHTS = 1, + GGML_BACKEND_BUFFER_USAGE_COMPUTE = 2, }; - GGML_API const char * ggml_backend_buffer_name (ggml_backend_buffer_t buffer); - GGML_API void ggml_backend_buffer_free (ggml_backend_buffer_t buffer); - GGML_API void * ggml_backend_buffer_get_base (ggml_backend_buffer_t buffer); - GGML_API size_t ggml_backend_buffer_get_size (ggml_backend_buffer_t buffer); - GGML_API GGML_CALL void ggml_backend_buffer_init_tensor (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); - GGML_API size_t ggml_backend_buffer_get_alignment (ggml_backend_buffer_t buffer); - GGML_API size_t ggml_backend_buffer_get_max_size (ggml_backend_buffer_t buffer); - GGML_API size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); - GGML_API void ggml_backend_buffer_clear (ggml_backend_buffer_t buffer, uint8_t value); - GGML_API bool ggml_backend_buffer_is_host (ggml_backend_buffer_t buffer); - GGML_API void ggml_backend_buffer_set_usage (ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage); - GGML_API ggml_backend_buffer_type_t ggml_backend_buffer_get_type (ggml_backend_buffer_t buffer); - GGML_API void ggml_backend_buffer_reset (ggml_backend_buffer_t buffer); + GGML_API const char * ggml_backend_buffer_name (ggml_backend_buffer_t buffer); + GGML_API void ggml_backend_buffer_free (ggml_backend_buffer_t buffer); + GGML_API void * ggml_backend_buffer_get_base (ggml_backend_buffer_t buffer); + GGML_API size_t ggml_backend_buffer_get_size (ggml_backend_buffer_t buffer); + GGML_API GGML_CALL void ggml_backend_buffer_init_tensor (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); + GGML_API size_t ggml_backend_buffer_get_alignment (ggml_backend_buffer_t buffer); + GGML_API size_t ggml_backend_buffer_get_max_size (ggml_backend_buffer_t buffer); + GGML_API size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); + GGML_API void ggml_backend_buffer_clear (ggml_backend_buffer_t buffer, uint8_t value); + GGML_API bool ggml_backend_buffer_is_host (ggml_backend_buffer_t buffer); + GGML_API void ggml_backend_buffer_set_usage (ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage); + GGML_API enum ggml_backend_buffer_usage ggml_backend_buffer_get_usage (ggml_backend_buffer_t buffer); + GGML_API ggml_backend_buffer_type_t ggml_backend_buffer_get_type (ggml_backend_buffer_t buffer); + GGML_API void ggml_backend_buffer_reset (ggml_backend_buffer_t buffer); // // Backend @@ -100,6 +101,7 @@ extern "C" { GGML_API enum ggml_status ggml_backend_graph_compute (ggml_backend_t backend, struct ggml_cgraph * cgraph); GGML_API enum ggml_status ggml_backend_graph_compute_async(ggml_backend_t backend, struct ggml_cgraph * cgraph); GGML_API bool ggml_backend_supports_op(ggml_backend_t backend, const struct ggml_tensor * op); + GGML_API bool ggml_backend_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft); GGML_API bool ggml_backend_offload_op(ggml_backend_t backend, const struct ggml_tensor * op); // tensor copy between different backends @@ -116,7 +118,7 @@ extern "C" { GGML_API void ggml_backend_event_free (ggml_backend_event_t event); GGML_API void ggml_backend_event_record (ggml_backend_event_t event); GGML_API void ggml_backend_event_synchronize(ggml_backend_event_t event); - GGML_API void ggml_backend_event_wait (ggml_backend_t backend, ggml_backend_event_t event); // wait async on event + GGML_API void ggml_backend_event_wait (ggml_backend_t backend, ggml_backend_event_t event); // // CPU backend @@ -145,7 +147,7 @@ extern "C" { GGML_API size_t ggml_backend_reg_get_count(void); GGML_API size_t ggml_backend_reg_find_by_name(const char * name); - GGML_API ggml_backend_t ggml_backend_reg_init_backend_from_str(const char * backend_str); // str is name[:params] + GGML_API ggml_backend_t ggml_backend_reg_init_backend_from_str(const char * backend_str); // str is backend_name:params (params is optional) GGML_API const char * ggml_backend_reg_get_name(size_t i); GGML_API ggml_backend_t ggml_backend_reg_init_backend(size_t i, const char * params); // params is backend-specific GGML_API ggml_backend_buffer_type_t ggml_backend_reg_get_default_buffer_type(size_t i); @@ -208,6 +210,9 @@ extern "C" { // Initialize backend buffers from a measure graph GGML_API bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph); + GGML_API int ggml_backend_sched_get_n_backends(ggml_backend_sched_t sched); + GGML_API ggml_backend_t ggml_backend_sched_get_backend(ggml_backend_sched_t sched, int i); + // Get the number of splits of the last graph GGML_API int ggml_backend_sched_get_n_splits(ggml_backend_sched_t sched); GGML_API int ggml_backend_sched_get_n_copies(ggml_backend_sched_t sched); diff --git a/llama/ggml-common.h b/llama/ggml-common.h index 068516f7..8ff58bfa 100644 --- a/llama/ggml-common.h +++ b/llama/ggml-common.h @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * @@ -45,7 +45,11 @@ typedef half2 ggml_half2; #define GGML_COMMON_DECL #elif defined(GGML_COMMON_DECL_CUDA) +#if defined(GGML_COMMON_DECL_MUSA) +#include +#else #include +#endif #include typedef half ggml_half; @@ -132,19 +136,19 @@ typedef sycl::half2 ggml_half2; #define QR6_K 2 #define QI2_XXS (QK_K / (4*QR2_XXS)) -#define QR2_XXS 8 +#define QR2_XXS 4 #define QI2_XS (QK_K / (4*QR2_XS)) -#define QR2_XS 8 +#define QR2_XS 4 #define QI2_S (QK_K / (4*QR2_S)) -#define QR2_S 8 +#define QR2_S 4 #define QI3_XXS (QK_K / (4*QR3_XXS)) -#define QR3_XXS 8 +#define QR3_XXS 4 #define QI3_XS (QK_K / (4*QR3_XS)) -#define QR3_XS 8 +#define QR3_XS 4 #define QI1_S (QK_K / (4*QR1_S)) #define QR1_S 8 @@ -156,10 +160,10 @@ typedef sycl::half2 ggml_half2; #define QR4_NL 2 #define QI4_XS (QK_K / (4*QR4_XS)) -#define QR4_XS 8 +#define QR4_XS 2 #define QI3_S (QK_K / (4*QR3_S)) -#define QR3_S 8 +#define QR3_S 4 #endif // GGML_COMMON_DECL_CUDA || GGML_COMMON_DECL_HIP @@ -225,6 +229,30 @@ typedef struct { } block_q8_1; static_assert(sizeof(block_q8_1) == 2*sizeof(ggml_half) + QK8_1, "wrong q8_1 block size/padding"); +typedef struct { + ggml_half d[4]; // deltas for 4 q4_0 blocks + uint8_t qs[QK4_0 * 2]; // nibbles / quants for 4 q4_0 blocks +} block_q4_0x4; +static_assert(sizeof(block_q4_0x4) == 4 * sizeof(ggml_half) + QK4_0 * 2, "wrong q4_0x4 block size/padding"); + +typedef struct { + ggml_half d[8]; // deltas for 8 q4_0 blocks + uint8_t qs[QK4_0 * 4]; // nibbles / quants for 8 q4_0 blocks +} block_q4_0x8; +static_assert(sizeof(block_q4_0x8) == 8 * sizeof(ggml_half) + QK4_0 * 4, "wrong q4_0x8 block size/padding"); + +typedef struct { + ggml_half d[4]; // deltas for 4 q8_0 blocks + int8_t qs[QK8_0 * 4]; // quants for 4 q8_0 blocks +} block_q8_0x4; +static_assert(sizeof(block_q8_0x4) == 4 * sizeof(ggml_half) + QK8_0 * 4, "wrong q8_0x4 block size/padding"); + +typedef struct { + ggml_half d[8]; // deltas for 8 q8_0 blocks + int8_t qs[QK8_0 * 8]; // quants for 8 q8_0 blocks +} block_q8_0x8; +static_assert(sizeof(block_q8_0x8) == 8 * sizeof(ggml_half) + QK8_0 * 8, "wrong q8_0x8 block size/padding"); + // // Super-block quantization structures // @@ -417,7 +445,7 @@ static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_ #define GGML_TABLE_END() }; #define GGML_COMMON_IMPL -#elif defined(GGML_COMMON_IMPL_CUDA) || defined(GGML_COMMON_IMPL_HIP) +#elif defined(GGML_COMMON_IMPL_CUDA) || defined(GGML_COMMON_IMPL_HIP) || defined(GGML_COMMON_IMPL_MUSA) #include #define GGML_TABLE_BEGIN(type, name, size) static const __device__ type name[size] = { diff --git a/llama/ggml-cuda.cu b/llama/ggml-cuda.cu index 29f6c756..a3341229 100644 --- a/llama/ggml-cuda.cu +++ b/llama/ggml-cuda.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * @@ -55,6 +55,7 @@ #include "ggml-cuda/tsembd.cuh" #include "ggml-cuda/unary.cuh" #include "ggml-cuda/upscale.cuh" +#include "ggml-cuda/conv-transpose-1d.cuh" #include #include @@ -123,7 +124,7 @@ void ggml_cuda_error(const char * stmt, const char * func, const char * file, in GGML_CUDA_LOG_ERROR(" current device: %d, in function %s at %s:%d\n", id, func, file, line); GGML_CUDA_LOG_ERROR(" %s\n", stmt); // abort with GGML_ASSERT to get a stack trace - GGML_ASSERT(!"CUDA error"); + GGML_ABORT("CUDA error"); } // this is faster on Windows @@ -178,21 +179,21 @@ static ggml_cuda_device_info ggml_cuda_init() { GGML_ASSERT(info.device_count <= GGML_CUDA_MAX_DEVICES); int64_t total_vram = 0; -#if defined(GGML_CUDA_FORCE_MMQ) - GGML_CUDA_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ: yes\n", __func__); +#ifdef GGML_CUDA_FORCE_MMQ + GGML_CUDA_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ: yes\n", __func__); #else - GGML_CUDA_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ: no\n", __func__); -#endif -#if defined(CUDA_USE_TENSOR_CORES) - GGML_CUDA_LOG_INFO("%s: CUDA_USE_TENSOR_CORES: yes\n", __func__); + GGML_CUDA_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ: no\n", __func__); +#endif // GGML_CUDA_FORCE_MMQ +#ifdef GGML_CUDA_FORCE_CUBLAS + GGML_CUDA_LOG_INFO("%s: GGML_CUDA_FORCE_CUBLAS: yes\n", __func__); #else - GGML_CUDA_LOG_INFO("%s: CUDA_USE_TENSOR_CORES: no\n", __func__); -#endif + GGML_CUDA_LOG_INFO("%s: GGML_CUDA_FORCE_CUBLAS: no\n", __func__); +#endif // GGML_CUDA_FORCE_CUBLAS GGML_CUDA_LOG_INFO("%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, info.device_count); for (int id = 0; id < info.device_count; ++id) { int device_vmm = 0; -#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) +#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA) CUdevice device; CU_CHECK(cuDeviceGet(&device, id)); CU_CHECK(cuDeviceGetAttribute(&device_vmm, CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED, device)); @@ -204,7 +205,7 @@ static ggml_cuda_device_info ggml_cuda_init() { alloc_prop.location.id = id; CU_CHECK(cuMemGetAllocationGranularity(&info.devices[id].vmm_granularity, &alloc_prop, CU_MEM_ALLOC_GRANULARITY_RECOMMENDED)); } -#endif // !defined(GGML_USE_HIPBLAS) +#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA) info.devices[id].vmm = !!device_vmm; cudaDeviceProp prop; @@ -214,13 +215,15 @@ static ggml_cuda_device_info ggml_cuda_init() { info.default_tensor_split[id] = total_vram; total_vram += prop.totalGlobalMem; + info.devices[id].nsm = prop.multiProcessorCount; + info.devices[id].smpb = prop.sharedMemPerBlock; #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) + info.devices[id].smpbo = prop.sharedMemPerBlock; info.devices[id].cc = 100*prop.major + 10*prop.minor + CC_OFFSET_AMD; #else + info.devices[id].smpbo = prop.sharedMemPerBlockOptin; info.devices[id].cc = 100*prop.major + 10*prop.minor; #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) - info.devices[id].smpb = prop.sharedMemPerBlock; - info.devices[id].nsm = prop.multiProcessorCount; } for (int id = 0; id < info.device_count; ++id) { @@ -338,7 +341,7 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool { }; // pool with virtual memory -#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) +#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA) struct ggml_cuda_pool_vmm : public ggml_cuda_pool { static const size_t CUDA_POOL_VMM_MAX_SIZE = 1ull << 35; // 32 GB @@ -432,14 +435,14 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool { GGML_ASSERT(ptr == (void *) (pool_addr + pool_used)); } }; -#endif // !defined(GGML_USE_HIPBLAS) +#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA) std::unique_ptr ggml_backend_cuda_context::new_pool_for_device(int device) { -#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) +#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA) if (ggml_cuda_info().devices[device].vmm) { return std::unique_ptr(new ggml_cuda_pool_vmm(device)); } -#endif +#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA) return std::unique_ptr(new ggml_cuda_pool_leg(device)); } @@ -491,12 +494,12 @@ GGML_CALL static void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t return; } - if (ggml_is_quantized(tensor->type)) { + if (ggml_is_quantized(tensor->type) && tensor->view_src == nullptr && ggml_backend_buffer_get_usage(buffer) != GGML_BACKEND_BUFFER_USAGE_COMPUTE) { // initialize padding to 0 to avoid possible NaN values size_t original_size = ggml_nbytes(tensor); size_t padded_size = ggml_backend_buft_get_alloc_size(buffer->buft, tensor); - if (padded_size > original_size && tensor->view_src == nullptr) { + if (padded_size > original_size) { ggml_cuda_set_device(ctx->device); CUDA_CHECK(cudaMemset((char *)tensor->data + original_size, 0, padded_size - original_size)); } @@ -573,6 +576,10 @@ GGML_CALL static const char * ggml_backend_cuda_buffer_type_name(ggml_backend_bu return ctx->name.c_str(); } +static bool ggml_backend_buft_is_cuda(ggml_backend_buffer_type_t buft) { + return buft->iface.get_name == ggml_backend_cuda_buffer_type_name; +} + GGML_CALL static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { ggml_backend_cuda_buffer_type_context * buft_ctx = (ggml_backend_cuda_buffer_type_context *)buft->context; @@ -615,24 +622,12 @@ GGML_CALL static size_t ggml_backend_cuda_buffer_type_get_alloc_size(ggml_backen GGML_UNUSED(buft); } -GGML_CALL static bool ggml_backend_cuda_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) { - if (!ggml_backend_is_cuda(backend)) { - return false; - } - - ggml_backend_cuda_buffer_type_context * buft_ctx = (ggml_backend_cuda_buffer_type_context *)buft->context; - ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; - - return buft_ctx->device == cuda_ctx->device; -} - static ggml_backend_buffer_type_i ggml_backend_cuda_buffer_type_interface = { /* .get_name = */ ggml_backend_cuda_buffer_type_name, /* .alloc_buffer = */ ggml_backend_cuda_buffer_type_alloc_buffer, /* .get_alignment = */ ggml_backend_cuda_buffer_type_get_alignment, /* .get_max_size = */ NULL, // defaults to SIZE_MAX /* .get_alloc_size = */ ggml_backend_cuda_buffer_type_get_alloc_size, - /* .supports_backend = */ ggml_backend_cuda_buffer_type_supports_backend, /* .is_host = */ NULL, }; @@ -671,7 +666,7 @@ static int64_t get_row_rounding(const std::array & } const int cc = ggml_cuda_info().devices[id].cc; - row_rounding = std::max(row_rounding, (int64_t)get_mmq_y_host(cc, get_mmq_x_max_host(cc))); + row_rounding = std::max(row_rounding, (int64_t)get_mmq_y_host(cc)); } return row_rounding; } @@ -893,6 +888,10 @@ GGML_CALL static const char * ggml_backend_cuda_split_buffer_type_name(ggml_back GGML_UNUSED(buft); } +static bool ggml_backend_buft_is_cuda_split(ggml_backend_buffer_type_t buft) { + return buft->iface.get_name == ggml_backend_cuda_split_buffer_type_name; +} + GGML_CALL static ggml_backend_buffer_t ggml_backend_cuda_split_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { // since we don't know the exact split after rounding, we cannot allocate the device buffers at this point // instead, we allocate them for each tensor separately in init_tensor @@ -936,12 +935,6 @@ GGML_CALL static size_t ggml_backend_cuda_split_buffer_type_get_alloc_size(ggml_ return total_size; } -GGML_CALL static bool ggml_backend_cuda_split_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) { - return ggml_backend_is_cuda(backend); - - GGML_UNUSED(buft); -} - GGML_CALL static bool ggml_backend_cuda_split_buffer_type_is_host(ggml_backend_buffer_type_t buft) { return false; @@ -954,7 +947,6 @@ static ggml_backend_buffer_type_i ggml_backend_cuda_split_buffer_type_interface /* .get_alignment = */ ggml_backend_cuda_split_buffer_type_get_alignment, /* .get_max_size = */ NULL, // defaults to SIZE_MAX /* .get_alloc_size = */ ggml_backend_cuda_split_buffer_type_get_alloc_size, - /* .supports_backend = */ ggml_backend_cuda_split_buffer_type_supports_backend, /* .is_host = */ ggml_backend_cuda_split_buffer_type_is_host, }; @@ -1054,7 +1046,6 @@ GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_host_buffer_type() { /* .get_alignment = */ ggml_backend_cpu_buffer_type()->iface.get_alignment, /* .get_max_size = */ NULL, // defaults to SIZE_MAX /* .get_alloc_size = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size, - /* .supports_backend = */ ggml_backend_cpu_buffer_type()->iface.supports_backend, /* .is_host = */ ggml_backend_cpu_buffer_type()->iface.is_host, }, /* .context = */ nullptr, @@ -1377,10 +1368,30 @@ static void ggml_cuda_set_peer_access(const int n_tokens, int main_device) { GGML_UNUSED(main_device); } +static cudaError_t ggml_cuda_Memcpy2DPeerAsync( + void * dst, int dstDevice, size_t dpitch, void * src, int srcDevice, size_t spitch, size_t width, size_t height, cudaStream_t stream) { + +#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA) + // cudaMemcpy2DAsync may fail with copies between vmm pools of different devices + cudaMemcpy3DPeerParms p = {}; + p.dstDevice = dstDevice; + p.dstPtr = make_cudaPitchedPtr(dst, dpitch, dpitch, height); + p.srcDevice = srcDevice; + p.srcPtr = make_cudaPitchedPtr(src, spitch, spitch, height); + p.extent = make_cudaExtent(width, height, 1); + return cudaMemcpy3DPeerAsync(&p, stream); +#else + // HIP does not support cudaMemcpy3DPeerAsync or vmm pools + GGML_UNUSED(dstDevice); + GGML_UNUSED(srcDevice); + return cudaMemcpy2DAsync(dst, dpitch, src, spitch, width, height, cudaMemcpyDeviceToDevice, stream); +#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA) +} + static void ggml_cuda_op_mul_mat( ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, ggml_cuda_op_mul_mat_t op, - const bool convert_src1_to_q8_1) { + quantize_cuda_t quantize_src1) { const int64_t ne00 = src0->ne[0]; const int64_t ne01 = src0->ne[1]; @@ -1437,7 +1448,9 @@ static void ggml_cuda_op_mul_mat( } struct dev_data { - ggml_cuda_pool_alloc src0_dd_alloc; + int cc; + + ggml_cuda_pool_alloc src0_dd_alloc; ggml_cuda_pool_alloc src1_ddf_alloc; ggml_cuda_pool_alloc src1_ddq_alloc; ggml_cuda_pool_alloc dst_dd_alloc; @@ -1456,6 +1469,8 @@ static void ggml_cuda_op_mul_mat( int used_devices = 0; for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) { + dev[id].cc = ggml_cuda_info().devices[id].cc; + // by default, use all rows dev[id].row_low = 0; dev[id].row_high = ne01; @@ -1500,17 +1515,28 @@ static void ggml_cuda_op_mul_mat( dev[id].src0_dd = dev[id].src0_dd_alloc.alloc(ctx.pool(id), ggml_nbytes(src0)); } + // If src0 is on a temporary compute buffers (partial offloading) there may be some padding that needs to be cleared: + if (ne00 % MATRIX_ROW_PADDING != 0 && ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE && src0->view_src == nullptr) { + const int64_t nbytes_data = ggml_row_size(src0->type, (dev[id].row_high - dev[id].row_low)*ne00); + const int64_t nbytes_padding = ggml_row_size(src0->type, MATRIX_ROW_PADDING - ne00 % MATRIX_ROW_PADDING); + CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd + nbytes_data , 0, nbytes_padding, stream)); + } + if (src1_on_device && src1_is_contiguous) { dev[id].src1_ddf = (float *) src1->data; } else { dev[id].src1_ddf = dev[id].src1_ddf_alloc.alloc(ctx.pool(id), ggml_nelements(src1)); } - if (convert_src1_to_q8_1) { - dev[id].src1_ddq = dev[id].src1_ddq_alloc.alloc(ctx.pool(id), nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs); + if (quantize_src1) { + size_t src_1_ddq_size = nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs; + if (quantize_src1 == quantize_mmq_q8_1_cuda) { + src_1_ddq_size += get_mmq_x_max_host(dev[id].cc)*sizeof(block_q8_1_mmq); + } + dev[id].src1_ddq = dev[id].src1_ddq_alloc.alloc(ctx.pool(id), src_1_ddq_size); if (src1_on_device && src1_is_contiguous) { - quantize_row_q8_1_cuda(dev[id].src1_ddf, dev[id].src1_ddq, ne10, nrows1, src1_padded_col_size, stream); + quantize_src1(dev[id].src1_ddf, dev[id].src1_ddq, ne10, ne11, ne12*ne13, src1_padded_col_size, src0->type, stream); CUDA_CHECK(cudaGetLastError()); } } @@ -1556,7 +1582,12 @@ static void ggml_cuda_op_mul_mat( const int64_t i03 = i0 / ne12; const int64_t i02 = i0 % ne12; - const size_t src1_ddq_i_offset = (i0*ne11 + src1_col_0) * src1_padded_col_size*q8_1_ts/q8_1_bs; + size_t src1_ddq_i_offset = i0*ne11 * src1_padded_col_size*q8_1_ts/q8_1_bs; + if (quantize_src1 == quantize_mmq_q8_1_cuda) { + src1_ddq_i_offset += src1_col_0 * sizeof(block_q8_1_mmq); + } else { + src1_ddq_i_offset += src1_col_0 * src1_padded_col_size*q8_1_ts/q8_1_bs; + } // for split tensors the data begins at i0 == i0_offset_low char * src0_dd_i = dev[id].src0_dd + (i0/i02_divisor) * (ne01*ne00*src0_ts)/src0_bs; @@ -1573,10 +1604,17 @@ static void ggml_cuda_op_mul_mat( // copy src0, src1 to device if necessary if (src1_is_contiguous) { if (id != ctx.device) { - if (convert_src1_to_q8_1) { + if (quantize_src1) { char * src1_ddq_i_source = dev[ctx.device].src1_ddq + src1_ddq_i_offset; - CUDA_CHECK(cudaMemcpyPeerAsync(src1_ddq_i, id, src1_ddq_i_source, ctx.device, - src1_ncols*src1_padded_col_size*q8_1_ts/q8_1_bs, stream)); + if (quantize_src1 == quantize_mmq_q8_1_cuda) { + const size_t pitch = ne11*sizeof(block_q8_1_mmq); + const size_t width = src1_ncols*sizeof(block_q8_1_mmq); + const size_t height = src1_padded_col_size/(4*QK8_1); + CUDA_CHECK(ggml_cuda_Memcpy2DPeerAsync(src1_ddq_i, id, pitch, src1_ddq_i_source, ctx.device, pitch, width, height, stream)); + } else { + CUDA_CHECK(cudaMemcpyPeerAsync( + src1_ddq_i, id, src1_ddq_i_source, ctx.device, src1_ncols*src1_padded_col_size*q8_1_ts/q8_1_bs, stream)); + } } else { float * src1_ddf_i_source = (float *) src1->data; src1_ddf_i_source += (i0*ne11 + src1_col_0) * ne10; @@ -1588,11 +1626,11 @@ static void ggml_cuda_op_mul_mat( CUDA_CHECK(ggml_cuda_cpy_tensor_2d( src1_ddf_i, src1, i03, i02, src1_col_0, src1_col_0+src1_ncols, stream)); } else { - GGML_ASSERT(false); + GGML_ABORT("fatal error"); } - if (convert_src1_to_q8_1 && !src1_is_contiguous) { - quantize_row_q8_1_cuda(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, src1_padded_col_size, stream); + if (quantize_src1 && !src1_is_contiguous) { + quantize_src1(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, 1, src1_padded_col_size, src0->type, stream); CUDA_CHECK(cudaGetLastError()); } @@ -1617,22 +1655,8 @@ static void ggml_cuda_op_mul_mat( float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3); GGML_ASSERT(dst->nb[1] == ne0*sizeof(float)); dhf_dst_i += src1_col_0*ne0 + dev[id].row_low; -#if !defined(GGML_USE_HIPBLAS) - // cudaMemcpy2DAsync may fail with copies between vmm pools of different devices - cudaMemcpy3DPeerParms p = {}; - p.dstDevice = ctx.device; - p.dstPtr = make_cudaPitchedPtr(dhf_dst_i, ne0*sizeof(float), row_diff, src1_ncols); - p.srcDevice = id; - p.srcPtr = make_cudaPitchedPtr(dst_dd_i, row_diff*sizeof(float), row_diff, src1_ncols); - p.extent = make_cudaExtent(row_diff*sizeof(float), src1_ncols, 1); - CUDA_CHECK(cudaMemcpy3DPeerAsync(&p, stream)); -#else - // HIP does not support cudaMemcpy3DPeerAsync or vmm pools - CUDA_CHECK(cudaMemcpy2DAsync(dhf_dst_i, ne0*sizeof(float), - dst_dd_i, row_diff*sizeof(float), - row_diff*sizeof(float), src1_ncols, - cudaMemcpyDeviceToDevice, stream)); -#endif + CUDA_CHECK(ggml_cuda_Memcpy2DPeerAsync( + dhf_dst_i, ctx.device, ne0*sizeof(float), dst_dd_i, id, row_diff*sizeof(float), row_diff*sizeof(float), src1_ncols, stream)); } else { float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3); GGML_ASSERT(dst->nb[1] == ne0*sizeof(float)); @@ -1834,6 +1858,9 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co } } #else +#ifdef GGML_USE_MUSA + GGML_ASSERT(false); +#else // !GGML_USE_MUSA if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) { // there is no broadcast and src0, src1 are contiguous across dims 2, 3 // use cublasGemmStridedBatchedEx @@ -1876,6 +1903,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co cu_compute_type, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); } +#endif // GGML_USE_MUSA #endif if (dst->op_params[0] == GGML_PREC_DEFAULT) { @@ -1887,9 +1915,23 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { const bool split = ggml_backend_buffer_is_cuda_split(src0->buffer); - int64_t min_compute_capability = INT_MAX; + bool use_dequantize_mul_mat_vec = (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) + && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 + && src0->ne[0] % GGML_CUDA_DMMV_X == 0 && src0->ne[0] >= GGML_CUDA_DMMV_X*2 + && src1->ne[1] == 1; + bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) + && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 + && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE; + bool use_mul_mat_q = ggml_is_quantized(src0->type) + && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32; + + // if mmvq is available it's a better choice than dmmv: +#ifndef GGML_CUDA_FORCE_DMMV + use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q; +#endif // GGML_CUDA_FORCE_DMMV + + bool any_gpus_with_slow_fp16 = false; - bool any_pascal_with_slow_fp16 = false; if (split) { ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) src0->buffer->buft->context; auto & tensor_split = buft_ctx->tensor_split; @@ -1899,60 +1941,16 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor continue; } - if (min_compute_capability > ggml_cuda_info().devices[id].cc) { - min_compute_capability = ggml_cuda_info().devices[id].cc; - } - if (ggml_cuda_info().devices[id].cc == 610) { - any_pascal_with_slow_fp16 = true; - } + const int cc = ggml_cuda_info().devices[id].cc; + use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]); + any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_available(cc); } } else { - min_compute_capability = ggml_cuda_info().devices[ctx.device].cc; - any_pascal_with_slow_fp16 = ggml_cuda_info().devices[ctx.device].cc == 610; + const int cc = ggml_cuda_info().devices[ctx.device].cc; + use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]); + any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_available(cc); } - // check data types and tensor shapes for custom matrix multiplication kernels: - bool use_dequantize_mul_mat_vec = (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) - && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 - && src0->ne[0] % GGML_CUDA_DMMV_X == 0 && src1->ne[1] == 1; - - bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) - && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 - && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE; - - bool use_mul_mat_q = ggml_cuda_supports_mmq(src0->type) - && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32; - -#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) - - const bool fp16_performance_good = min_compute_capability >= CC_RDNA1; - -#ifdef CUDA_USE_TENSOR_CORES - use_mul_mat_q = use_mul_mat_q && min_compute_capability < CC_RDNA3; -#endif // CUDA_USE_TENSOR_CORES - -#else - - // fp16 performance is good on Volta or newer and on P100 (compute capability 6.0) - const bool fp16_performance_good = min_compute_capability >= CC_PASCAL && !any_pascal_with_slow_fp16; - - // mmvq and mmq need the __dp4a instruction which on NVIDIA is only available for CC >= 6.1 - use_mul_mat_vec_q = use_mul_mat_vec_q && min_compute_capability >= MIN_CC_DP4A; - use_mul_mat_q = use_mul_mat_q && min_compute_capability >= MIN_CC_DP4A; - -#ifdef CUDA_USE_TENSOR_CORES - // when tensor cores are available, use them for large batch size - // ref: https://github.com/ggerganov/llama.cpp/pull/3776 - use_mul_mat_q = use_mul_mat_q && (!fp16_performance_good || src1->ne[1] <= MMQ_MAX_BATCH_SIZE); -#endif // CUDA_USE_TENSOR_CORES - -#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) - - // if mmvq is available it's a better choice than dmmv: -#ifndef GGML_CUDA_FORCE_DMMV - use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q; -#endif // GGML_CUDA_FORCE_DMMV - // debug helpers //printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]); //printf(" %8d %8d %8d %8d\n", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]); @@ -1961,23 +1959,24 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name); //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name); - if (!split && !fp16_performance_good && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) { - // KQ single-batch + if (!split && any_gpus_with_slow_fp16 && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) { + // FP32 precision KQ single-batch for batch size 1 without FlashAttention ggml_cuda_mul_mat_vec_p021(ctx, src0, src1, dst); - } else if (!split && !fp16_performance_good && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) { - // KQV single-batch + } else if (!split && any_gpus_with_slow_fp16 && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) { + // FP32 precision KQV single-batch for batch size 1 without FlashAttention ggml_cuda_mul_mat_vec_nc(ctx, src0, src1, dst); - } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || fp16_performance_good) && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) { - // KQ + KQV multi-batch + } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16) + && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) { + // KQ + KQV multi-batch without FlashAttention ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst); } else if (use_dequantize_mul_mat_vec) { - ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false); + ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, nullptr); } else if (use_mul_mat_vec_q) { - ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, true); + ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, quantize_row_q8_1_cuda); } else if (use_mul_mat_q) { - ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_q, true); + ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_q, quantize_mmq_q8_1_cuda); } else { - ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false); + ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_cublas, nullptr); } } @@ -2281,6 +2280,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_SQR: ggml_cuda_op_sqr(ctx, dst); break; + case GGML_OP_SQRT: + ggml_cuda_op_sqrt(ctx, dst); + break; case GGML_OP_CLAMP: ggml_cuda_op_clamp(ctx, dst); break; @@ -2302,6 +2304,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_IM2COL: ggml_cuda_op_im2col(ctx, dst); break; + case GGML_OP_CONV_TRANSPOSE_1D: + ggml_cuda_op_conv_transpose_1d(ctx,dst); + break; case GGML_OP_POOL_2D: ggml_cuda_op_pool2d(ctx, dst); break; @@ -2744,7 +2749,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_UNARY_OP_HARDSWISH: case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_TANH: - return true; + return ggml_is_contiguous(op->src[0]); default: return false; } @@ -2752,27 +2757,40 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: { - struct ggml_tensor * a; - struct ggml_tensor * b; + struct ggml_tensor * a = op->src[0]; if (op->op == GGML_OP_MUL_MAT) { - a = op->src[0]; - b = op->src[1]; - } else { - a = op->src[2]; - b = op->src[1]; - } - if (a->ne[3] != b->ne[3]) { - return false; - } - ggml_type a_type = a->type; - if (a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS || a_type == GGML_TYPE_IQ3_XXS || - a_type == GGML_TYPE_IQ1_S || a_type == GGML_TYPE_IQ4_NL || a_type == GGML_TYPE_IQ3_S || - a_type == GGML_TYPE_IQ1_M || a_type == GGML_TYPE_IQ2_S || a_type == GGML_TYPE_IQ4_XS) { - if (b->ne[1] == 1 && ggml_nrows(b) > 1) { + struct ggml_tensor * b = op->src[1]; + if (a->ne[3] != b->ne[3]) { return false; } } - return true; + switch (a->type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_Q8_K: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ4_XS: + return true; + default: + return false; + } } break; case GGML_OP_GET_ROWS: { @@ -2832,6 +2850,15 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons ggml_type src0_type = op->src[0]->type; return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16; } break; + case GGML_OP_CONV_TRANSPOSE_1D: + { + ggml_type src0_type = op->src[0]->type; + ggml_type src1_type = op->src[1]->type; + if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) { + return true; + } + return false; + } break; case GGML_OP_NONE: case GGML_OP_RESHAPE: case GGML_OP_VIEW: @@ -2844,6 +2871,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_OP_RMS_NORM: case GGML_OP_SCALE: case GGML_OP_SQR: + case GGML_OP_SQRT: case GGML_OP_CLAMP: case GGML_OP_CONT: case GGML_OP_DIAG_MASK_INF: @@ -2883,6 +2911,20 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons GGML_UNUSED(backend); } +GGML_CALL static bool ggml_backend_cuda_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) { + if (ggml_backend_buft_is_cuda_split(buft)) { + return true; + } + + if (ggml_backend_buft_is_cuda(buft)) { + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; + ggml_backend_cuda_buffer_type_context * buft_ctx = (ggml_backend_cuda_buffer_type_context *)buft->context; + return buft_ctx->device == cuda_ctx->device; + } + + return false; +} + GGML_CALL static bool ggml_backend_cuda_offload_op(ggml_backend_t backend, const ggml_tensor * op) { const int min_batch_size = 32; @@ -2937,7 +2979,7 @@ static void ggml_backend_cuda_event_wait(ggml_backend_t backend, ggml_backend_ev CUDA_CHECK(cudaLaunchHostFunc(cuda_ctx->stream(), wait_fn, event)); #endif - GGML_ASSERT(false); + GGML_ABORT("fatal error"); } } @@ -2955,9 +2997,11 @@ static ggml_backend_i ggml_backend_cuda_interface = { /* .synchronize = */ ggml_backend_cuda_synchronize, /* .graph_plan_create = */ NULL, /* .graph_plan_free = */ NULL, + /* .graph_plan_update = */ NULL, /* .graph_plan_compute = */ NULL, /* .graph_compute = */ ggml_backend_cuda_graph_compute, /* .supports_op = */ ggml_backend_cuda_supports_op, + /* .supports_buft = */ ggml_backend_cuda_supports_buft, /* .offload_op = */ ggml_backend_cuda_offload_op, /* .event_new = */ ggml_backend_cuda_event_new, /* .event_free = */ ggml_backend_cuda_event_free, @@ -3017,7 +3061,7 @@ GGML_CALL bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size return false; } -#if CUDART_VERSION >= 11100 +#if CUDART_VERSION >= 11100 || defined(GGML_USE_MUSA) cudaError_t err = cudaHostRegister(buffer, size, cudaHostRegisterPortable | cudaHostRegisterReadOnly); if (err != cudaSuccess) { // clear the error diff --git a/llama/ggml-cuda.h b/llama/ggml-cuda.h index 9f2b1c88..fce52bf9 100644 --- a/llama/ggml-cuda.h +++ b/llama/ggml-cuda.h @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * @@ -32,6 +32,9 @@ #ifdef GGML_USE_HIPBLAS #define GGML_CUDA_NAME "ROCm" #define GGML_CUBLAS_NAME "hipBLAS" +#elif defined(GGML_USE_MUSA) +#define GGML_CUDA_NAME "MUSA" +#define GGML_CUBLAS_NAME "muBLAS" #else #define GGML_CUDA_NAME "CUDA" #define GGML_CUBLAS_NAME "cuBLAS" diff --git a/llama/ggml-cuda/acc.cu b/llama/ggml-cuda/acc.cu index e6e49958..0f55c157 100644 --- a/llama/ggml-cuda/acc.cu +++ b/llama/ggml-cuda/acc.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/acc.cuh b/llama/ggml-cuda/acc.cuh index acddd4b8..519c95c8 100644 --- a/llama/ggml-cuda/acc.cuh +++ b/llama/ggml-cuda/acc.cuh @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/alibi.cu b/llama/ggml-cuda/alibi.cu index 0ff2f9f3..35d276b5 100644 --- a/llama/ggml-cuda/alibi.cu +++ b/llama/ggml-cuda/alibi.cu @@ -1,83 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file - * - * MIT License - * - * Copyright (c) 2023-2024 The ggml authors - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ - -/** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file - * - * MIT License - * - * Copyright (c) 2023-2024 The ggml authors - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ - -/** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file - * - * MIT License - * - * Copyright (c) 2023-2024 The ggml authors - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ - -/** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/alibi.cuh b/llama/ggml-cuda/alibi.cuh index 6087d0b0..0d6a3440 100644 --- a/llama/ggml-cuda/alibi.cuh +++ b/llama/ggml-cuda/alibi.cuh @@ -1,57 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file - * - * MIT License - * - * Copyright (c) 2023-2024 The ggml authors - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ - -/** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file - * - * MIT License - * - * Copyright (c) 2023-2024 The ggml authors - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ - -/** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/arange.cu b/llama/ggml-cuda/arange.cu index 598620d0..514c146e 100644 --- a/llama/ggml-cuda/arange.cu +++ b/llama/ggml-cuda/arange.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/arange.cuh b/llama/ggml-cuda/arange.cuh index db1650db..f1d8acc2 100644 --- a/llama/ggml-cuda/arange.cuh +++ b/llama/ggml-cuda/arange.cuh @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/argsort.cu b/llama/ggml-cuda/argsort.cu index d7972d58..1987e87f 100644 --- a/llama/ggml-cuda/argsort.cu +++ b/llama/ggml-cuda/argsort.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * @@ -99,6 +99,7 @@ static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, co const dim3 block_nums(1, nrows, 1); const size_t shared_mem = ncols_pad * sizeof(int); + // FIXME: this limit could be raised by ~2-4x on Ampere or newer GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb); if (order == GGML_SORT_ORDER_ASC) { @@ -106,7 +107,7 @@ static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, co } else if (order == GGML_SORT_ORDER_DESC) { k_argsort_f32_i32<<>>(x, dst, ncols, ncols_pad); } else { - GGML_ASSERT(false); + GGML_ABORT("fatal error"); } } diff --git a/llama/ggml-cuda/argsort.cuh b/llama/ggml-cuda/argsort.cuh index 7ae91c17..9189815c 100644 --- a/llama/ggml-cuda/argsort.cuh +++ b/llama/ggml-cuda/argsort.cuh @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/binbcast.cu b/llama/ggml-cuda/binbcast.cu index 3e1c6585..df396eb2 100644 --- a/llama/ggml-cuda/binbcast.cu +++ b/llama/ggml-cuda/binbcast.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * @@ -285,7 +285,7 @@ static void ggml_cuda_op_bin_bcast( } else { fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__, ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type)); - GGML_ASSERT(false); + GGML_ABORT("fatal error"); } } diff --git a/llama/ggml-cuda/binbcast.cuh b/llama/ggml-cuda/binbcast.cuh index b5cbd086..e6a48196 100644 --- a/llama/ggml-cuda/binbcast.cuh +++ b/llama/ggml-cuda/binbcast.cuh @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/clamp.cu b/llama/ggml-cuda/clamp.cu index 97b498be..844cb913 100644 --- a/llama/ggml-cuda/clamp.cu +++ b/llama/ggml-cuda/clamp.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/clamp.cuh b/llama/ggml-cuda/clamp.cuh index 9d94460c..2d25cb00 100644 --- a/llama/ggml-cuda/clamp.cuh +++ b/llama/ggml-cuda/clamp.cuh @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/common.cuh b/llama/ggml-cuda/common.cuh index e465649f..dbd07204 100644 --- a/llama/ggml-cuda/common.cuh +++ b/llama/ggml-cuda/common.cuh @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * @@ -29,6 +29,7 @@ #include "ggml.h" #include "ggml-cuda.h" +#include #include #if defined(GGML_USE_HIPBLAS) @@ -37,6 +38,10 @@ #else #define GGML_COMMON_DECL_CUDA #define GGML_COMMON_IMPL_CUDA +#if defined(GGML_USE_MUSA) +#define GGML_COMMON_DECL_MUSA +#define GGML_COMMON_IMPL_MUSA +#endif #endif #include "ggml-common.h" @@ -129,7 +134,7 @@ #define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags) #define cudaStream_t hipStream_t #define cudaSuccess hipSuccess -#define __trap abort +#define __trap() do { abort(); __builtin_unreachable(); } while(0) #define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS #define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED #define CUBLAS_STATUS_ALLOC_FAILED HIPBLAS_STATUS_ALLOC_FAILED @@ -139,6 +144,150 @@ #define CUBLAS_STATUS_EXECUTION_FAILED HIPBLAS_STATUS_EXECUTION_FAILED #define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR #define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED +#elif defined(GGML_USE_MUSA) +#include +#include +#include +#include +// XXX: Keep the following order the same as hipBLAS +// #define CUBLAS_COMPUTE_16F MUBLAS_COMPUTE_16F +// #define CUBLAS_COMPUTE_32F MUBLAS_COMPUTE_32F +#define CUBLAS_COMPUTE_32F_FAST_16F MUBLAS_COMPUTE_32F_FAST_16F +#define CUBLAS_GEMM_DEFAULT MUBLAS_GEMM_DEFAULT +#define CUBLAS_GEMM_DEFAULT_TENSOR_OP MUBLAS_GEMM_DEFAULT +#define CUBLAS_OP_N MUBLAS_OP_N +#define CUBLAS_OP_T MUBLAS_OP_T +#define CUBLAS_STATUS_SUCCESS MUBLAS_STATUS_SUCCESS +// #define CUBLAS_TF32_TENSOR_OP_MATH 0 +#define CUDA_R_16F MUSA_R_16F +#define CUDA_R_32F MUSA_R_32F +// #define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width) +// #define cublasComputeType_t mublasComputeType_t +#define cublasCreate mublasCreate +#define cublasDestroy mublasDestroy +#define cublasGemmEx mublasGemmEx +#define cublasGemmBatchedEx mublasGemmBatchedEx +#define cublasGemmStridedBatchedEx mublasGemmStridedBatchedEx +#define cublasHandle_t mublasHandle_t +// #define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS +#define cublasSetMathMode mublasSetMathMode +#define cublasSetStream mublasSetStream +#define cublasSgemm mublasSgemm +#define cublasStatus_t mublasStatus_t +#define cudaDataType_t musaDataType_t //deprecated, new hipblasDatatype not in 5.6 +#define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer +#define cudaDeviceDisablePeerAccess musaDeviceDisablePeerAccess +#define cudaDeviceEnablePeerAccess musaDeviceEnablePeerAccess +#define cudaDeviceProp musaDeviceProp +#define cudaDeviceSynchronize musaDeviceSynchronize +#define cudaError_t musaError_t +#define cudaErrorPeerAccessAlreadyEnabled musaErrorPeerAccessAlreadyEnabled +#define cudaErrorPeerAccessNotEnabled musaErrorPeerAccessNotEnabled +#define cudaEventCreateWithFlags musaEventCreateWithFlags +#define cudaEventDisableTiming musaEventDisableTiming +#define cudaEventRecord musaEventRecord +#define cudaEventSynchronize musaEventSynchronize +#define cudaEvent_t musaEvent_t +#define cudaEventDestroy musaEventDestroy +#define cudaFree musaFree +#define cudaFreeHost musaFreeHost +#define cudaGetDevice musaGetDevice +#define cudaGetDeviceCount musaGetDeviceCount +#define cudaGetDeviceProperties musaGetDeviceProperties +#define cudaGetErrorString musaGetErrorString +#define cudaGetLastError musaGetLastError +#define cudaHostRegister musaHostRegister +#define cudaHostRegisterPortable musaHostRegisterPortable +#define cudaHostRegisterReadOnly musaHostRegisterReadOnly +#define cudaHostUnregister musaHostUnregister +#define cudaLaunchHostFunc musaLaunchHostFunc +#define cudaMalloc musaMalloc +#define cudaMallocHost musaMallocHost +#define cudaMemcpy musaMemcpy +#define cudaMemcpyAsync musaMemcpyAsync +#define cudaMemcpyPeerAsync musaMemcpyPeerAsync +#define cudaMemcpy2DAsync musaMemcpy2DAsync +#define cudaMemcpyDeviceToDevice musaMemcpyDeviceToDevice +#define cudaMemcpyDeviceToHost musaMemcpyDeviceToHost +#define cudaMemcpyHostToDevice musaMemcpyHostToDevice +#define cudaMemcpyKind musaMemcpyKind +#define cudaMemset musaMemset +#define cudaMemsetAsync musaMemsetAsync +#define cudaMemGetInfo musaMemGetInfo +#define cudaOccupancyMaxPotentialBlockSize musaOccupancyMaxPotentialBlockSize +#define cudaSetDevice musaSetDevice +#define cudaStreamCreateWithFlags musaStreamCreateWithFlags +#define cudaStreamDestroy musaStreamDestroy +#define cudaStreamFireAndForget musaStreamFireAndForget +#define cudaStreamNonBlocking musaStreamNonBlocking +#define cudaStreamPerThread musaStreamPerThread +#define cudaStreamSynchronize musaStreamSynchronize +#define cudaStreamWaitEvent musaStreamWaitEvent +#define cudaStream_t musaStream_t +#define cudaSuccess musaSuccess + +// XXX: Other CUDA => MUSA mapping +#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE MU_MEM_ACCESS_FLAGS_PROT_READWRITE +#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED MU_MEM_ALLOC_GRANULARITY_RECOMMENDED +#define CU_MEM_ALLOCATION_TYPE_PINNED MU_MEM_ALLOCATION_TYPE_PINNED +#define CU_MEM_LOCATION_TYPE_DEVICE MU_MEM_LOCATION_TYPE_DEVICE +#define CUdevice MUdevice +#define CUdeviceptr MUdeviceptr +#define CUmemAccessDesc MUmemAccessDesc +#define CUmemAllocationProp MUmemAllocationProp +#define CUmemGenericAllocationHandle MUmemGenericAllocationHandle +#define cuDeviceGet muDeviceGet +#define cuDeviceGetAttribute muDeviceGetAttribute +#define cuMemAddressFree muMemAddressFree +#define cuMemAddressReserve muMemAddressReserve +#define cuMemCreate muMemCreate +#define cuMemGetAllocationGranularity muMemGetAllocationGranularity +#define cuMemMap muMemMap +#define cuMemRelease muMemRelease +#define cuMemSetAccess muMemSetAccess +#define cuMemUnmap muMemUnmap +#define cudaFuncAttributeMaxDynamicSharedMemorySize musaFuncAttributeMaxDynamicSharedMemorySize +#define cudaFuncSetAttribute musaFuncSetAttribute +#define cudaMemcpy3DPeerParms musaMemcpy3DPeerParms +#define make_cudaExtent make_musaExtent +#define make_cudaPitchedPtr make_musaPitchedPtr + +// XXX: USE_CUDA_GRAPH +#define CUDA_SUCCESS MUSA_SUCCESS +#define CUresult MUresult +#define cuGetErrorString muGetErrorString +#define cudaErrorGraphExecUpdateFailure musaErrorGraphExecUpdateFailure +#define cudaErrorInvalidDeviceFunction musaErrorInvalidDeviceFunction +#define cudaGraphDestroy musaGraphDestroy +#define cudaGraphExecDestroy musaGraphExecDestroy +#define cudaGraphExec_t musaGraphExec_t +#define cudaGraphExecUpdate musaGraphExecUpdate +#define cudaGraphExecUpdateResultInfo musaGraphExecUpdateResult +#define cudaGraphGetNodes musaGraphGetNodes +#define cudaGraphInstantiate musaGraphInstantiate +#define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams +#define cudaGraphKernelNodeSetParams musaGraphKernelNodeSetParams +#define cudaGraphLaunch musaGraphLaunch +#define cudaGraphNodeGetType musaGraphNodeGetType +#define cudaGraphNode_t musaGraphNode_t +#define cudaGraphNodeType musaGraphNodeType +#define cudaGraphNodeTypeKernel musaGraphNodeTypeKernel +#define cudaGraph_t musaGraph_t +#define cudaKernelNodeParams musaKernelNodeParams +#define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed +#define cudaStreamEndCapture musaStreamEndCapture + +// XXX: cuBLAS => muBLAS mapping +#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED MU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED +#define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_MATH_MODE_DEFAULT +#define CUBLAS_COMPUTE_16F CUDA_R_16F +#define CUBLAS_COMPUTE_32F CUDA_R_32F +#define cublasComputeType_t cudaDataType_t + +// XXX: Clang builtins mapping +#define __vsub4 __vsub4_musa +#define __vcmpeq4 __vcmpeq4_musa +#define __vcmpne4 __vcmpne4_musa #else #include #include @@ -165,29 +314,13 @@ #define CC_PASCAL 600 #define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products #define CC_VOLTA 700 +#define CC_TURING 750 #define CC_AMPERE 800 #define CC_OFFSET_AMD 1000000 #define CC_RDNA1 (CC_OFFSET_AMD + 1010) #define CC_RDNA2 (CC_OFFSET_AMD + 1030) #define CC_RDNA3 (CC_OFFSET_AMD + 1100) -// define this if you want to always fallback to MMQ kernels and not use cuBLAS for matrix multiplication -// on modern hardware, using cuBLAS is recommended as it utilizes F16 tensor cores which are very performant -// for large computational tasks. the drawback is that this requires some extra amount of VRAM: -// - 7B quantum model: +100-200 MB -// - 13B quantum model: +200-400 MB -// -//#define GGML_CUDA_FORCE_MMQ - -// TODO: improve this to be correct for more hardware -// for example, currently fails for GeForce GTX 1660 which is TURING arch (> VOLTA) but does not have tensor cores -#if !defined(GGML_CUDA_FORCE_MMQ) -#define CUDA_USE_TENSOR_CORES -#endif - -#define MMVQ_MAX_BATCH_SIZE 8 // max batch size to use MMVQ kernels -#define MMQ_MAX_BATCH_SIZE 64 // max batch size to use MMQ kernels when tensor cores are available - #define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses #if defined(_MSC_VER) @@ -209,9 +342,13 @@ void ggml_cuda_error(const char * stmt, const char * func, const char * file, in #define CUDA_CHECK(err) CUDA_CHECK_GEN(err, cudaSuccess, cudaGetErrorString) -#if CUDART_VERSION >= 12000 +#if CUDART_VERSION >= 12000 || defined(GGML_USE_MUSA) static const char * cublas_get_error_str(const cublasStatus_t err) { +#ifndef GGML_USE_MUSA return cublasGetStatusString(err); +#else + return mublasStatus_to_string(err); +#endif // GGML_USE_MUSA } #else static const char * cublas_get_error_str(const cublasStatus_t err) { @@ -241,7 +378,7 @@ static const char * cu_get_error_str(CUresult err) { #define CU_CHECK(err) CUDA_CHECK_GEN(err, CUDA_SUCCESS, cu_get_error_str) #endif -#if CUDART_VERSION >= 11100 +#if CUDART_VERSION >= 11100 || defined(GGML_USE_MUSA) #define GGML_CUDA_ASSUME(x) __builtin_assume(x) #else #define GGML_CUDA_ASSUME(x) @@ -255,6 +392,42 @@ typedef float dfloat; // dequantize float typedef float2 dfloat2; #endif //GGML_CUDA_F16 +#if defined(GGML_USE_MUSA) +#ifndef __has_builtin + #define __has_builtin(x) 0 +#endif + +typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4))); + +static __device__ __forceinline__ int __vsub4_musa(const int a, const int b) { + return __vsubss4(a, b); +} + +static __device__ __forceinline__ unsigned int __vcmpeq4_musa(unsigned int a, unsigned int b) { + const uint8x4_t& va = reinterpret_cast(a); + const uint8x4_t& vb = reinterpret_cast(b); + unsigned int c; + uint8x4_t& vc = reinterpret_cast(c); +#pragma unroll + for (int i = 0; i < 4; ++i) { + vc[i] = va[i] == vb[i] ? 0xff : 0x00; + } + return c; +} + +static __device__ __forceinline__ unsigned int __vcmpne4_musa(unsigned int a, unsigned int b) { + const uint8x4_t& va = reinterpret_cast(a); + const uint8x4_t& vb = reinterpret_cast(b); + unsigned int c; + uint8x4_t& vc = reinterpret_cast(c); +#pragma unroll + for (int i = 0; i < 4; ++i) { + vc[i] = va[i] == vb[i] ? 0x00 : 0xff; + } + return c; +} +#endif // defined(GGML_USE_MUSA) + #if defined(GGML_USE_HIPBLAS) #define __CUDA_ARCH__ 1300 @@ -268,6 +441,10 @@ typedef float2 dfloat2; #define RDNA2 #endif +#if defined(__gfx1010__) || defined(__gfx1012__) +#define RDNA1 +#endif + #ifndef __has_builtin #define __has_builtin(x) 0 #endif @@ -310,30 +487,15 @@ static __device__ __forceinline__ unsigned int __vcmpeq4(unsigned int a, unsigne return c; } -static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) { -#if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx1030__) - c = __builtin_amdgcn_sdot4(a, b, c, false); -#elif defined(RDNA3) - c = __builtin_amdgcn_sudot4( true, a, true, b, c, false); -#elif defined(__gfx1010__) || defined(__gfx900__) - int tmp1; - int tmp2; - asm("\n \ - v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_0 src1_sel:BYTE_0 \n \ - v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_1 src1_sel:BYTE_1 \n \ - v_add3_u32 %0, %1, %2, %0 \n \ - v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_2 src1_sel:BYTE_2 \n \ - v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_3 src1_sel:BYTE_3 \n \ - v_add3_u32 %0, %1, %2, %0 \n \ - " - : "+v"(c), "=&v"(tmp1), "=&v"(tmp2) - : "v"(a), "v"(b) - ); -#else - const int8x4_t va = reinterpret_cast(a); - const int8x4_t vb = reinterpret_cast(b); - c += va[0] * vb[0] + va[1] * vb[1] + va[2] * vb[2] + va[3] * vb[3]; -#endif +static __device__ __forceinline__ unsigned int __vcmpne4(unsigned int a, unsigned int b) { + const uint8x4_t& va = reinterpret_cast(a); + const uint8x4_t& vb = reinterpret_cast(b); + unsigned int c; + uint8x4_t& vc = reinterpret_cast(c); +#pragma unroll + for (int i = 0; i < 4; ++i) { + vc[i] = va[i] == vb[i] ? 0x00 : 0xff; + } return c; } @@ -352,18 +514,34 @@ static __device__ __forceinline__ half2 __shfl_xor(half2 var, int laneMask, int #endif // defined(__HIP_PLATFORM_AMD__) && HIP_VERSION < 50600000 #endif // defined(GGML_USE_HIPBLAS) -#define FP16_AVAILABLE (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL +#if (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL +#define FP16_AVAILABLE +#endif // (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL -#define FP16_MMA_AVAILABLE !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA +#if defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610 +#define FAST_FP16_AVAILABLE +#endif // defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610 -static bool fast_fp16_available(const int cc) { +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA +#define FP16_MMA_AVAILABLE +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA + +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING +#define INT8_MMA_AVAILABLE +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING + +static constexpr bool fast_fp16_available(const int cc) { return cc >= CC_PASCAL && cc != 610; } -static bool fp16_mma_available(const int cc) { +static constexpr bool fp16_mma_available(const int cc) { return cc < CC_OFFSET_AMD && cc >= CC_VOLTA; } +static constexpr bool int8_mma_available(const int cc) { + return cc < CC_OFFSET_AMD && cc >= CC_TURING; +} + [[noreturn]] static __device__ void no_device_code( const char * file_name, const int line, const char * function_name, const int arch, const char * arch_list) { @@ -384,7 +562,7 @@ static __device__ void no_device_code( #ifdef __CUDA_ARCH__ #define NO_DEVICE_CODE no_device_code(__FILE__, __LINE__, __FUNCTION__, __CUDA_ARCH__, STRINGIZE(__CUDA_ARCH_LIST__)) #else -#define NO_DEVICE_CODE //GGML_ASSERT(false && "NO_DEVICE_CODE not valid in host code.") +#define NO_DEVICE_CODE //GGML_ABORT("NO_DEVICE_CODE not valid in host code.") #endif // __CUDA_ARCH__ static __device__ __forceinline__ float warp_reduce_sum(float x) { @@ -405,7 +583,7 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) { } static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { -#if FP16_AVAILABLE +#ifdef FP16_AVAILABLE #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) #pragma unroll @@ -438,7 +616,7 @@ static __device__ __forceinline__ float warp_reduce_max(float x) { } static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) { -#if FP16_AVAILABLE +#ifdef FP16_AVAILABLE #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX return __float2half(fmaxf(__half2float(a), __half2float(b))); @@ -491,10 +669,50 @@ static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half const uint32_t mask_high = 0xFFFF0000 * (float(__high2half(a)) > float(__high2half(b))); return mask_low | mask_high; } -#endif // CUDART_VERSION < 12000 +#endif // CUDART_VERSION < CUDART_HMASK + +static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, int c) { +#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) +#if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(RDNA2) + c = __builtin_amdgcn_sdot4(a, b, c, false); +#elif defined(RDNA3) + c = __builtin_amdgcn_sudot4( true, a, true, b, c, false); +#elif defined(__gfx1010__) || defined(__gfx900__) + int tmp1; + int tmp2; + asm("\n \ + v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_0 src1_sel:BYTE_0 \n \ + v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_1 src1_sel:BYTE_1 \n \ + v_add3_u32 %0, %1, %2, %0 \n \ + v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_2 src1_sel:BYTE_2 \n \ + v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_3 src1_sel:BYTE_3 \n \ + v_add3_u32 %0, %1, %2, %0 \n \ + " + : "+v"(c), "=&v"(tmp1), "=&v"(tmp2) + : "v"(a), "v"(b) + ); +#else + const int8x4_t va = reinterpret_cast(a); + const int8x4_t vb = reinterpret_cast(b); + c += va[0] * vb[0] + va[1] * vb[1] + va[2] * vb[2] + va[3] * vb[3]; +#endif + return c; + +#else // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) + +#if __CUDA_ARCH__ >= MIN_CC_DP4A + return __dp4a(a, b, c); +#else // __CUDA_ARCH__ >= MIN_CC_DP4A + const int8_t * a8 = (const int8_t *) &a; + const int8_t * b8 = (const int8_t *) &b; + return c + a8[0]*b8[0] + a8[1]*b8[1] + a8[2]*b8[2] + a8[3]*b8[3]; +#endif // __CUDA_ARCH__ >= MIN_CC_DP4A + +#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) +} // TODO: move to ggml-common.h -static const __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113}; +static constexpr __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113}; typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v); @@ -652,19 +870,6 @@ struct ggml_cuda_type_traits { static constexpr int qi = QI3_S; }; -static int get_mmq_x_max_host(const int cc) { -#ifdef CUDA_USE_TENSOR_CORES - return cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? MMQ_MAX_BATCH_SIZE : 64; -#else - return cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? 128 : 64; -#endif // CUDA_USE_TENSOR_CORES -} - -// Round rows to this value for --split-mode row: -static int get_mmq_y_host(const int cc, const int mmq_x) { - return cc >= CC_VOLTA && mmq_x >= 32 ? 128 : 64; -} - ////////////////////// struct ggml_cuda_device_info { @@ -674,6 +879,7 @@ struct ggml_cuda_device_info { int cc; // compute capability int nsm; // number of streaming multiprocessors size_t smpb; // max. shared memory per block + size_t smpbo; // max. shared memory per block (with opt-in) bool vmm; // virtual memory support size_t vmm_granularity; // granularity of virtual memory size_t total_vram; diff --git a/llama/ggml-cuda/concat.cu b/llama/ggml-cuda/concat.cu index 351c379e..e77a1c44 100644 --- a/llama/ggml-cuda/concat.cu +++ b/llama/ggml-cuda/concat.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/concat.cuh b/llama/ggml-cuda/concat.cuh index b70e5669..f2010440 100644 --- a/llama/ggml-cuda/concat.cuh +++ b/llama/ggml-cuda/concat.cuh @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/conv-transpose-1d.cu b/llama/ggml-cuda/conv-transpose-1d.cu new file mode 100644 index 00000000..0117a6b7 --- /dev/null +++ b/llama/ggml-cuda/conv-transpose-1d.cu @@ -0,0 +1,113 @@ +/** + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file + * + * MIT License + * + * Copyright (c) 2023-2024 The ggml authors + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#include "conv-transpose-1d.cuh" + +static __global__ void conv_transpose_1d_kernel( + const int s0, const int p0, const int d0, const int output_size, + const int src0_ne0, const int src0_ne1, const int src0_ne2, const int src0_ne3, + const int src1_ne0, const int src1_ne1, const int src1_ne2, const int src1_ne3, + const int dst_ne0, const int dst_ne1, const int dst_ne2, const int dst_ne3, + const float * src0, const float * src1, float * dst) { + int global_index = threadIdx.x + blockIdx.x * blockDim.x; + if (global_index >= output_size) { + return; + } + + int out_index = global_index / dst_ne0; + + float accumulator = 0; + + for (int c = 0; c < src0_ne2; c++) { + int idx = global_index % dst_ne0; + + int kernel_offset = (src0_ne0 * src0_ne1 * c) + (out_index * src0_ne0); + int input_offset = src1_ne0 * c; + + for (int i = 0; i < src1_ne0; i++) { + if (!(idx >= i*s0 && idx < i*s0 + src0_ne0)) { + continue; + } + int weight_idx = idx - i*s0; + + float kernel_weight = src0[kernel_offset + weight_idx]; + float input_value = src1[input_offset+i]; + + accumulator += kernel_weight * input_value; + } + } + dst[global_index] = accumulator; +} + +static void conv_transpose_1d_f32_f32_cuda( + const int s0, const int p0, const int d0, const int output_size, + const int src0_ne0, const int src0_ne1, const int src0_ne2, const int src0_ne3, + const int src1_ne0, const int src1_ne1, const int src1_ne2, const int src1_ne3, + const int dst_ne0, const int dst_ne1, const int dst_ne2, const int dst_ne3, + const float * src0, const float * src1, float * dst, + cudaStream_t stream) { + + const int num_blocks = (output_size + CUDA_CONV_TRANPOSE_1D_BLOCK_SIZE - 1) / CUDA_CONV_TRANPOSE_1D_BLOCK_SIZE; + conv_transpose_1d_kernel<<>>( + s0,p0,d0,output_size, + src0_ne0, src0_ne1, src0_ne2, src0_ne3, + src1_ne0, src1_ne1, src1_ne2, src1_ne3, + dst_ne0, dst_ne1, dst_ne2, dst_ne3, + src0,src1, dst); +} + +void ggml_cuda_op_conv_transpose_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const float * src0_d = (const float *)src0->data; + + const ggml_tensor * src1 = dst->src[1]; + const float * src1_d = (const float *)src1->data; + + float * dst_d = (float *)dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); + + const int32_t * opts = (const int32_t *)dst->op_params; + + const int s0 = opts[0]; + const int p0 = 0;//opts[3]; + const int d0 = 1;//opts[4]; + + const int64_t kernel_size = ggml_nelements(src0); + const int64_t input_size = ggml_nelements(src1); + const int64_t output_size = ggml_nelements(dst); + + conv_transpose_1d_f32_f32_cuda(s0, p0, d0, output_size, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + src0_d, src1_d, dst_d, stream); +} diff --git a/llama/ggml-cuda/conv-transpose-1d.cuh b/llama/ggml-cuda/conv-transpose-1d.cuh new file mode 100644 index 00000000..90ed15d0 --- /dev/null +++ b/llama/ggml-cuda/conv-transpose-1d.cuh @@ -0,0 +1,31 @@ +/** + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file + * + * MIT License + * + * Copyright (c) 2023-2024 The ggml authors + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#include "common.cuh" + +#define CUDA_CONV_TRANPOSE_1D_BLOCK_SIZE 256 + +void ggml_cuda_op_conv_transpose_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/llama/ggml-cuda/convert.cu b/llama/ggml-cuda/convert.cu index b5e56bf5..44a18e53 100644 --- a/llama/ggml-cuda/convert.cu +++ b/llama/ggml-cuda/convert.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/convert.cuh b/llama/ggml-cuda/convert.cuh index 3abba0da..a72f0206 100644 --- a/llama/ggml-cuda/convert.cuh +++ b/llama/ggml-cuda/convert.cuh @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/cpy.cu b/llama/ggml-cuda/cpy.cu index 408ad9c5..d5024659 100644 --- a/llama/ggml-cuda/cpy.cu +++ b/llama/ggml-cuda/cpy.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * @@ -477,7 +477,7 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg } else { fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__, ggml_type_name(src0->type), ggml_type_name(src1->type)); - GGML_ASSERT(false); + GGML_ABORT("fatal error"); } } @@ -510,7 +510,6 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) { } else { fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__, ggml_type_name(src0->type), ggml_type_name(src1->type)); - GGML_ASSERT(false); + GGML_ABORT("fatal error"); } } - diff --git a/llama/ggml-cuda/cpy.cuh b/llama/ggml-cuda/cpy.cuh index af93b77e..9907eb3e 100644 --- a/llama/ggml-cuda/cpy.cuh +++ b/llama/ggml-cuda/cpy.cuh @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/dequantize.cuh b/llama/ggml-cuda/dequantize.cuh index b7e55460..4baf3f59 100644 --- a/llama/ggml-cuda/dequantize.cuh +++ b/llama/ggml-cuda/dequantize.cuh @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/diagmask.cu b/llama/ggml-cuda/diagmask.cu index 16dcfcd5..14dbb972 100644 --- a/llama/ggml-cuda/diagmask.cu +++ b/llama/ggml-cuda/diagmask.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/diagmask.cuh b/llama/ggml-cuda/diagmask.cuh index 62338819..1ec8e9ba 100644 --- a/llama/ggml-cuda/diagmask.cuh +++ b/llama/ggml-cuda/diagmask.cuh @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/dmmv.cu b/llama/ggml-cuda/dmmv.cu index 88779422..feb9bf80 100644 --- a/llama/ggml-cuda/dmmv.cu +++ b/llama/ggml-cuda/dmmv.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * @@ -688,7 +688,7 @@ void ggml_cuda_op_dequantize_mul_mat_vec( convert_mul_mat_vec_f16_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); break; default: - GGML_ASSERT(false); + GGML_ABORT("fatal error"); break; } diff --git a/llama/ggml-cuda/dmmv.cuh b/llama/ggml-cuda/dmmv.cuh index fbb28f4f..be2b3fa6 100644 --- a/llama/ggml-cuda/dmmv.cuh +++ b/llama/ggml-cuda/dmmv.cuh @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/fattn-common.cuh b/llama/ggml-cuda/fattn-common.cuh index bea6c4fd..cba14ae2 100644 --- a/llama/ggml-cuda/fattn-common.cuh +++ b/llama/ggml-cuda/fattn-common.cuh @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * @@ -80,12 +80,11 @@ typedef float (*vec_dot_KQ_f32_t)( template static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { -#if __CUDA_ARCH__ >= MIN_CC_DP4A const block_q4_0 * K_q4_0 = (const block_q4_0 *) K_c; GGML_UNUSED(Q_v); - half sum = 0.0f; + T sum = 0.0f; #pragma unroll for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) { @@ -95,12 +94,12 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0( const int iqs4 = k_KQ % QI4_0; const int shift = k_KQ & (QI8_1/2); - const int v = (get_int_from_uint8(K_q4_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F; + const int v = (get_int_b2(K_q4_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F; const int u = Q_q8[k_KQ_0/WARP_SIZE]; - const int sumi = __dp4a(v, u, 0); + const int sumi = ggml_cuda_dp4a(v, u, 0); -#if FP16_AVAILABLE +#ifdef FP16_AVAILABLE if (std::is_same::value) { const half2 * Q_ds = (const half2 *) Q_ds_v; @@ -116,19 +115,11 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0( } return sum; -#else - GGML_UNUSED(K_c); - GGML_UNUSED(Q_v); - GGML_UNUSED(Q_q8); - GGML_UNUSED(Q_ds_v); - NO_DEVICE_CODE; -#endif // __CUDA_ARCH__ >= MIN_CC_DP4A } template static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { -#if __CUDA_ARCH__ >= MIN_CC_DP4A const block_q4_1 * K_q4_1 = (const block_q4_1 *) K_c; GGML_UNUSED(Q_v); @@ -143,12 +134,12 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1( const int iqs4 = k_KQ % QI4_1; const int shift = k_KQ & (QI8_1/2); - const int v = (get_int_from_uint8_aligned(K_q4_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F; + const int v = (get_int_b4(K_q4_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F; const int u = Q_q8[k_KQ_0/WARP_SIZE]; - const int sumi = __dp4a(v, u, 0); + const int sumi = ggml_cuda_dp4a(v, u, 0); -#if FP16_AVAILABLE +#ifdef FP16_AVAILABLE if (std::is_same::value) { const half2 * Q_ds = (const half2 *) Q_ds_v; @@ -168,19 +159,11 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1( } return sum; -#else - GGML_UNUSED(K_c); - GGML_UNUSED(Q_v); - GGML_UNUSED(Q_q8); - GGML_UNUSED(Q_ds_v); - NO_DEVICE_CODE; -#endif // __CUDA_ARCH__ >= MIN_CC_DP4A } template static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { -#if __CUDA_ARCH__ >= MIN_CC_DP4A const block_q5_0 * K_q5_0 = (const block_q5_0 *) K_c; GGML_UNUSED(Q_v); @@ -196,8 +179,8 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0( const int iqs8 = k_KQ % QI8_1; const int shift = k_KQ & (QI8_1/2); - int v = (get_int_from_uint8(K_q5_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F; - const int vh = get_int_from_uint8(K_q5_0[ib].qh, 0) >> (iqs8 * QI5_0); + int v = (get_int_b2(K_q5_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F; + const int vh = get_int_b2(K_q5_0[ib].qh, 0) >> (iqs8 * QI5_0); v |= (vh << 4) & 0x00000010; // 0 -> 4 v |= (vh << 11) & 0x00001000; // 1 -> 12 v |= (vh << 18) & 0x00100000; // 2 -> 20 @@ -205,9 +188,9 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0( const int u = Q_q8[k_KQ_0/WARP_SIZE]; - const int sumi = __dp4a(v, u, 0); + const int sumi = ggml_cuda_dp4a(v, u, 0); -#if FP16_AVAILABLE +#ifdef FP16_AVAILABLE if (std::is_same::value) { const half2 * Q_ds = (const half2 *) Q_ds_v; @@ -223,19 +206,11 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0( } return sum; -#else - GGML_UNUSED(K_c); - GGML_UNUSED(Q_v); - GGML_UNUSED(Q_q8); - GGML_UNUSED(Q_ds_v); - NO_DEVICE_CODE; -#endif // __CUDA_ARCH__ >= MIN_CC_DP4A } template static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { -#if __CUDA_ARCH__ >= MIN_CC_DP4A const block_q5_1 * K_q5_1 = (const block_q5_1 *) K_c; GGML_UNUSED(Q_v); @@ -251,8 +226,8 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1( const int iqs8 = k_KQ % QI8_1; const int shift = k_KQ & (QI8_1/2); - int v = (get_int_from_uint8(K_q5_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F; - const int vh = get_int_from_uint8(K_q5_1[ib].qh, 0) >> (iqs8 * QI5_1); + int v = (get_int_b2(K_q5_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F; + const int vh = get_int_b2(K_q5_1[ib].qh, 0) >> (iqs8 * QI5_1); v |= (vh << 4) & 0x00000010; // 0 -> 4 v |= (vh << 11) & 0x00001000; // 1 -> 12 v |= (vh << 18) & 0x00100000; // 2 -> 20 @@ -260,9 +235,9 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1( const int u = Q_q8[k_KQ_0/WARP_SIZE]; - const int sumi = __dp4a(v, u, 0); + const int sumi = ggml_cuda_dp4a(v, u, 0); -#if FP16_AVAILABLE +#ifdef FP16_AVAILABLE if (std::is_same::value) { const half2 * Q_ds = (const half2 *) Q_ds_v; @@ -282,19 +257,11 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1( } return sum; -#else - GGML_UNUSED(K_c); - GGML_UNUSED(Q_v); - GGML_UNUSED(Q_q8); - GGML_UNUSED(Q_ds_v); - NO_DEVICE_CODE; -#endif // __CUDA_ARCH__ >= MIN_CC_DP4A } template static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { -#if __CUDA_ARCH__ >= MIN_CC_DP4A const block_q8_0 * K_q8_0 = (const block_q8_0 *) K_c; GGML_UNUSED(Q_v); @@ -308,7 +275,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0( const int ib = k_KQ / QI8_0; const int iqs = k_KQ % QI8_0; - const int v = get_int_from_int8(K_q8_0[ib].qs, iqs); + const int v = get_int_b2(K_q8_0[ib].qs, iqs); T Q_d; if (std::is_same::value) { @@ -323,13 +290,6 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0( } return sum; -#else - GGML_UNUSED(K_c); - GGML_UNUSED(Q_v); - GGML_UNUSED(Q_q8); - GGML_UNUSED(Q_ds_v); - NO_DEVICE_CODE; -#endif // __CUDA_ARCH__ >= MIN_CC_DP4A } template @@ -340,7 +300,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16( GGML_UNUSED(Q_q8); GGML_UNUSED(Q_ds_v); -#if FP16_AVAILABLE +#ifdef FP16_AVAILABLE if (std::is_same::value) { const half2 * Q_h2 = (const half2 *) Q_v; @@ -433,7 +393,7 @@ static __device__ __forceinline__ T dequantize_1_q4_0(const void * __restrict__ const int q0 = x[ib].qs[iqs]; const int q = ((q0 >> (4*shift)) & 0x0F) - 8; -#if FP16_AVAILABLE +#ifdef FP16_AVAILABLE if (std::is_same::value) { return ((half) d)*((half) q); } @@ -454,7 +414,7 @@ static __device__ __forceinline__ T dequantize_1_q4_1(const void * __restrict__ const int q0 = x[ib].qs[iqs]; const int q = ((q0 >> (4*shift)) & 0x0F); -#if FP16_AVAILABLE +#ifdef FP16_AVAILABLE if (std::is_same::value) { return __low2half(dm)*((half) q) + __high2half(dm); } @@ -474,12 +434,12 @@ static __device__ __forceinline__ T dequantize_1_q5_0(const void * __restrict__ const T d = x[ib].d; const int ql0 = x[ib].qs[iqs]; - const int qh0 = get_int_from_uint8(x[ib].qh, 0); + const int qh0 = get_int_b2(x[ib].qh, 0); const int ql = ((ql0 >> (4*shift)) & 0x0F); const int qh = ((qh0 >> idq) << 4) & 0x10; const int q = (ql | qh) - 16; -#if FP16_AVAILABLE +#ifdef FP16_AVAILABLE if (std::is_same::value) { return ((half) d)*((half) q); } @@ -499,12 +459,12 @@ static __device__ __forceinline__ T dequantize_1_q5_1(const void * __restrict__ const half2 dm = x[ib].dm; const int ql0 = x[ib].qs[iqs]; - const int qh0 = get_int_from_uint8_aligned(x[ib].qh, 0); + const int qh0 = get_int_b4(x[ib].qh, 0); const int ql = ((ql0 >> (4*shift)) & 0x0F); const int qh = ((qh0 >> idq) << 4) & 0x10; const int q = (ql | qh); -#if FP16_AVAILABLE +#ifdef FP16_AVAILABLE if (std::is_same::value) { return __low2half(dm)*((half) q) + __high2half(dm); } @@ -523,7 +483,7 @@ static __device__ __forceinline__ T dequantize_1_q8_0(const void * __restrict__ const T d = x[ib].d; const int q = x[ib].qs[iqs]; -#if FP16_AVAILABLE +#ifdef FP16_AVAILABLE if (std::is_same::value) { return ((half) d)*((half) q); } @@ -629,20 +589,20 @@ static void on_no_fattn_vec_case(const int D) { if (D == 64) { fprintf(stderr, "Unsupported KV type combination for head_size 64.\n"); fprintf(stderr, "By default only f16 KV cache is supported.\n"); - fprintf(stderr, "Compile with LLAMA_CUDA_FA_ALL_QUANTS for V cache quantization support.\n"); - GGML_ASSERT(false); + fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for V cache quantization support.\n"); + GGML_ABORT("fatal error"); } else if (D == 128) { fprintf(stderr, "Unsupported KV type combination for head_size 128.\n"); fprintf(stderr, "Supported combinations:\n"); fprintf(stderr, " - K == q4_0, V == q4_0, 4.50 BPV\n"); fprintf(stderr, " - K == q8_0, V == q8_0, 8.50 BPV\n"); fprintf(stderr, " - K == f16, V == f16, 16.00 BPV\n"); - fprintf(stderr, "Compile with LLAMA_CUDA_FA_ALL_QUANTS for all combinations of q4_0, q4_1, q5_0, q5_1, q8_0, and f16.\n"); - GGML_ASSERT(false); + fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for all combinations of q4_0, q4_1, q5_0, q5_1, q8_0, and f16.\n"); + GGML_ABORT("fatal error"); } else { fprintf(stderr, "Unsupported KV type combination for head_size 256.\n"); fprintf(stderr, "Only f16 is supported.\n"); - GGML_ASSERT(false); + GGML_ABORT("fatal error"); } } diff --git a/llama/ggml-cuda/fattn-tile-f16.cu b/llama/ggml-cuda/fattn-tile-f16.cu index 153e248d..a4fc2127 100644 --- a/llama/ggml-cuda/fattn-tile-f16.cu +++ b/llama/ggml-cuda/fattn-tile-f16.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * @@ -69,7 +69,7 @@ static __global__ void flash_attn_tile_ext_f16( const int ne1, const int ne2, const int ne3) { -#if FP16_AVAILABLE +#ifdef FP16_AVAILABLE //In this kernel Q, K, V are matrices while i, j, k are matrix indices. const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on. @@ -313,7 +313,7 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); } break; default: { - GGML_ASSERT(false && "FlashAttention without tensor cores only supports head sizes 64 and 128."); + GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128."); } break; } } diff --git a/llama/ggml-cuda/fattn-tile-f16.cuh b/llama/ggml-cuda/fattn-tile-f16.cuh index 1ad39a15..c48c863d 100644 --- a/llama/ggml-cuda/fattn-tile-f16.cuh +++ b/llama/ggml-cuda/fattn-tile-f16.cuh @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/fattn-tile-f32.cu b/llama/ggml-cuda/fattn-tile-f32.cu index 62a873ef..49c1ec56 100644 --- a/llama/ggml-cuda/fattn-tile-f32.cu +++ b/llama/ggml-cuda/fattn-tile-f32.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * @@ -310,7 +310,7 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); } break; default: { - GGML_ASSERT(false && "FlashAttention without tensor cores only supports head sizes 64 and 128."); + GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128."); } break; } } diff --git a/llama/ggml-cuda/fattn-tile-f32.cuh b/llama/ggml-cuda/fattn-tile-f32.cuh index 74bf3dd9..87c48525 100644 --- a/llama/ggml-cuda/fattn-tile-f32.cuh +++ b/llama/ggml-cuda/fattn-tile-f32.cuh @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/fattn-vec-f16.cuh b/llama/ggml-cuda/fattn-vec-f16.cuh index a702edd2..496535c1 100644 --- a/llama/ggml-cuda/fattn-vec-f16.cuh +++ b/llama/ggml-cuda/fattn-vec-f16.cuh @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * @@ -66,7 +66,7 @@ static __global__ void flash_attn_vec_ext_f16( const int ne1, const int ne2, const int ne3) { -#if FP16_AVAILABLE +#ifdef FP16_AVAILABLE //In this kernel Q, K, V are matrices while i, j, k are matrix indices. constexpr vec_dot_KQ_f16_t vec_dot_KQ = get_vec_dot_KQ_f16(type_K); diff --git a/llama/ggml-cuda/fattn-vec-f32.cuh b/llama/ggml-cuda/fattn-vec-f32.cuh index d2bd51f3..1517ac72 100644 --- a/llama/ggml-cuda/fattn-vec-f32.cuh +++ b/llama/ggml-cuda/fattn-vec-f32.cuh @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * @@ -175,7 +175,7 @@ static __global__ void flash_attn_vec_ext_f32( for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; - Q_f2[j][i0/WARP_SIZE] = ncols <= 2 || ic0 + j ? Q_f2_j[i] : make_float2(0.0f, 0.0f); + Q_f2[j][i0/WARP_SIZE] = ncols <= 2 || ic0 + j < ne01 ? Q_f2_j[i] : make_float2(0.0f, 0.0f); Q_f2[j][i0/WARP_SIZE].x *= scale; Q_f2[j][i0/WARP_SIZE].y *= scale; } diff --git a/llama/ggml-cuda/fattn-wmma-f16.cuh b/llama/ggml-cuda/fattn-wmma-f16.cuh index ada2a966..ce74f71d 100644 --- a/llama/ggml-cuda/fattn-wmma-f16.cuh +++ b/llama/ggml-cuda/fattn-wmma-f16.cuh @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * @@ -27,9 +27,9 @@ #include "common.cuh" #include "fattn-common.cuh" -#if FP16_MMA_AVAILABLE +#ifdef FP16_MMA_AVAILABLE #include -#endif +#endif // FP16_MMA_AVAILABLE // D == head size, VKQ_stride == num VKQ rows calculated in parallel: template @@ -71,7 +71,7 @@ static __global__ void flash_attn_ext_f16( const int ne1, const int ne2, const int ne3) { -#if FP16_MMA_AVAILABLE +#ifdef FP16_MMA_AVAILABLE //In this kernel Q, K, V are matrices while i, j, k are matrix indices. const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on. diff --git a/llama/ggml-cuda/fattn.cu b/llama/ggml-cuda/fattn.cu index 9dfb824d..511e19d4 100644 --- a/llama/ggml-cuda/fattn.cu +++ b/llama/ggml-cuda/fattn.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * @@ -64,7 +64,7 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst); break; default: - GGML_ASSERT(false); + GGML_ABORT("fatal error"); break; } } else { @@ -89,7 +89,7 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g // ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst); // break; default: - GGML_ASSERT(false); + GGML_ABORT("fatal error"); break; } } @@ -112,7 +112,7 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst); break; default: - GGML_ASSERT(false); + GGML_ABORT("fatal error"); break; } return; @@ -140,7 +140,7 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst); break; default: - GGML_ASSERT(false); + GGML_ABORT("fatal error"); break; } return; @@ -167,7 +167,7 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst); break; default: - GGML_ASSERT(false); + GGML_ABORT("fatal error"); break; } } diff --git a/llama/ggml-cuda/fattn.cuh b/llama/ggml-cuda/fattn.cuh index abb65afc..e04eefbc 100644 --- a/llama/ggml-cuda/fattn.cuh +++ b/llama/ggml-cuda/fattn.cuh @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/getrows.cu b/llama/ggml-cuda/getrows.cu index d5276dc6..87b09d8b 100644 --- a/llama/ggml-cuda/getrows.cu +++ b/llama/ggml-cuda/getrows.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * @@ -197,8 +197,7 @@ void ggml_cuda_op_get_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { break; default: // TODO: k-quants - fprintf(stderr, "%s: unsupported type: %s\n", __func__, ggml_type_name(src0->type)); - GGML_ASSERT(false); + GGML_ABORT("%s: unsupported type: %s\n", __func__, ggml_type_name(src0->type)); break; } } diff --git a/llama/ggml-cuda/getrows.cuh b/llama/ggml-cuda/getrows.cuh index 07e92cc9..0700d3a6 100644 --- a/llama/ggml-cuda/getrows.cuh +++ b/llama/ggml-cuda/getrows.cuh @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/im2col.cu b/llama/ggml-cuda/im2col.cu index 81dd9733..574e641b 100644 --- a/llama/ggml-cuda/im2col.cu +++ b/llama/ggml-cuda/im2col.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/im2col.cuh b/llama/ggml-cuda/im2col.cuh index 1e9aa07f..ca3d91f0 100644 --- a/llama/ggml-cuda/im2col.cuh +++ b/llama/ggml-cuda/im2col.cuh @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/mma.cuh b/llama/ggml-cuda/mma.cuh new file mode 100644 index 00000000..2e7fff79 --- /dev/null +++ b/llama/ggml-cuda/mma.cuh @@ -0,0 +1,247 @@ +/** + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file + * + * MIT License + * + * Copyright (c) 2023-2024 The ggml authors + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#include "common.cuh" + +struct mma_int_A_I16K4 { + static constexpr int I = 16; + static constexpr int K = 4; + static constexpr int ne = 2; + + int x[ne] = {0}; + + static __device__ __forceinline__ int get_i(const int l) { + const int ret = (l%2) * (I/2) + threadIdx.x / K; + GGML_CUDA_ASSUME(ret >= 0); + GGML_CUDA_ASSUME(ret < I); + return ret; + } + + static __device__ __forceinline__ int get_k(const int /* l */) { + const int ret = threadIdx.x % K; + GGML_CUDA_ASSUME(ret >= 0); + GGML_CUDA_ASSUME(ret < K); + return ret; + } + + __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) { +#if defined(INT8_MMA_AVAILABLE) + const int * xs = xs0 + (threadIdx.x%I)*stride; + asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];" + : "+r"(x[0]), "+r"(x[1]) + : "l"(xs)); +#else +#pragma unroll + for (int l = 0; l < ne; ++l) { + x[l] = xs0[get_i(l)*stride + get_k(l)]; + } +#endif // defined(INT8_MMA_AVAILABLE) + } +}; + +struct mma_int_A_I16K8 { + static constexpr int I = 16; + static constexpr int K = 8; + static constexpr int ne = 4; + + int x[ne] = {0}; + + static __device__ __forceinline__ int get_i(const int l) { + const int ret = (l%2) * (I/2) + threadIdx.x / (K/2); + GGML_CUDA_ASSUME(ret >= 0); + GGML_CUDA_ASSUME(ret < I); + return ret; + } + + static __device__ __forceinline__ int get_k(const int l) { + const int ret = (l/2) * (K/2) + threadIdx.x % (K/2); + GGML_CUDA_ASSUME(ret >= 0); + GGML_CUDA_ASSUME(ret < K); + return ret; + } + + __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) { +#if defined(INT8_MMA_AVAILABLE) + const int * xs = xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2); + asm("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];" + : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3]) + : "l"(xs)); +#else +#pragma unroll + for (int l = 0; l < ne; ++l) { + x[l] = xs0[get_i(l)*stride + get_k(l)]; + } +#endif // defined(INT8_MMA_AVAILABLE) + } + + __device__ __forceinline__ void load_low(const int * __restrict__ xs0, const int & stride) { + ((mma_int_A_I16K4 *) x)[0].load(xs0, stride); + } +}; + +struct mma_int_B_J8K4 { + static constexpr int J = 8; + static constexpr int K = 4; + static constexpr int ne = 1; + + int x[ne] = {0}; + + static __device__ __forceinline__ int get_j(const int /* l */) { + const int ret = threadIdx.x / K; + GGML_CUDA_ASSUME(ret >= 0); + GGML_CUDA_ASSUME(ret < J); + return ret; + } + + static __device__ __forceinline__ int get_k(const int /* l */) { + const int ret = threadIdx.x % K; + GGML_CUDA_ASSUME(ret >= 0); + GGML_CUDA_ASSUME(ret < K); + return ret; + } + + __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) { +#if defined(INT8_MMA_AVAILABLE) && false // Loading as 4 byte values is faster + const int * xs = xs0 + (threadIdx.x%J)*stride; + asm("ldmatrix.sync.aligned.m8n8.x1.b16 {%0}, [%1];" + : "+r"(x[0]) + : "l"(xs)); +#else +#pragma unroll + for (int l = 0; l < ne; ++l) { + x[l] = xs0[get_j(l)*stride + get_k(l)]; + } +#endif // defined(INT8_MMA_AVAILABLE) + } +}; + +struct mma_int_B_J8K8 { + static constexpr int J = 8; + static constexpr int K = 8; + static constexpr int ne = 2; + + int x[ne] = {0}; + + static __device__ __forceinline__ int get_j(const int /* l */) { + const int ret = threadIdx.x / (K/2); + GGML_CUDA_ASSUME(ret >= 0); + GGML_CUDA_ASSUME(ret < J); + return ret; + } + + static __device__ __forceinline__ int get_k(const int l) { + const int ret = l * (K/2) + threadIdx.x % (K/2); + GGML_CUDA_ASSUME(ret >= 0); + GGML_CUDA_ASSUME(ret < K); + return ret; + } + + __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) { +#if defined(INT8_MMA_AVAILABLE) && false // Loading as 4 byte values is faster + const int * xs = xs0 + (threadIdx.x%J)*stride + ((threadIdx.x/J)*(K/2)) % K; + asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];" + : "+r"(x[0]), "+r"(x[1]) + : "l"(xs)); +#else +#pragma unroll + for (int l = 0; l < ne; ++l) { + x[l] = xs0[get_j(l)*stride + get_k(l)]; + } +#endif // defined(INT8_MMA_AVAILABLE) + } +}; + +struct mma_int_C_I16J8 { + static constexpr int I = 16; + static constexpr int J = 8; + static constexpr int ne = 4; + + int x[ne] = {0}; + + static __device__ __forceinline__ int get_i(const int l) { + const int ret = (l/2) * (I/2) + threadIdx.x / (J/2); + GGML_CUDA_ASSUME(ret >= 0); + GGML_CUDA_ASSUME(ret < I); + return ret; + } + + static __device__ __forceinline__ int get_j(const int l) { + const int ret = 2 * (threadIdx.x % (J/2)) + l%2; + GGML_CUDA_ASSUME(ret >= 0); + GGML_CUDA_ASSUME(ret < J); + return ret; + } + + __device__ __forceinline__ void mma_K4(const mma_int_A_I16K4 & mma_A, const mma_int_B_J8K4 & mma_B) { +#ifdef INT8_MMA_AVAILABLE +#if __CUDA_ARCH__ >= CC_AMPERE + asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};" + : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3]) + : "r"(mma_A.x[0]), "r"(mma_A.x[1]), "r"(mma_B.x[0])); +#else + // On Turing m16n8k16 mma is not available, use 2x m8n8k16 mma instead: + asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};" + : "+r"(x[0]), "+r"(x[1]) + : "r"(mma_A.x[0]), "r"(mma_B.x[0])); + asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};" + : "+r"(x[2]), "+r"(x[3]) + : "r"(mma_A.x[1]), "r"(mma_B.x[0])); +#endif // __CUDA_ARCH__ >= CC_AMPERE +#else + GGML_UNUSED(mma_A); + GGML_UNUSED(mma_B); + NO_DEVICE_CODE; +#endif // INT8_MMA_AVAILABLE + } + + __device__ __forceinline__ void mma_K8(const mma_int_A_I16K8 & mma_A, const mma_int_B_J8K8 & mma_B) { +#ifdef INT8_MMA_AVAILABLE +#if __CUDA_ARCH__ >= CC_AMPERE + asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};" + : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3]) + : "r"(mma_A.x[0]), "r"(mma_A.x[1]), "r"(mma_A.x[2]), "r"(mma_A.x[3]), "r"(mma_B.x[0]), "r"(mma_B.x[1])); +#else + // On Turing m16n8k32 mma is not available, use 4x m8n8k16 mma instead: + asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};" + : "+r"(x[0]), "+r"(x[1]) + : "r"(mma_A.x[0]), "r"(mma_B.x[0])); + asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};" + : "+r"(x[2]), "+r"(x[3]) + : "r"(mma_A.x[1]), "r"(mma_B.x[0])); + asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};" + : "+r"(x[0]), "+r"(x[1]) + : "r"(mma_A.x[2]), "r"(mma_B.x[1])); + asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};" + : "+r"(x[2]), "+r"(x[3]) + : "r"(mma_A.x[3]), "r"(mma_B.x[1])); +#endif // __CUDA_ARCH__ >= CC_AMPERE +#else + GGML_UNUSED(mma_A); + GGML_UNUSED(mma_B); + NO_DEVICE_CODE; +#endif // INT8_MMA_AVAILABLE + } +}; diff --git a/llama/ggml-cuda/mmq.cu b/llama/ggml-cuda/mmq.cu index 346a4c72..a5046bf1 100644 --- a/llama/ggml-cuda/mmq.cu +++ b/llama/ggml-cuda/mmq.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * @@ -37,6 +37,7 @@ void ggml_cuda_op_mul_mat_q( const int64_t nb01 = src0->nb[1]; const int64_t ne10 = src1->ne[0]; + const int64_t ne11 = src1->ne[1]; GGML_ASSERT(ne10 % QK8_1 == 0); const int64_t ne0 = dst->ne[0]; @@ -51,41 +52,65 @@ void ggml_cuda_op_mul_mat_q( // nrows_dst == nrows of the matrix that the kernel writes into const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff; - const mmq_args args = {src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stride00, src1_padded_row_size, src1_ncols, nrows_dst}; + const mmq_args args = {src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stride00, src1_padded_row_size, src1_ncols, ne11, nrows_dst}; switch (src0->type) { case GGML_TYPE_Q4_0: - mul_mat_q_case(args, stream); + mul_mat_q_case(ctx, args, stream); break; case GGML_TYPE_Q4_1: - mul_mat_q_case(args, stream); + mul_mat_q_case(ctx, args, stream); break; case GGML_TYPE_Q5_0: - mul_mat_q_case(args, stream); + mul_mat_q_case(ctx, args, stream); break; case GGML_TYPE_Q5_1: - mul_mat_q_case(args, stream); + mul_mat_q_case(ctx, args, stream); break; case GGML_TYPE_Q8_0: - mul_mat_q_case(args, stream); + mul_mat_q_case(ctx, args, stream); break; case GGML_TYPE_Q2_K: - mul_mat_q_case(args, stream); + mul_mat_q_case(ctx, args, stream); break; case GGML_TYPE_Q3_K: - mul_mat_q_case(args, stream); + mul_mat_q_case(ctx, args, stream); break; case GGML_TYPE_Q4_K: - mul_mat_q_case(args, stream); + mul_mat_q_case(ctx, args, stream); break; case GGML_TYPE_Q5_K: - mul_mat_q_case(args, stream); + mul_mat_q_case(ctx, args, stream); break; case GGML_TYPE_Q6_K: - mul_mat_q_case(args, stream); + mul_mat_q_case(ctx, args, stream); + break; + case GGML_TYPE_IQ2_XXS: + mul_mat_q_case(ctx, args, stream); + break; + case GGML_TYPE_IQ2_XS: + mul_mat_q_case(ctx, args, stream); + break; + case GGML_TYPE_IQ2_S: + mul_mat_q_case(ctx, args, stream); + break; + case GGML_TYPE_IQ3_XXS: + mul_mat_q_case(ctx, args, stream); + break; + case GGML_TYPE_IQ3_S: + mul_mat_q_case(ctx, args, stream); + break; + case GGML_TYPE_IQ1_S: + mul_mat_q_case(ctx, args, stream); + break; + case GGML_TYPE_IQ4_XS: + mul_mat_q_case(ctx, args, stream); + break; + case GGML_TYPE_IQ4_NL: + mul_mat_q_case(ctx, args, stream); break; default: - GGML_ASSERT(false); + GGML_ABORT("fatal error"); break; } @@ -94,7 +119,13 @@ void ggml_cuda_op_mul_mat_q( GGML_UNUSED(src1_ddf_i); } -bool ggml_cuda_supports_mmq(enum ggml_type type) { +bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) { +#ifdef GGML_CUDA_FORCE_CUBLAS + return false; +#endif // GGML_CUDA_FORCE_CUBLAS + + bool mmq_supported; + switch (type) { case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: @@ -106,8 +137,40 @@ bool ggml_cuda_supports_mmq(enum ggml_type type) { case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: - return true; + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ4_XS: + case GGML_TYPE_IQ4_NL: + mmq_supported = true; + break; default: - return false; + mmq_supported = false; + break; } + + if (!mmq_supported) { + return false; + } + + if (int8_mma_available(cc)) { + return true; + } + + if (cc < MIN_CC_DP4A) { + return false; + } + +#ifdef GGML_CUDA_FORCE_MMQ + return true; +#endif //GGML_CUDA_FORCE_MMQ + + if (cc < CC_OFFSET_AMD) { + return cc < CC_VOLTA || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; + } + + return cc < CC_RDNA3 || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; } diff --git a/llama/ggml-cuda/mmq.cuh b/llama/ggml-cuda/mmq.cuh index 4a04f4a2..ab18ee1f 100644 --- a/llama/ggml-cuda/mmq.cuh +++ b/llama/ggml-cuda/mmq.cuh @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * @@ -24,107 +24,247 @@ * SOFTWARE. */ +#pragma once + #include "common.cuh" #include "vecdotq.cuh" +#include "mma.cuh" #include #include -typedef void (*load_tiles_mmq_t)( - const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, - int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride); -typedef void (*vec_dot_mmq_t)( - const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, - const int * __restrict__ y_qs, const half2 * __restrict__ y_ms, float * __restrict__ sum, const int & k0); +#define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available. +#define MMQ_ITER_K 256 +#define MMQ_NWARPS 8 + +typedef void (*load_tiles_mmq_t)(const char * __restrict__ x, int * x_tile, const int & kbx0, const int & i_max, const int & stride); +typedef void (*vec_dot_mmq_t)(const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00); +typedef void (*mmq_write_back_t)(const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max); + +enum mmq_q8_1_ds_layout { + MMQ_Q8_1_DS_LAYOUT_D4, + MMQ_Q8_1_DS_LAYOUT_DS4, + MMQ_Q8_1_DS_LAYOUT_D2S6, +}; + +struct block_q8_1_mmq { + // The y float data is converted to a data layout that can simply be copied to shared memory as a contiguous block. + // The y float data is first grouped as blocks of 128 values. + // These blocks are then treated as individual data values and transposed. + // + // To avoid shared memory bank conflicts each block is padded with 16 bytes. + // This padding is also used to store block scales/partial sums. + // The scales multiplied with the quantized data are equal to the unquantized values. + // The partial sums are obtained by summing up a subgroup of the contained values (prior to quantization) + // and are only needed for performance reasons. + // + // The exact data stored depends on the x data type. + union { + float d4[4]; // 1 32 bit scale per 32 values, stored as d0,d1,d2,d3 + half2 ds4[4]; // 1 16 bit scale + 1 16 bit partial sum per 32 values, stored as d0,s0,d1,s1,d2,s2,d3,s3 + half d2s6[8]; // 1 16 bit scale per 64 values + 1 16 bit partial sum per 16 values for the first 96 values, + // stored as d0,d1,s1,s2,s3,s4,s5 + }; + int8_t qs[4*QK8_1]; // 128 values quantized to 8 bit each +}; +static_assert(sizeof(block_q8_1_mmq) == 4*QK8_1 + 4*sizeof(half2), "Unexpected block_q8_1_mmq size"); +static_assert(sizeof(block_q8_1_mmq) == 4*sizeof(block_q8_1), "Unexpected block_q8_1_mmq size"); + +static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) { + switch (type_x) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + return MMQ_Q8_1_DS_LAYOUT_DS4; + case GGML_TYPE_Q5_0: + return MMQ_Q8_1_DS_LAYOUT_D4; + case GGML_TYPE_Q5_1: + return MMQ_Q8_1_DS_LAYOUT_DS4; + case GGML_TYPE_Q8_0: + return MMQ_Q8_1_DS_LAYOUT_D4; + case GGML_TYPE_Q2_K: + return MMQ_Q8_1_DS_LAYOUT_D2S6; + case GGML_TYPE_Q3_K: + return MMQ_Q8_1_DS_LAYOUT_D4; + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + return MMQ_Q8_1_DS_LAYOUT_DS4; + case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + return MMQ_Q8_1_DS_LAYOUT_D4; + case GGML_TYPE_IQ1_S: + return MMQ_Q8_1_DS_LAYOUT_DS4; + case GGML_TYPE_IQ4_XS: + case GGML_TYPE_IQ4_NL: + return MMQ_Q8_1_DS_LAYOUT_D4; + default: + GGML_ABORT("fatal error"); + break; + } +} struct tile_x_sizes { - int ql; + int qs; int dm; - int qh; int sc; }; -// get_mmq_x_max_host is in common.cuh so that it can be used to determine the correct way to round for --split-mode row +static constexpr int get_mmq_x_max_host(const int cc) { + return int8_mma_available(cc) ? 128 : +#ifdef GGML_CUDA_FORCE_MMQ + cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? 128 : 64; +#else + cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? MMQ_DP4A_MAX_BATCH_SIZE : 64; +#endif // GGML_CUDA_FORCE_MMQ +} static constexpr __device__ int get_mmq_x_max_device() { +#ifdef INT8_MMA_AVAILABLE + return 128; +#else // INT8_MMA_AVAILABLE + #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) - return 64; -#else + return 128; +#else // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) + #if __CUDA_ARCH__ >= CC_VOLTA -#ifdef CUDA_USE_TENSOR_CORES - return MMQ_MAX_BATCH_SIZE; +#ifdef GGML_CUDA_FORCE_MMQ + return MMQ_DP4A_MAX_BATCH_SIZE; +#else // GGML_CUDA_FORCE_MMQ + return 128; +#endif // GGML_CUDA_FORCE_MMQ +#else // __CUDA_ARCH__ >= CC_VOLTA + + return 64; +#endif // __CUDA_ARCH__ >= CC_VOLTA + +#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) +#endif // INT8_MMA_AVAILABLE +} + +static constexpr int get_mmq_y_host(const int cc) { + return cc >= CC_OFFSET_AMD ? (cc == CC_RDNA1 ? 64 : 128) : (cc >= CC_VOLTA ? 128 : 64); +} + +static constexpr __device__ int get_mmq_y_device() { +#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) +#if defined(RDNA1) + return 64; #else return 128; -#endif // CUDA_USE_TENSOR_CORES -#else - return 64; -#endif // __CUDA_ARCH__ >= CC_VOLTA -#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) -} - -// get_mmq_y_host is in common.cuh so that it can be used to determine the correct way to round for --split-mode row - -#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) -static constexpr __device__ int get_mmq_y_device(int mmq_x) { - return mmq_x >= 32 ? 128 : 64; -} +#endif // defined RDNA1 #else #if __CUDA_ARCH__ >= CC_VOLTA -static constexpr __device__ int get_mmq_y_device(int mmq_x) { - return mmq_x >= 32 ? 128 : 64; -} + return 128; #else -static constexpr __device__ int get_mmq_y_device(int /*mmq_x*/) { return 64; -} #endif // __CUDA_ARCH__ >= CC_VOLTA #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) - -#define TILE_X_SIZES_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0, 0} -#define TILE_X_SIZES_Q4_1 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_1 + mmq_y/QI4_1, 0, 0} -#define TILE_X_SIZES_Q5_0 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_0 + mmq_y/QI5_0, 0, 0} -#define TILE_X_SIZES_Q5_1 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_1 + mmq_y/QI5_1, 0, 0} -#define TILE_X_SIZES_Q8_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI8_0 + mmq_y/QI8_0, 0, 0} -#define TILE_X_SIZES_Q2_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI2_K + mmq_y/QI2_K, 0, mmq_y*WARP_SIZE/4 + mmq_y/4} -#define TILE_X_SIZES_Q3_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI3_K + mmq_y/QI3_K, mmq_y*WARP_SIZE/2 + mmq_y/2, mmq_y*WARP_SIZE/4 + mmq_y/4} -#define TILE_X_SIZES_Q4_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_K + mmq_y/QI4_K, 0, mmq_y*WARP_SIZE/8 + mmq_y/8} -#define TILE_X_SIZES_Q5_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_K + mmq_y/QI5_K, 0, mmq_y*WARP_SIZE/8 + mmq_y/8} -#define TILE_X_SIZES_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, 0, mmq_y*WARP_SIZE/8 + mmq_y/8} - -#define GET_TILE_X_SIZES_BODY \ - return type == GGML_TYPE_Q4_0 ? TILE_X_SIZES_Q4_0 : \ - type == GGML_TYPE_Q4_1 ? TILE_X_SIZES_Q4_1 : \ - type == GGML_TYPE_Q5_0 ? TILE_X_SIZES_Q5_0 : \ - type == GGML_TYPE_Q5_1 ? TILE_X_SIZES_Q5_1 : \ - type == GGML_TYPE_Q8_0 ? TILE_X_SIZES_Q8_0 : \ - type == GGML_TYPE_Q2_K ? TILE_X_SIZES_Q2_K : \ - type == GGML_TYPE_Q3_K ? TILE_X_SIZES_Q3_K : \ - type == GGML_TYPE_Q4_K ? TILE_X_SIZES_Q4_K : \ - type == GGML_TYPE_Q5_K ? TILE_X_SIZES_Q5_K : \ - type == GGML_TYPE_Q6_K ? TILE_X_SIZES_Q6_K : \ - tile_x_sizes{0, 0, 0, 0} - -static tile_x_sizes get_tile_x_sizes_host(const ggml_type type, const int mmq_y) { - GET_TILE_X_SIZES_BODY; } -template -static constexpr __device__ tile_x_sizes get_tile_x_sizes_device(ggml_type type) { - GET_TILE_X_SIZES_BODY; +#define MMQ_DP4A_TXS_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0} +#define MMQ_DP4A_TXS_Q4_1 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_1 + mmq_y/QI4_1, 0} +#define MMQ_DP4A_TXS_Q8_0 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*2/QI8_0 + mmq_y/(QI8_0/2), 0} +#define MMQ_DP4A_TXS_Q8_0_16 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*4/QI8_0 + mmq_y/(QI8_0/4), 0} +#define MMQ_DP4A_TXS_Q8_1 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*2/QI8_1 + mmq_y/(QI8_1/2), 0} +#define MMQ_DP4A_TXS_Q2_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE + mmq_y, 0} +#define MMQ_DP4A_TXS_Q3_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y, mmq_y*WARP_SIZE/8 + mmq_y/8} +#define MMQ_DP4A_TXS_Q4_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_K, mmq_y*WARP_SIZE/8 + mmq_y/8} +#define MMQ_DP4A_TXS_Q5_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_K + mmq_y/QI5_K, mmq_y*WARP_SIZE/8 + mmq_y/8} +#define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, mmq_y*WARP_SIZE/8 + mmq_y/8} + +static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) { + return type == GGML_TYPE_Q4_0 ? MMQ_DP4A_TXS_Q4_0 : + type == GGML_TYPE_Q4_1 ? MMQ_DP4A_TXS_Q4_1 : + type == GGML_TYPE_Q5_0 ? MMQ_DP4A_TXS_Q8_0 : + type == GGML_TYPE_Q5_1 ? MMQ_DP4A_TXS_Q8_1 : + type == GGML_TYPE_Q8_0 ? MMQ_DP4A_TXS_Q8_0 : + type == GGML_TYPE_Q2_K ? MMQ_DP4A_TXS_Q2_K : + type == GGML_TYPE_Q3_K ? MMQ_DP4A_TXS_Q3_K : + type == GGML_TYPE_Q4_K ? MMQ_DP4A_TXS_Q4_K : + type == GGML_TYPE_Q5_K ? MMQ_DP4A_TXS_Q5_K : + type == GGML_TYPE_Q6_K ? MMQ_DP4A_TXS_Q6_K : + type == GGML_TYPE_IQ2_XXS ? MMQ_DP4A_TXS_Q8_0 : + type == GGML_TYPE_IQ2_XS ? MMQ_DP4A_TXS_Q8_0_16 : + type == GGML_TYPE_IQ2_S ? MMQ_DP4A_TXS_Q8_0_16 : + type == GGML_TYPE_IQ3_XXS ? MMQ_DP4A_TXS_Q8_0 : + type == GGML_TYPE_IQ3_S ? MMQ_DP4A_TXS_Q8_0 : + type == GGML_TYPE_IQ1_S ? MMQ_DP4A_TXS_Q8_0 : + type == GGML_TYPE_IQ4_XS ? MMQ_DP4A_TXS_Q8_0 : + type == GGML_TYPE_IQ4_NL ? MMQ_DP4A_TXS_Q8_0 : + tile_x_sizes{0, 0, 0}; } +#define MMQ_MMA_TILE_X_K_Q8_0 (2*WARP_SIZE + 2*WARP_SIZE/QI8_0 + 4) +#define MMQ_MMA_TILE_X_K_Q8_1 (2*WARP_SIZE + 2*WARP_SIZE/QI8_0 + 4) +#define MMQ_MMA_TILE_X_K_Q2_K (2*WARP_SIZE + WARP_SIZE + 4) +#define MMQ_MMA_TILE_X_K_Q3_K (2*WARP_SIZE + WARP_SIZE/2 + 4) +#define MMQ_MMA_TILE_X_K_Q6_K (2*WARP_SIZE + WARP_SIZE/QI6_K + WARP_SIZE/8 + 7) + +static_assert(MMQ_MMA_TILE_X_K_Q8_0 % 8 == 4, "Wrong padding."); +static_assert(MMQ_MMA_TILE_X_K_Q8_1 % 8 == 4, "Wrong padding."); +static_assert(MMQ_MMA_TILE_X_K_Q2_K % 8 == 4, "Wrong padding."); +static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding."); +static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding."); + +static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { + return type == GGML_TYPE_Q4_0 ? MMQ_MMA_TILE_X_K_Q8_0 : + type == GGML_TYPE_Q4_1 ? MMQ_MMA_TILE_X_K_Q8_1 : + type == GGML_TYPE_Q5_0 ? MMQ_MMA_TILE_X_K_Q8_0 : + type == GGML_TYPE_Q5_1 ? MMQ_MMA_TILE_X_K_Q8_1 : + type == GGML_TYPE_Q8_0 ? MMQ_MMA_TILE_X_K_Q8_0 : + type == GGML_TYPE_Q2_K ? MMQ_MMA_TILE_X_K_Q2_K : + type == GGML_TYPE_Q3_K ? MMQ_MMA_TILE_X_K_Q3_K : + type == GGML_TYPE_Q4_K ? MMQ_MMA_TILE_X_K_Q8_1 : + type == GGML_TYPE_Q5_K ? MMQ_MMA_TILE_X_K_Q8_1 : + type == GGML_TYPE_Q6_K ? MMQ_MMA_TILE_X_K_Q6_K : + type == GGML_TYPE_IQ2_XXS ? MMQ_MMA_TILE_X_K_Q8_0 : + type == GGML_TYPE_IQ2_XS ? MMQ_MMA_TILE_X_K_Q3_K : + type == GGML_TYPE_IQ2_S ? MMQ_MMA_TILE_X_K_Q3_K : + type == GGML_TYPE_IQ3_XXS ? MMQ_MMA_TILE_X_K_Q8_0 : + type == GGML_TYPE_IQ3_S ? MMQ_MMA_TILE_X_K_Q8_0 : + type == GGML_TYPE_IQ1_S ? MMQ_MMA_TILE_X_K_Q8_0 : + type == GGML_TYPE_IQ4_XS ? MMQ_MMA_TILE_X_K_Q8_0 : + type == GGML_TYPE_IQ4_NL ? MMQ_MMA_TILE_X_K_Q8_0 : + 0; +} + +#define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1) + +static int mmq_get_granularity_host(const int mmq_x, const int cc) { + return int8_mma_available(cc) && mmq_x >= 48 ? 16 : 8; +} + +#ifdef INT8_MMA_AVAILABLE +static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) { + return mmq_x >= 48 ? 16 : 8; +} +#else +static constexpr __device__ int mmq_get_granularity_device(const int /* mmq_x */) { + return 8; +} +#endif // INT8_MMA_AVAILABLE + // ------------------------------------------------------------ template static __device__ __forceinline__ void load_tiles_q4_0( - const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, - int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) { - GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); + const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + 2*WARP_SIZE); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE const int kbx = threadIdx.x / QI4_0; const int kqsx = threadIdx.x % QI4_0; - float * x_dmf = (float *) x_dm; - #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { int i = i0 + threadIdx.y; @@ -134,8 +274,14 @@ template static __device__ __forceinlin } const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx; + const int qs0 = get_int_b2(bxi->qs, kqsx); - x_ql[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8(bxi->qs, kqsx); +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + 0] = __vsubss4((qs0 >> 0) & 0x0F0F0F0F, 0x08080808); + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + QI4_0] = __vsubss4((qs0 >> 4) & 0x0F0F0F0F, 0x08080808); +#else + x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0; +#endif // INT8_MMA_AVAILABLE } const int blocks_per_tile_x_row = WARP_SIZE / QI4_0; @@ -151,47 +297,65 @@ template static __device__ __forceinlin const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbxd; - x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbxd] = bxi->d; +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; +#else + x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + kbxd] = bxi->d; +#endif // INT8_MMA_AVAILABLE } } template -static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mul_mat( - const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, - const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) { +static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { - GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y); + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + txs.qs; + const int * y_qs = (const int *) y + 4; + const half2 * y_ds = (const half2 *) y; + +// #pragma unroll + for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_0*VDR_Q4_0_Q8_1_MMQ) { + const int k0 = k00 + k01; #pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { - const int j = j0 + threadIdx.y; + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; + for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; - const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2)); - const float * x_dmf = (const float *) x_dm; + const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2); - int u[2*VDR_Q4_0_Q8_1_MMQ]; + int u[2*VDR_Q4_0_Q8_1_MMQ]; #pragma unroll - for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) { - u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE]; - u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_0) % WARP_SIZE]; + for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) { + u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + kyqs + l]; + u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_0)]; + } + + sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_0_q8_1_impl + (&x_qs[i*(WARP_SIZE + 1) + k0/QR4_0], u, + x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/(QR4_0*QI4_0)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); } - - sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_0_q8_1_impl - (&x_ql[i * (WARP_SIZE + 1) + k0], u, x_dmf[i * (WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0], - y_ds[j * (WARP_SIZE/QI8_1) + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]); } } } template static __device__ __forceinline__ void load_tiles_q4_1( - const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, - int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) { - GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); + const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y); + int * x_qs = (int *) x_tile; + half2 * x_dm = (half2 *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE const int kbx = threadIdx.x / QI4_1; const int kqsx = threadIdx.x % QI4_1; @@ -205,8 +369,14 @@ template static __device__ __forceinlin } const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx; + const int qs0 = get_int_b4(bxi->qs, kqsx); - x_ql[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx); +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + 0] = (qs0 >> 0) & 0x0F0F0F0F; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + QI4_1] = (qs0 >> 4) & 0x0F0F0F0F; +#else + x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0; +#endif // INT8_MMA_AVAILABLE } const int blocks_per_tile_x_row = WARP_SIZE / QI4_1; @@ -222,46 +392,65 @@ template static __device__ __forceinlin const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbxd; - x_dm[i * (WARP_SIZE/QI4_1) + i / QI4_1 + kbxd] = bxi->dm; +#ifdef INT8_MMA_AVAILABLE + x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm; +#else + x_dm[i*(WARP_SIZE/QI4_1) + i/QI4_1 + kbxd] = bxi->dm; +#endif // INT8_MMA_AVAILABLE } } template -static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mul_mat( - const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, - const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) { +static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { - GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y); + const int * x_qs = (const int *) x; + const half2 * x_dm = (const half2 *) x_qs + txs.qs; + const int * y_qs = (const int *) y + 4; + const half2 * y_ds = (const half2 *) y; + +// #pragma unroll + for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_1*VDR_Q4_1_Q8_1_MMQ) { + const int k0 = k00 + k01; #pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { - const int j = j0 + threadIdx.y; + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; + for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; - const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2)); + const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2); - int u[2*VDR_Q4_1_Q8_1_MMQ]; + int u[2*VDR_Q4_1_Q8_1_MMQ]; #pragma unroll - for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) { - u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE]; - u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_1) % WARP_SIZE]; + for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) { + u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + kyqs + l]; + u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_1)]; + } + + sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_1_q8_1_impl + (&x_qs[i*(WARP_SIZE + 1) + k0/QR4_1], u, + x_dm[i*(WARP_SIZE/QI4_1) + i/QI4_1 + k0/(QR4_1*QI4_1)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); } - - sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_1_q8_1_impl - (&x_ql[i * (WARP_SIZE + 1) + k0], u, x_dm[i * (WARP_SIZE/QI4_1) + i/QI4_1 + k0/QI4_1], - y_ds[j * (WARP_SIZE/QI8_1) + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]); } } } template static __device__ __forceinline__ void load_tiles_q5_0( - const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, - int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) { - GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); + const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE const int kbx = threadIdx.x / QI5_0; const int kqsx = threadIdx.x % QI5_0; @@ -276,8 +465,8 @@ template static __device__ __forceinlin const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbx; - const int ql = get_int_from_uint8(bxi->qs, kqsx); - const int qh = get_int_from_uint8(bxi->qh, 0) >> (4 * (threadIdx.x % QI5_0)); + const int ql = get_int_b2(bxi->qs, kqsx); + const int qh = get_int_b2(bxi->qh, 0) >> (4 * (threadIdx.x % QI5_0)); int qs0 = (ql >> 0) & 0x0F0F0F0F; qs0 |= (qh << 4) & 0x00000010; // 0 -> 4 @@ -286,8 +475,6 @@ template static __device__ __forceinlin qs0 |= (qh << 25) & 0x10000000; // 3 -> 28 qs0 = __vsubss4(qs0, 0x10101010); // subtract 16 - x_ql[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+0] = qs0; - int qs1 = (ql >> 4) & 0x0F0F0F0F; qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4 qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12 @@ -295,12 +482,17 @@ template static __device__ __forceinlin qs1 |= (qh << 9) & 0x10000000; // 19 -> 28 qs1 = __vsubss4(qs1, 0x10101010); // subtract 16 - x_ql[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+1] = qs1; +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + 0] = qs0; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1; +#else + x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_0) + kqsx + 0] = qs0; + x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1; +#endif // INT8_MMA_AVAILABLE } const int blocks_per_tile_x_row = WARP_SIZE / QI5_0; const int kbxd = threadIdx.x % blocks_per_tile_x_row; - float * x_dmf = (float *) x_dm; #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_0) { @@ -312,49 +504,25 @@ template static __device__ __forceinlin const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd; - x_dmf[i * (WARP_SIZE/QI5_0) + i / QI5_0 + kbxd] = bxi->d; +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; +#else + x_df[i*(WARP_SIZE/QI5_0) + i/QI5_0 + kbxd] = bxi->d; +#endif // INT8_MMA_AVAILABLE } } -template -static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mul_mat( - const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, - const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) { - - GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); - -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { - const int j = j0 + threadIdx.y; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - - const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2)); - const int index_bx = i * (WARP_SIZE/QI5_0) + i/QI5_0 + k0/QI5_0; - const float * x_dmf = (const float *) x_dm; - const float * y_df = (const float *) y_ds; - - int u[2*VDR_Q5_0_Q8_1_MMQ]; - -#pragma unroll - for (int l = 0; l < VDR_Q5_0_Q8_1_MMQ; ++l) { - u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE]; - u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_0) % WARP_SIZE]; - } - - sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl - (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k0], u, x_dmf[index_bx], y_df[j * (WARP_SIZE/QI8_1) + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]); - } - } -} - - template static __device__ __forceinline__ void load_tiles_q5_1( - const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, - int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) { - GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); + const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y); + int * x_qs = (int *) x_tile; + half2 * x_dm = (half2 *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE const int kbx = threadIdx.x / QI5_1; const int kqsx = threadIdx.x % QI5_1; @@ -369,8 +537,8 @@ template static __device__ __forceinlin const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbx; - const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx); - const int qh = get_int_from_uint8_aligned(bxi->qh, 0) >> (4 * (threadIdx.x % QI5_1)); + const int ql = get_int_b4(bxi->qs, kqsx); + const int qh = get_int_b4(bxi->qh, 0) >> (4 * (threadIdx.x % QI5_1)); int qs0 = (ql >> 0) & 0x0F0F0F0F; qs0 |= (qh << 4) & 0x00000010; // 0 -> 4 @@ -378,15 +546,19 @@ template static __device__ __forceinlin qs0 |= (qh << 18) & 0x00100000; // 2 -> 20 qs0 |= (qh << 25) & 0x10000000; // 3 -> 28 - x_ql[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+0] = qs0; - int qs1 = (ql >> 4) & 0x0F0F0F0F; qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4 qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12 qs1 |= (qh << 2) & 0x00100000; // 18 -> 20 qs1 |= (qh << 9) & 0x10000000; // 19 -> 28 - x_ql[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+1] = qs1; +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + 0] = qs0; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1; +#else + x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_1) + kqsx + 0] = qs0; + x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1; +#endif // INT8_MMA_AVAILABLE } const int blocks_per_tile_x_row = WARP_SIZE / QI5_1; @@ -402,51 +574,28 @@ template static __device__ __forceinlin const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd; - x_dm[i * (WARP_SIZE/QI5_1) + i / QI5_1 + kbxd] = bxi->dm; - } -} - -template -static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mul_mat( - const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, - const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) { - - GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); - -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { - const int j = j0 + threadIdx.y; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - - const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2)); - const int index_bx = i * (WARP_SIZE/QI5_1) + + i/QI5_1 + k0/QI5_1; - - int u[2*VDR_Q5_1_Q8_1_MMQ]; - -#pragma unroll - for (int l = 0; l < VDR_Q5_1_Q8_1_MMQ; ++l) { - u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE]; - u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_1) % WARP_SIZE]; - } - - sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_1_q8_1_impl - (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k0], u, x_dm[index_bx], y_ds[j * (WARP_SIZE/QI8_1) + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]); - } +#ifdef INT8_MMA_AVAILABLE + x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm; +#else + x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + kbxd] = bxi->dm; +#endif // INT8_MMA_AVAILABLE } } template static __device__ __forceinline__ void load_tiles_q8_0( - const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, - int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) { + const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { - GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_tile + 2*WARP_SIZE); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE const int kbx = threadIdx.x / QI8_0; const int kqsx = threadIdx.x % QI8_0; - float * x_dmf = (float *) x_dm; #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { @@ -458,15 +607,21 @@ template static __device__ __forceinlin const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx; - x_ql[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_int8(bxi->qs, kqsx); +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 0 + threadIdx.x] = get_int_b2(bxi[0].qs, kqsx); + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + WARP_SIZE + threadIdx.x] = get_int_b2(bxi[WARP_SIZE/QI8_0].qs, kqsx); +#else + x_qs[i*(2*WARP_SIZE + 1) + 0 + threadIdx.x] = get_int_b2(bxi[0].qs, kqsx); + x_qs[i*(2*WARP_SIZE + 1) + WARP_SIZE + threadIdx.x] = get_int_b2(bxi[WARP_SIZE/QI8_0].qs, kqsx); +#endif // INT8_MMA_AVAILABLE } - const int blocks_per_tile_x_row = WARP_SIZE / QI8_0; + const int blocks_per_tile_x_row = 2*WARP_SIZE / QI8_0; const int kbxd = threadIdx.x % blocks_per_tile_x_row; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI8_0) { - int i = i0 + threadIdx.y * QI8_0 + threadIdx.x / blocks_per_tile_x_row; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI8_0/2) { + int i = i0 + threadIdx.y * (QI8_0/2) + threadIdx.x / blocks_per_tile_x_row; if (need_check) { i = min(i, i_max); @@ -474,249 +629,729 @@ template static __device__ __forceinlin const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd; - x_dmf[i * (WARP_SIZE/QI8_0) + i / QI8_0 + kbxd] = bxi->d; +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; +#else + x_df[i*(2*WARP_SIZE/QI8_0) + i/(QI8_0/2) + kbxd] = bxi->d; +#endif // INT8_MMA_AVAILABLE } } template -static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mul_mat( - const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, - const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) { +static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { - GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y); + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + txs.qs; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + +// #pragma unroll + for (int k01 = 0; k01 < WARP_SIZE; k01 += VDR_Q8_0_Q8_1_MMQ) { + const int k0 = k00 + k01; #pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { - const int j = j0 + threadIdx.y; + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; + for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; - const float * x_dmf = (const float *) x_dm; - const float * y_df = (const float *) y_ds; - - sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl - (&x_ql[i * (WARP_SIZE + 1) + k0], &y_qs[j * WARP_SIZE + k0], x_dmf[i * (WARP_SIZE/QI8_0) + i/QI8_0 + k0/QI8_0], - y_df[j * (WARP_SIZE/QI8_1) + k0/QI8_1]); + sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl + (&x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k0 % WARP_SIZE], + x_df[i*(2*WARP_SIZE/QI8_0) + i/(QI8_0/2) + k0/QI8_0], y_df[j*MMQ_TILE_Y_K + (k0/QI8_1) % (WARP_SIZE/QI8_1)]); + } } } } +template +static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { + + typedef mma_int_A_I16K8 mma_A; + typedef mma_int_B_J8K8 mma_B; + typedef mma_int_C_I16J8 mma_C; + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = 2 * granularity; + constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + 2*WARP_SIZE; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + const half2 * y_ds = (const half2 *) y; + + mma_A A[ntx][WARP_SIZE/QI8_0]; + float dA[ntx][mma_C::ne/2][WARP_SIZE/QI8_0]; + + const int i0 = (threadIdx.y/ntx)*rows_per_warp; + +#pragma unroll + for (int n = 0; n < ntx; ++n) { +#pragma unroll + for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) { + const int k0 = k00 + k01; + + A[n][k01/QI8_0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0); + } + +#pragma unroll + for (int l = 0; l < mma_C::ne/2; ++l) { + const int i = i0 + n*mma_A::I + mma_C::get_i(2*l); + +#pragma unroll + for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) { + const int k0 = k00 + k01; + + dA[n][l][k01/QI8_0] = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0]; + } + } + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { +#pragma unroll + for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) { + mma_B B; + float dB[mma_C::ne/2]; + + B.load(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + +#pragma unroll + for (int l = 0; l < mma_C::ne/2; ++l) { + const int j = j0 + mma_C::get_j(l); + + if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) { + dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1]; + } else { + dB[l] = __low2float(y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); + } + } + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + mma_C C; + C.mma_K8(A[n][k01/QI8_0], B); + +#pragma unroll + for (int l = 0; l < mma_C::ne; ++l) { + sum[(j0/mma_C::J + n)*mma_C::ne + l] += C.x[l]*dA[n][l/2][k01/QI8_0]*dB[l%2]; + } + } + } + } +} + +template +static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { + + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y); + const int * x_qs = (const int *) x; + const half2 * x_dm = (const half2 *) x_qs + txs.qs; + const int * y_qs = (const int *) y + 4; + const half2 * y_ds = (const half2 *) y; + +// #pragma unroll + for (int k01 = 0; k01 < WARP_SIZE; k01 += VDR_Q8_0_Q8_1_MMQ) { + const int k0 = k00 + k01; + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_1_q8_1_impl + (&x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], + x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + k0/QI8_1], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); + } + } + } +} + +template +static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { + + typedef mma_int_A_I16K8 mma_A; + typedef mma_int_B_J8K8 mma_B; + typedef mma_int_C_I16J8 mma_C; + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = 2 * granularity; + constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const half2 * x_dm = (const half2 *) x_qs + 2*WARP_SIZE; + const int * y_qs = (const int *) y + 4; + const half2 * y_dm = (const half2 *) y; + + mma_A A[ntx][WARP_SIZE/QI8_1]; + float2 dmA[ntx][mma_C::ne/2][WARP_SIZE/QI8_1]; + + const int i0 = (threadIdx.y/ntx)*rows_per_warp; + +#pragma unroll + for (int n = 0; n < ntx; ++n) { +#pragma unroll + for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) { + const int k0 = k00 + k01; + + A[n][k01/QI8_1].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1); + } + +#pragma unroll + for (int l = 0; l < mma_C::ne/2; ++l) { + const int i = i0 + n*mma_A::I + mma_C::get_i(2*l); + +#pragma unroll + for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) { + const int k0 = k00 + k01; + + dmA[n][l][k01/QI8_1] = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + k0/QI8_1]); + } + } + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { +#pragma unroll + for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) { + mma_B B; + float2 dsB[mma_C::ne/2]; + + B.load(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + +#pragma unroll + for (int l = 0; l < mma_C::ne/2; ++l) { + const int j = j0 + mma_C::get_j(l); + + dsB[l] = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]); + } + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + mma_C C; + C.mma_K8(A[n][k01/QI8_1], B); + +#pragma unroll + for (int l = 0; l < mma_C::ne; ++l) { + sum[(j0/mma_C::J + n)*mma_C::ne + l] += dmA[n][l/2][k01/QI8_1].x*dsB[l%2].x*C.x[l]; + sum[(j0/mma_C::J + n)*mma_C::ne + l] += dmA[n][l/2][k01/QI8_1].y*dsB[l%2].y; + } + } + } + } +} + +template +static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { + + constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16; + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + txs.qs; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + +// #pragma unroll + for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) { + const int k0 = k00 + k01; + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_16_q8_1_impl( + &x_qs[i*(2*WARP_SIZE + 1) + k0], + &y_qs[j*MMQ_TILE_Y_K + k01], + &x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + k0/(QI8_0/2)], + y_df[j*MMQ_TILE_Y_K + k01/QI8_1]); + } + } + } +} + +template +static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { +#ifdef INT8_MMA_AVAILABLE + + typedef mma_int_A_I16K4 mma_A; + typedef mma_int_A_I16K8 mma_A_K8; + typedef mma_int_B_J8K4 mma_B; + typedef mma_int_C_I16J8 mma_C; + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = 2 * granularity; + constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + WARP_SIZE*2; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + + const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I); + + mma_A A[ntx][8]; + float dA[ntx][mma_C::ne/2][8]; + +#pragma unroll + for (int n = 0; n < ntx; ++n) { +#pragma unroll + for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) { + const int k0 = k00 + k01; + + ((mma_A_K8 *) A[n])[k01/8].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K); + } + +#pragma unroll + for (int l = 0; l < mma_C::ne/2; ++l) { + const int i = i0 + n*mma_C::I + mma_C::get_i(2*l); + +#pragma unroll + for (int k01 = 0; k01 < WARP_SIZE; k01 += 4) { + const int k0 = k00 + k01; + + dA[n][l][k01/4] = x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4]; + } + } + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { +#pragma unroll + for (int k01 = 0; k01 < WARP_SIZE; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) { + mma_B B[2]; + float dB[mma_C::ne/2]; + + B[0].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K); + B[1].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K); + +#pragma unroll + for (int l = 0; l < mma_C::ne/2; ++l) { + const int j = j0 + mma_C::get_j(l); + + dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1]; + } + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + mma_C C[2]; + C[0].mma_K4(A[n][k01/4 + 0], B[0]); + C[1].mma_K4(A[n][k01/4 + 1], B[1]); + +#pragma unroll + for (int l = 0; l < mma_C::ne; ++l) { + sum[(j0/mma_C::J + n)*mma_C::ne + l] += dB[l%2]*(C[0].x[l]*dA[n][l/2][k01/4 + 0] + C[1].x[l]*dA[n][l/2][k01/4 + 1]); + } + } + } + } +#else + GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); + NO_DEVICE_CODE; +#endif // INT8_MMA_AVAILABLE +} + template static __device__ __forceinline__ void load_tiles_q2_K( - const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, - int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) { - GGML_UNUSED(x_qh); + const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y); + int * x_qs = (int *) x_tile; + half2 * x_dm = (half2 *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE - const int kbx = threadIdx.x / QI2_K; const int kqsx = threadIdx.x % QI2_K; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { - int i = i0 + threadIdx.y; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/QI2_K) { + int i = i0 + threadIdx.y*(WARP_SIZE/QI2_K) + threadIdx.x/QI2_K; if (need_check) { i = min(i, i_max); } - const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride + kbx; + const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride; - x_ql[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx); - } - - const int blocks_per_tile_x_row = WARP_SIZE / QI2_K; - const int kbxd = threadIdx.x % blocks_per_tile_x_row; + const int x_ql_0 = get_int_b2(bxi->qs, kqsx); #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI2_K) { - int i = (i0 + threadIdx.y * QI2_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y; + for (int l = 0; l < QR2_K; ++l) { + const int k = (kqsx/8)*32 + l*8 + kqsx % 8; - if (need_check) { - i = min(i, i_max); + const int x_qs_k = (x_ql_0 >> (2*l)) & 0x03030303; + +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k] = x_qs_k; +#else + x_qs[i*(2*WARP_SIZE + 1) + k] = x_qs_k; +#endif // INT8_MMA_AVAILABLE } - const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride + kbxd; + const int sc_m = bxi->scales[kqsx]; +#ifdef FAST_FP16_AVAILABLE + const half2 x_dm_ik = __hmul2(bxi->dm, make_half2(sc_m & 0x0F, sc_m >> 4)); +#else + const float2 bxi_dmf = __half22float2(bxi->dm); + const half2 x_dm_ik = make_half2(bxi_dmf.x*(sc_m & 0x0F), bxi_dmf.y*(sc_m >> 4)); +#endif // FAST_FP16_AVAILABLE - x_dm[i * (WARP_SIZE/QI2_K) + i / QI2_K + kbxd] = bxi->dm; - } - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) { - int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4); - - if (need_check) { - i = min(i, i_max); - } - - const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/4)) / (QI2_K/4); - - x_sc[i * (WARP_SIZE/4) + i / 4 + threadIdx.x % (WARP_SIZE/4)] = get_int_from_uint8_aligned(bxi->scales, threadIdx.x % (QI2_K/4)); +#ifdef INT8_MMA_AVAILABLE + x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + kqsx] = x_dm_ik; +#else + x_dm[i*(WARP_SIZE + 1) + kqsx] = x_dm_ik; +#endif // INT8_MMA_AVAILABLE } } template -static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mul_mat( - const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, - const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) { +static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { - GGML_UNUSED(x_qh); + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y); + const int * x_qs = (const int *) x; + const half2 * x_dm = (const half2 *) x_qs + txs.qs; + const int * y_qs = (const int *) y + 4; + const half2 * y_ds = (const half2 *) y; + float2 y_df[mmq_x/nwarps]; #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { const int j = j0 + threadIdx.y; -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - - const int kbx = k0 / QI2_K; - const int ky = (k0 % QI2_K) * QR2_K; - const float * y_df = (const float *) y_ds; - - int v[QR2_K*VDR_Q2_K_Q8_1_MMQ]; - - const int kqsx = i * (WARP_SIZE + 1) + kbx*QI2_K + (QI2_K/2) * (ky/(2*QI2_K)) + ky % (QI2_K/2); - const int shift = 2 * ((ky % (2*QI2_K)) / (QI2_K/2)); + y_df[j0/nwarps] = __half22float2(y_ds[j*MMQ_TILE_Y_K]); + } #pragma unroll - for (int l = 0; l < QR2_K*VDR_Q2_K_Q8_1_MMQ; ++l) { - v[l] = (x_ql[kqsx + l] >> shift) & 0x03030303; + for (int k01 = 0; k01 < WARP_SIZE; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) { + const int k0 = k00 + k01; + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + if (k01 < WARP_SIZE/2) { + constexpr int ns = 2; + sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq( + &x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], + &x_dm[i*(WARP_SIZE + 1) + k0/4], k01 < WARP_SIZE/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y, + &y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]); + } else { + constexpr int ns = 1; + sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq( + &x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], + &x_dm[i*(WARP_SIZE + 1) + k0/4], k01 < WARP_SIZE/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y, + &y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]); + } } - - const uint8_t * scales = ((const uint8_t *) &x_sc[i * (WARP_SIZE/4) + i/4 + kbx*4]) + ky/4; - - const int index_y = j * WARP_SIZE + (QR2_K*k0) % WARP_SIZE; - sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq( - v, &y_qs[index_y], scales, x_dm[i * (WARP_SIZE/QI2_K) + i/QI2_K + kbx], y_df[index_y/QI8_1]); } } } +template +static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { +#ifdef INT8_MMA_AVAILABLE + + typedef mma_int_A_I16K4 mma_A; + typedef mma_int_A_I16K8 mma_A_K8; + typedef mma_int_B_J8K4 mma_B; + typedef mma_int_C_I16J8 mma_C; + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = 2 * granularity; + constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE*2; + const int * y_qs = (const int *) y + 4; + const half2 * y_ds = (const half2 *) y; + + const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I); + + mma_A A[ntx][8]; + float dA[ntx][mma_C::ne/2][8]; + float mA[ntx][mma_C::ne/2][8]; + +#pragma unroll + for (int n = 0; n < ntx; ++n) { +#pragma unroll + for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) { + const int k0 = k00 + k01; + + ((mma_A_K8 *) A[n])[k01/QI8_1].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K); + } + } + +#pragma unroll + for (int n = 0; n < ntx; ++n) { +#pragma unroll + for (int l = 0; l < mma_C::ne/2; ++l) { + const int i = i0 + n*mma_C::I + mma_C::get_i(2*l); + +#pragma unroll + for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1/2) { + const int k0 = k00 + k01; + + const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/(QI8_1/2)]); + + dA[n][l][k01/(QI8_1/2)] = dm.x; + mA[n][l][k01/(QI8_1/2)] = dm.y; + } + } + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { + float2 dB[mma_C::ne/2]; + +#pragma unroll + for (int l = 0; l < mma_C::ne/2; ++l) { + const int j = j0 + mma_C::get_j(l); + + dB[l] = __half22float2(y_ds[j*MMQ_TILE_Y_K]); + } + +#pragma unroll + for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) { + mma_B B[2]; + + B[0].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K); + B[1].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K); + + mma_C Cm[2]; + if (k01 >= WARP_SIZE * 3/4) { + mma_A A1; + A1.x[0] = 0x01010101; + A1.x[1] = 0x01010101; + Cm[0].mma_K4(A1, B[0]); + Cm[1].mma_K4(A1, B[1]); + } + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + mma_C Cd[2]; + + Cd[0].mma_K4(A[n][k01/4 + 0], B[0]); + Cd[1].mma_K4(A[n][k01/4 + 1], B[1]); + +#pragma unroll + for (int l = 0; l < mma_C::ne; ++l) { + float tmp = Cd[0].x[l]*dA[n][l/2][k01/4 + 0] + Cd[1].x[l]*dA[n][l/2][k01/4 + 1]; + if (k01 >= WARP_SIZE * 3/4) { + tmp -= Cm[0].x[l]*mA[n][l/2][k01/4 + 0] + Cm[1].x[l]*mA[n][l/2][k01/4 + 1]; + } + sum[(j0/mma_C::J + n)*mma_C::ne + l] += tmp*(k01 < WARP_SIZE/2 ? dB[l%2].x : dB[l%2].y); + } + } + } + +#pragma unroll + for (int k01 = 0; k01 < WARP_SIZE * 3/4; k01 += QI8_1) { + float2 sB[mma_C::ne/2]; + +#pragma unroll + for (int l = 0; l < mma_C::ne/2; ++l) { + const int j = j0 + mma_C::get_j(l); + + sB[l] = __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]); + } + +#pragma unroll + for (int n = 0; n < ntx; ++n) { +#pragma unroll + for (int l = 0; l < mma_C::ne; ++l) { + sum[(j0/mma_C::J + n)*mma_C::ne + l] -= mA[n][l/2][k01/4 + 0]*sB[l%2].x; + sum[(j0/mma_C::J + n)*mma_C::ne + l] -= mA[n][l/2][k01/4 + 1]*sB[l%2].y; + } + } + } + } +#else + GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); + NO_DEVICE_CODE; +#endif // INT8_MMA_AVAILABLE +} + template static __device__ __forceinline__ void load_tiles_q3_K( - const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, - int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) { + const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); + int * x_sc = (int *) (x_df + txs.dm); +#endif // INT8_MMA_AVAILABLE - const int kbx = threadIdx.x / QI3_K; const int kqsx = threadIdx.x % QI3_K; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { - int i = i0 + threadIdx.y; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/QI3_K) { + int i = i0 + threadIdx.y * (WARP_SIZE/QI3_K) + threadIdx.x / QI3_K; if (need_check) { i = min(i, i_max); } - const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride + kbx; + const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride; - x_ql[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8(bxi->qs, kqsx); - } - - const int blocks_per_tile_x_row = WARP_SIZE / QI3_K; - const int kbxd = threadIdx.x % blocks_per_tile_x_row; - float * x_dmf = (float *) x_dm; + const int x_ql_0 = get_int_b2(bxi->qs, kqsx); + const int x_qh_0 = get_int_b2(bxi->hmask, kqsx % (QI3_K/2)) >> (4 * (kqsx / (QI3_K/2))); #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI3_K) { - int i = (i0 + threadIdx.y * QI3_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y; + for (int l = 0; l < QR3_K; ++l) { + const int k = (kqsx/8)*32 + l*8 + kqsx % 8; - if (need_check) { - i = min(i, i_max); + const int x_ql_k = (x_ql_0 >> (2*l)) & 0x03030303; + const int x_qh_k = ((x_qh_0 >> l) << 2) & 0x04040404; + + const int x_qs_k = __vsubss4(x_ql_k | x_qh_k, 0x04040404); + +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k] = x_qs_k; +#else + x_qs[i*(2*WARP_SIZE + 1) + k] = x_qs_k; +#endif // INT8_MMA_AVAILABLE } - - const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride + kbxd; - - x_dmf[i * (WARP_SIZE/QI3_K) + i / QI3_K + kbxd] = bxi->d; } #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 2) { - int i = i0 + threadIdx.y * 2 + threadIdx.x / (WARP_SIZE/2); + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*8) { + int i = i0 + threadIdx.y*8 + threadIdx.x/(WARP_SIZE/8); if (need_check) { i = min(i, i_max); } - const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/2)) / (QI3_K/2); + const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride; - // invert the mask with ~ so that a 0/1 results in 4/0 being subtracted - x_qh[i * (WARP_SIZE/2) + i / 2 + threadIdx.x % (WARP_SIZE/2)] = ~get_int_from_uint8(bxi->hmask, threadIdx.x % (QI3_K/2)); - } - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) { - int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4); - - if (need_check) { - i = min(i, i_max); - } - - const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/4)) / (QI3_K/4); - - const int ksc = threadIdx.x % (QI3_K/4); + const int ksc = threadIdx.x % (WARP_SIZE/8); const int ksc_low = ksc % (QI3_K/8); const int shift_low = 4 * (ksc / (QI3_K/8)); - const int sc_low = (get_int_from_uint8(bxi->scales, ksc_low) >> shift_low) & 0x0F0F0F0F; + const int sc_low = (get_int_b2(bxi->scales, ksc_low) >> shift_low) & 0x0F0F0F0F; const int ksc_high = QI3_K/8; const int shift_high = 2 * ksc; - const int sc_high = ((get_int_from_uint8(bxi->scales, ksc_high) >> shift_high) << 4) & 0x30303030; + const int sc_high = ((get_int_b2(bxi->scales, ksc_high) >> shift_high) << 4) & 0x30303030; const int sc = __vsubss4(sc_low | sc_high, 0x20202020); - x_sc[i * (WARP_SIZE/4) + i / 4 + threadIdx.x % (WARP_SIZE/4)] = sc; +#ifdef INT8_MMA_AVAILABLE + const int8_t * sc8 = (const int8_t *) ≻ + const float d = bxi->d; + +#pragma unroll + for (int l = 0; l < sizeof(int); ++l) { + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + sizeof(int)*(threadIdx.x % (WARP_SIZE/8)) + l] = d*sc8[l]; + } +#else + x_sc[i*(WARP_SIZE/8) + i/8 + threadIdx.x % (WARP_SIZE/8)] = sc; +#endif // INT8_MMA_AVAILABLE } + +#ifndef INT8_MMA_AVAILABLE +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*WARP_SIZE) { + int i = (i0 + threadIdx.y*WARP_SIZE + threadIdx.x) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride; + + x_df[i] = bxi->d; + } +#endif // INT8_MMA_AVAILABLE } template -static __device__ __forceinline__ void vec_dot_q3_K_q8_1_mul_mat( - const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, - const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) { +static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { + + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y); + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + txs.qs; + const int * x_sc = (const int *) x_df + txs.dm; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + +// #pragma unroll + for (int k01 = 0; k01 < WARP_SIZE; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) { + const int k0 = k00 + k01; #pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { - const int j = j0 + threadIdx.y; + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; + for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; - const int kbx = k0 / QI3_K; - const int ky = (k0 % QI3_K) * QR3_K; - const float * x_dmf = (const float *) x_dm; - const float * y_df = (const float *) y_ds; + const int8_t * scales = ((const int8_t *) (x_sc + i*(WARP_SIZE/8) + i/8)) + k0/4; - const int8_t * scales = ((const int8_t *) (x_sc + i * (WARP_SIZE/4) + i/4 + kbx*4)) + ky/4; - - int v[QR3_K*VDR_Q3_K_Q8_1_MMQ]; - -#pragma unroll - for (int l = 0; l < QR3_K*VDR_Q3_K_Q8_1_MMQ; ++l) { - const int kqsx = i * (WARP_SIZE + 1) + kbx*QI3_K + (QI3_K/2) * (ky/(2*QI3_K)) + ky % (QI3_K/2); - const int shift = 2 * ((ky % 32) / 8); - const int vll = (x_ql[kqsx + l] >> shift) & 0x03030303; - - const int vh = x_qh[i * (WARP_SIZE/2) + i/2 + kbx * (QI3_K/2) + (ky+l)%8] >> ((ky+l) / 8); - const int vlh = (vh << 2) & 0x04040404; - - v[l] = __vsubss4(vll, vlh); + sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q3_K_q8_1_impl_mmq( + &x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], scales, + x_df[i], y_df[j*MMQ_TILE_Y_K + k01/QI8_1]); } - - const int index_y = j * WARP_SIZE + (k0*QR3_K) % WARP_SIZE; - sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q3_K_q8_1_impl_mmq( - v, &y_qs[index_y], scales, x_dmf[i * (WARP_SIZE/QI3_K) + i/QI3_K + kbx], y_df[index_y/QI8_1]); } } } -template static __device__ __forceinline__ void load_tiles_q4_K( - const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, - int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) { - GGML_UNUSED(x_qh); +static __device__ __forceinline__ int unpack_scales_q45_K(const int * scales, const int ksc) { + // scale arrangement after the following two lines: + // - ksc == 0: sc0, sc1, sc2, sc3 + // - ksc == 1: sc4, sc5, sc6, sc7 + // - ksc == 2: m0, m1, m2, m3 + // - ksc == 3: m4, m5, m6, m7 + return ((scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F) | // lower 4 bits + ((scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030); // upper 2 bits +} - const int kbx = 0; // threadIdx.x / QI4_K - const int kqsx = threadIdx.x; // threadIdx.x % QI4_K +template static __device__ __forceinline__ void load_tiles_q4_K( + const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y); + int * x_qs = (int *) x_tile; + half2 * x_dm = (half2 *) (x_qs + txs.qs); + int * x_sc = (int *) (x_dm + txs.dm); +#endif // INT8_MMA_AVAILABLE #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { @@ -726,25 +1361,59 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + kbx; + const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride; + const int qs0 = get_int_b4(bxi->qs, threadIdx.x); - x_ql[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx); +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(threadIdx.x/8) + threadIdx.x % 8 + 0] = (qs0 >> 0) & 0x0F0F0F0F; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(threadIdx.x/8) + threadIdx.x % 8 + 8] = (qs0 >> 4) & 0x0F0F0F0F; +#else + x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0; +#endif // INT8_MMA_AVAILABLE } - const int blocks_per_tile_x_row = WARP_SIZE / QI4_K; // == 1 if QK_K == 256 - const int kbxd = threadIdx.x % blocks_per_tile_x_row; // == 0 if QK_K == 256 +#ifdef INT8_MMA_AVAILABLE #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_K) { - int i = (i0 + threadIdx.y * QI4_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*16) { + int i = (i0 + threadIdx.y*16 + threadIdx.x/(WARP_SIZE/16)) % mmq_y; if (need_check) { i = min(i, i_max); } - const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + kbxd; + const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride; - x_dm[i * (WARP_SIZE/QI4_K) + i / QI4_K + kbxd] = bxi->dm; + const int * scales = (const int *) bxi->scales; + const int ksc = threadIdx.x % (WARP_SIZE/16); + + const int sc32 = unpack_scales_q45_K(scales, ksc + 0); + const int m32 = unpack_scales_q45_K(scales, ksc + 2); + + const uint8_t * sc8 = (const uint8_t *) &sc32; + const uint8_t * m8 = (const uint8_t *) &m32; + + const half2 dm = bxi->dm * make_half2(1.0f, -1.0f); + +#pragma unroll + for (int l = 0; l < sizeof(int); ++l) { + x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]); + } + } + +#else + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*QI4_K) { + int i = (i0 + threadIdx.y*QI4_K + threadIdx.x) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride; + + x_dm[i] = bxi->dm; } #pragma unroll @@ -760,46 +1429,58 @@ template static __device__ __forceinlin const int * scales = (const int *) bxi->scales; const int ksc = threadIdx.x % (WARP_SIZE/8); + const int scales8 = unpack_scales_q45_K(scales, ksc); - // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8 - int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits - scales8 |= (scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits - - x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8; + x_sc[i*(WARP_SIZE/8) + i/8 + ksc] = scales8; } +#endif // INT8_MMA_AVAILABLE } template -static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mul_mat( - const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, - const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) { +static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { - GGML_UNUSED(x_qh); + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y); + const int * x_qs = (const int *) x; + const half2 * x_dm = (const half2 *) x_qs + txs.qs; + const int * x_sc = (const int *) x_dm + txs.dm; + const int * y_qs = (const int *) y + 4; + const half2 * y_ds = (const half2 *) y; + +// #pragma unroll + for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_K*VDR_Q4_K_Q8_1_MMQ) { + const int k0 = k00 + k01; #pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { - const int j = j0 + threadIdx.y; + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; + for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; - const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2*((k0 % 16) / 8); + const uint8_t * sc = (const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/32] + 2*(k01/16); - const int index_y = j * WARP_SIZE + (QR4_K*k0) % WARP_SIZE; - sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_K_q8_1_impl_mmq( - &x_ql[i * (WARP_SIZE + 1) + k0], &y_qs[index_y], sc, sc+8, x_dm[i * (WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[index_y/QI8_1]); + sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_K_q8_1_impl_mmq( + &x_qs[i*(WARP_SIZE + 1) + k0/2], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8, + x_dm[i], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); + } } } } template static __device__ __forceinline__ void load_tiles_q5_K( - const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, - int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) { - GGML_UNUSED(x_qh); + const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { - const int kbx = 0; // threadIdx.x / QI5_K - const int kqsx = threadIdx.x; // threadIdx.x % QI5_K +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + half2 * x_dm = (half2 *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y); + int * x_qs = (int *) x_tile; + half2 * x_dm = (half2 *) (x_qs + txs.qs); + int * x_sc = (int *) (x_dm + txs.dm); +#endif // INT8_MMA_AVAILABLE #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { @@ -809,94 +1490,139 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride + kbx; - const int ky = QR5_K*kqsx; + const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride; + const int ky = QR5_K*threadIdx.x; - const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx); + const int ql = get_int_b4(bxi->qs, threadIdx.x); const int ql0 = (ql >> 0) & 0x0F0F0F0F; const int ql1 = (ql >> 4) & 0x0F0F0F0F; - const int qh = get_int_from_uint8_aligned(bxi->qh, kqsx % (QI5_K/4)); - const int qh0 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 0)) << 4) & 0x10101010; - const int qh1 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 1)) << 4) & 0x10101010; + const int qh = get_int_b4(bxi->qh, threadIdx.x % (QI5_K/4)); + const int qh0 = ((qh >> (2 * (threadIdx.x / (QI5_K/4)) + 0)) << 4) & 0x10101010; + const int qh1 = ((qh >> (2 * (threadIdx.x / (QI5_K/4)) + 1)) << 4) & 0x10101010; const int kq0 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + 0; - const int kq1 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + (QI5_K/4); + const int kq1 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + QI5_K/4; - x_ql[i * (2*WARP_SIZE + 1) + kq0] = ql0 | qh0; - x_ql[i * (2*WARP_SIZE + 1) + kq1] = ql1 | qh1; +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq0] = ql0 | qh0; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq1] = ql1 | qh1; +#else + x_qs[i*(2*WARP_SIZE + 1) + kq0] = ql0 | qh0; + x_qs[i*(2*WARP_SIZE + 1) + kq1] = ql1 | qh1; +#endif // INT8_MMA_AVAILABLE } - const int blocks_per_tile_x_row = WARP_SIZE / QI5_K; // == 1 if QK_K == 256 - const int kbxd = threadIdx.x % blocks_per_tile_x_row; // == 0 if QK_K == 256 +#ifdef INT8_MMA_AVAILABLE #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_K) { - int i = (i0 + threadIdx.y * QI5_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*16) { + int i = (i0 + threadIdx.y*16 + threadIdx.x/(WARP_SIZE/16)) % mmq_y; if (need_check) { i = min(i, i_max); } - const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride + kbxd; + const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride; - x_dm[i * (WARP_SIZE/QI5_K) + i / QI5_K + kbxd] = bxi->dm; - } + const int * scales = (const int *) bxi->scales; + const int ksc = threadIdx.x % (WARP_SIZE/16); + + const int sc32 = unpack_scales_q45_K(scales, ksc + 0); + const int m32 = unpack_scales_q45_K(scales, ksc + 2); + + const uint8_t * sc8 = (const uint8_t *) &sc32; + const uint8_t * m8 = (const uint8_t *) &m32; + + const half2 dm = bxi->dm * make_half2(1.0f, -1.0f); #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) { - int i = (i0 + threadIdx.y * 8 + threadIdx.x / (WARP_SIZE/8)) % mmq_y; + for (int l = 0; l < sizeof(int); ++l) { + x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]); + } + } + +#else + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*QI5_K) { + int i = (i0 + threadIdx.y*QI5_K + threadIdx.x) % mmq_y; if (need_check) { i = min(i, i_max); } - const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / (QI5_K/8); + const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride; + + x_dm[i] = bxi->dm; + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*8) { + int i = (i0 + threadIdx.y*8 + threadIdx.x/(WARP_SIZE/8)) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride; const int * scales = (const int *) bxi->scales; const int ksc = threadIdx.x % (WARP_SIZE/8); + const int scales8 = unpack_scales_q45_K(scales, ksc); - // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8 - int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits - scales8 |= (scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits - - x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8; + x_sc[i*(WARP_SIZE/8) + i/8 + ksc] = scales8; } +#endif // INT8_MMA_AVAILABLE } template -static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mul_mat( - const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, - const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) { +static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { - GGML_UNUSED(x_qh); + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y); + const int * x_qs = (const int *) x; + const half2 * x_dm = (const half2 *) x_qs + txs.qs; + const int * x_sc = (const int *) x_dm + txs.dm; + const int * y_qs = (const int *) y + 4; + const half2 * y_ds = (const half2 *) y; + +// #pragma unroll + for (int k01 = 0; k01 < WARP_SIZE; k01 += QR5_K*VDR_Q5_K_Q8_1_MMQ) { + const int k0 = k00 + k01; #pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { - const int j = j0 + threadIdx.y; + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; + for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; - const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2 * ((k0 % 16) / 8); + const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k00/32]) + 2*(k01/16); - const int index_x = i * (QR5_K*WARP_SIZE + 1) + QR5_K*k0; - const int index_y = j * WARP_SIZE + (QR5_K*k0) % WARP_SIZE; - sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q5_K_q8_1_impl_mmq( - &x_ql[index_x], &y_qs[index_y], sc, sc+8, x_dm[i * (WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[index_y/QI8_1]); + sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q5_K_q8_1_impl_mmq( + &x_qs[i*(QR5_K*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8, + x_dm[i], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); + } } } } template static __device__ __forceinline__ void load_tiles_q6_K( - const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, - int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) { - GGML_UNUSED(x_qh); + const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { - const int kbx = 0; // threadIdx.x / QI6_K - const int kqsx = threadIdx.x; // threadIdx.x % QI6_K +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); + int * x_sc = (int *) (x_df + WARP_SIZE/QI6_K); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); + int * x_sc = (int *) (x_df + txs.dm); +#endif // INT8_MMA_AVAILABLE #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { @@ -906,27 +1632,30 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + kbx; - const int ky = QR6_K*kqsx; + const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride; - const int ql = get_int_from_uint8(bxi->ql, kqsx); + const int ql = get_int_b2(bxi->ql, threadIdx.x); const int ql0 = (ql >> 0) & 0x0F0F0F0F; const int ql1 = (ql >> 4) & 0x0F0F0F0F; - const int qh = get_int_from_uint8(bxi->qh, (QI6_K/4) * (kqsx / (QI6_K/2)) + kqsx % (QI6_K/4)); - const int qh0 = ((qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) << 4) & 0x30303030; - const int qh1 = (qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) & 0x30303030; + const int qh = get_int_b2(bxi->qh, (QI6_K/4) * (threadIdx.x / (QI6_K/2)) + threadIdx.x % (QI6_K/4)); + const int qh0 = ((qh >> ((threadIdx.x & 0x08) >> 2)) << 4) & 0x30303030; + const int qh1 = (qh >> ((threadIdx.x & 0x08) >> 2)) & 0x30303030; - const int kq0 = ky - ky % QI6_K + threadIdx.x % (QI6_K/2) + 0; - const int kq1 = ky - ky % QI6_K + threadIdx.x % (QI6_K/2) + (QI6_K/2); + const int kq0 = 2*threadIdx.x - threadIdx.x % (QI6_K/2) + 0; + const int kq1 = 2*threadIdx.x - threadIdx.x % (QI6_K/2) + QI6_K/2; - x_ql[i * (2*WARP_SIZE + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020); - x_ql[i * (2*WARP_SIZE + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020); +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq0] = __vsubss4(ql0 | qh0, 0x20202020); + x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq1] = __vsubss4(ql1 | qh1, 0x20202020); +#else + x_qs[i*(2*WARP_SIZE + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020); + x_qs[i*(2*WARP_SIZE + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020); +#endif // INT8_MMA_AVAILABLE } const int blocks_per_tile_x_row = WARP_SIZE / QI6_K; // == 1 if QK_K == 256 const int kbxd = threadIdx.x % blocks_per_tile_x_row; // == 0 if QK_K == 256 - float * x_dmf = (float *) x_dm; #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI6_K) { @@ -938,7 +1667,11 @@ template static __device__ __forceinlin const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + kbxd; - x_dmf[i * (WARP_SIZE/QI6_K) + i / QI6_K + kbxd] = bxi->d; +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q6_K + kbxd] = bxi->d; +#else + x_df[i*(WARP_SIZE/QI6_K) + i/QI6_K + kbxd] = bxi->d; +#endif // INT8_MMA_AVAILABLE } #pragma unroll @@ -951,34 +1684,686 @@ template static __device__ __forceinlin const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / 4; - x_sc[i * (WARP_SIZE/8) + i / 8 + threadIdx.x % (WARP_SIZE/8)] = get_int_from_int8(bxi->scales, threadIdx.x % (QI6_K/8)); +#ifdef INT8_MMA_AVAILABLE + x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x % (WARP_SIZE/8)] = get_int_b2(bxi->scales, threadIdx.x % (QI6_K/8)); +#else + x_sc[i*(WARP_SIZE/8) + i/8 + threadIdx.x % (WARP_SIZE/8)] = get_int_b2(bxi->scales, threadIdx.x % (QI6_K/8)); +#endif // INT8_MMA_AVAILABLE } } template -static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mul_mat( - const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, - const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) { +static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { - GGML_UNUSED(x_qh); + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y); + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + txs.qs; + const int * x_sc = (const int *) x_df + txs.dm; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + +// #pragma unroll + for (int k01 = 0; k01 < WARP_SIZE; k01 += QR6_K*VDR_Q6_K_Q8_1_MMQ) { + const int k0 = k00 + k01; + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]); + + sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q6_K_q8_1_impl_mmq( + &x_qs[i*(QR6_K*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc, + x_df[i*(WARP_SIZE/QI6_K) + i/QI6_K], &y_df[j*MMQ_TILE_Y_K + k01/QI8_1]); + } + } + } +} + +template +static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { +#ifdef INT8_MMA_AVAILABLE + + typedef mma_int_A_I16K4 mma_A; + typedef mma_int_B_J8K4 mma_B; + typedef mma_int_C_I16J8 mma_C; + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = 2 * granularity; + constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + WARP_SIZE*2; + const int * x_sc = (const int *) x_df + WARP_SIZE/QI6_K; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + + const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I); + + mma_A A[ntx][8]; + int scA[ntx][mma_C::ne/2][8]; + float dA[ntx][mma_C::ne/2]; + +#pragma unroll + for (int n = 0; n < ntx; ++n) { +#pragma unroll + for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) { + const int k0 = k00 + k01; + + A[n][k01/4 + 0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + 0), MMQ_MMA_TILE_X_K_Q6_K); + A[n][k01/4 + 1].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + mma_A::K), MMQ_MMA_TILE_X_K_Q6_K); + } + +#pragma unroll + for (int k01 = 0; k01 < WARP_SIZE; k01 += 16) { + const int k0 = k00 + k01; + +#pragma unroll + for (int l = 0; l < mma_C::ne/2; ++l) { + const int i = i0 + n*mma_C::I + mma_C::get_i(2*l); + + const int sc_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + k0/16]; + const int8_t * sc = (const int8_t *) &sc_packed; + +#pragma unroll + for (int ksc = 0; ksc < sizeof(int); ++ksc) { + scA[n][l][k01/4 + ksc] = sc[ksc]; + } + } + } + +#pragma unroll + for (int l = 0; l < mma_C::ne/2; ++l) { + const int i = i0 + n*mma_C::I + mma_C::get_i(2*l); + + dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q6_K]; + } + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { + float tmp[ntx][mma_C::ne] = {{0.0f}}; + +#pragma unroll + for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) { + mma_B B[2]; + float dB[mma_C::ne/2]; + + B[0].load(y_qs + j0*MMQ_TILE_Y_K + 0 + k01, MMQ_TILE_Y_K); + B[1].load(y_qs + j0*MMQ_TILE_Y_K + mma_B::K + k01, MMQ_TILE_Y_K); + +#pragma unroll + for (int l = 0; l < mma_C::ne/2; ++l) { + const int j = j0 + mma_C::get_j(l); + + dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1]; + } + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + mma_C C[2]; + C[0].mma_K4(A[n][k01/4 + 0], B[0]); + C[1].mma_K4(A[n][k01/4 + 1], B[1]); + +#pragma unroll + for (int l = 0; l < mma_C::ne; ++l) { + tmp[n][l] += (C[0].x[l]*scA[n][l/2][k01/4 + 0] + C[1].x[l]*scA[n][l/2][k01/4 + 1])*dB[l%2]; + } + } + } + +#pragma unroll + for (int n = 0; n < ntx; ++n) { +#pragma unroll + for (int l = 0; l < mma_C::ne; ++l) { + sum[(j0/mma_C::J + n)*mma_C::ne + l] += tmp[n][l]*dA[n][l/2]; + } + } + } +#else + GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); + NO_DEVICE_CODE; +#endif // INT8_MMA_AVAILABLE +} + +template static __device__ __forceinline__ void load_tiles_iq4_nl( + const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_NL, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + const int kbx = threadIdx.x / QI4_NL; + const int kqsx = threadIdx.x % QI4_NL; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + threadIdx.y; + + if (need_check) { + i = min(i, i_max); + } + + const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbx; + + const int aux_q4 = get_int_b2(bxi->qs, kqsx); + const int2 v = get_int_from_table_16(aux_q4); + const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4; +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y; +#else + x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x; + x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y; +#endif // INT8_MMA_AVAILABLE + } + + const int blocks_per_tile_x_row = WARP_SIZE / QI4_NL; + const int kbxd = threadIdx.x % blocks_per_tile_x_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_NL) { + int i = i0 + threadIdx.y * QI4_NL + threadIdx.x / blocks_per_tile_x_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbxd; + +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = __half2float(bxi->d); +#else + x_df[i*(WARP_SIZE/4) + i/4 + kbxd] = __half2float(bxi->d); +#endif // INT8_MMA_AVAILABLE + } +} + +template static __device__ __forceinline__ void load_tiles_iq2_xxs( + const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_XXS, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + const int kqsx = threadIdx.x % (QI2_XXS/2); + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI2_XXS/2)) { + int i = i0 + threadIdx.y*(2*WARP_SIZE/QI2_XXS) + threadIdx.x/(QI2_XXS/2); + + if (need_check) { + i = min(i, i_max); + } + + const block_iq2_xxs * bxi = (const block_iq2_xxs *) x + kbx0 + i*stride; + + const int q2 = get_int_b2(bxi->qs, 2*kqsx+0); + const uint8_t * aux8 = (const uint8_t *) &q2; + const uint32_t aux32 = get_int_b2(bxi->qs, 2*kqsx+1); + +#pragma unroll + for (int l = 0; l < QR2_XXS; ++l) { + const int * grid_pos = (const int *) (iq2xxs_grid + aux8[l]); + const int signs_packed = ksigns_iq2xs[(aux32 >> (7*l)) & 0x7F]; + + const int signs0 = __vcmpne4(((signs_packed & 0x03) << 7) | ((signs_packed & 0x0C) << 21), 0x00000000); + const int grid0 = __vsub4(grid_pos[0] ^ signs0, signs0); + + const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000); + const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1); + +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid0; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid1; +#else + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid0; + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid1; +#endif // INT8_MMA_AVAILABLE + } + + const int ls = aux32 >> 28; + const float d = bxi->d; +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/4; +#else + x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = (ls*d + d/2)/4; +#endif // INT8_MMA_AVAILABLE + } +} + +template static __device__ __forceinline__ void load_tiles_iq2_xs( + const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16; + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + const int kqsx = threadIdx.x % (QI2_XS/2); + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI2_XS/2)) { + int i = i0 + threadIdx.y*(2*WARP_SIZE/QI2_XS) + threadIdx.x/(QI2_XS/2); + + if (need_check) { + i = min(i, i_max); + } + + const block_iq2_xs * bxi = (const block_iq2_xs *) x + kbx0 + i*stride; + + const int2 q2_packed = make_int2(get_int_b2(bxi->qs, 2*kqsx+0), get_int_b2(bxi->qs, 2*kqsx+1)); + const uint16_t * q2 = (const uint16_t *) &q2_packed; + + #pragma unroll + for (int l = 0; l < QR2_XS; ++l) { + const uint32_t * grid_pos = (const uint32_t *)(iq2xs_grid + (q2[l] & 0x000001FF)); + const uint32_t * signs = (const uint32_t *)(ksigns64 + (q2[l] >> 9)); + + const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]); + const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]); + +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l; + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h; +#else + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid_l; + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid_h; +#endif // INT8_MMA_AVAILABLE + } + + const int ls = bxi->scales[kqsx]; + const float d = bxi->d; +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; +#else + x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; + x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; +#endif // INT8_MMA_AVAILABLE + } +} + +template static __device__ __forceinline__ void load_tiles_iq2_s( + const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_S, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + const int kqsx = threadIdx.x % (QI2_S/2); + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI2_S/2)) { + int i = i0 + threadIdx.y*(2*WARP_SIZE/QI2_S) + threadIdx.x/(QI2_S/2); + + if (need_check) { + i = min(i, i_max); + } + + const block_iq2_s * bxi = (const block_iq2_s *) x + kbx0 + i*stride; + + const int qs_packed = get_int_b2(bxi->qs, kqsx); + const uint8_t * qs = (const uint8_t *) &qs_packed; + + const int qh = bxi->qh[kqsx]; + + const int signs_packed_32 = get_int_b2(bxi->qs, QK_K/32 + kqsx); + const uint8_t * signs_packed_8 = (const uint8_t *) &signs_packed_32; + +#pragma unroll + for (int l = 0; l < QR2_S; ++l) { + const int * grid_pos = (const int *)(iq2s_grid + (qs[l] | ((qh << (8-2*l)) & 0x300))); + + const int signs0 = __vcmpne4(((signs_packed_8[l] & 0x03) << 7) | ((signs_packed_8[l] & 0x0C) << 21), 0x00000000); + const int signs1 = __vcmpne4(((signs_packed_8[l] & 0x30) << 3) | ((signs_packed_8[l] & 0xC0) << 17), 0x00000000); + + const int grid_l = __vsub4(grid_pos[0] ^ signs0, signs0); + const int grid_h = __vsub4(grid_pos[1] ^ signs1, signs1); + +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l; + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h; +#else + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid_l; + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid_h; +#endif // INT8_MMA_AVAILABLE + } + + const int ls = bxi->scales[kqsx]; + const float d = bxi->d; +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; +#else + x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; + x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; +#endif // INT8_MMA_AVAILABLE + } +} + +template static __device__ __forceinline__ void load_tiles_iq3_xxs( + const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_XXS, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + const int kqsx = threadIdx.x % (QI3_XXS/2); + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI3_XXS/2)) { + int i = i0 + threadIdx.y*(2*WARP_SIZE/QI3_XXS) + threadIdx.x/(QI3_XXS/2); + + if (need_check) { + i = min(i, i_max); + } + + const block_iq3_xxs * bxi = (const block_iq3_xxs *) x + kbx0 + i*stride; + + const int2 q3_packed = make_int2(get_int_b2(bxi->qs, 2*kqsx+0), get_int_b2(bxi->qs, 2*kqsx+1)); + const uint8_t * q3 = (const uint8_t *) &q3_packed; + const uint32_t aux32 = get_int_b2(bxi->qs, QK_K/16 + kqsx); + +#pragma unroll + for (int l = 0; l < QR3_XXS; ++l) { + const int2 grid_pos = make_int2(iq3xxs_grid[q3[2*l+0]], iq3xxs_grid[q3[2*l+1]]); + + const int * signs = (const int *)(ksigns64 + ((aux32 >> (7*l)) & 0x7F)); + + const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]); + const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]); + +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid_l; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid_h; +#else + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid_l; + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid_h; +#endif // INT8_MMA_AVAILABLE + } + + const int ls = aux32 >> 28; + const float d = bxi->d; +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/2; +#else + x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = (ls*d + d/2)/2; +#endif // INT8_MMA_AVAILABLE + } +} + +template static __device__ __forceinline__ void load_tiles_iq3_s( + const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + const int kqsx = threadIdx.x % (QI3_S/2); + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI3_S/2)) { + int i = i0 + threadIdx.y*(2*WARP_SIZE/QI3_S) + threadIdx.x/(QI3_S/2); + + if (need_check) { + i = min(i, i_max); + } + + const block_iq3_s * bxi = (const block_iq3_s *) x + kbx0 + i*stride; + + const int2 qs_packed = make_int2(get_int_b2(bxi->qs, 2*kqsx+0), get_int_b2(bxi->qs, 2*kqsx+1)); + const uint8_t * qs = (const uint8_t *) &qs_packed; + + const int qh = bxi->qh[kqsx]; + + const int signs_packed_32 = get_int_b2(bxi->signs, kqsx); + const uint8_t * signs_packed_8 = (const uint8_t *) &signs_packed_32; + +#pragma unroll + for (int l = 0; l < QR3_S; ++l) { + const int2 grid_pos = make_int2( + iq3s_grid[qs[2*l+0] | ((qh << (8 - 2*l)) & 0x100)], + iq3s_grid[qs[2*l+1] | ((qh << (7 - 2*l)) & 0x100)]); + + const int signs0 = __vcmpne4(((signs_packed_8[l] & 0x03) << 7) | ((signs_packed_8[l] & 0x0C) << 21), 0x00000000); + const int signs1 = __vcmpne4(((signs_packed_8[l] & 0x30) << 3) | ((signs_packed_8[l] & 0xC0) << 17), 0x00000000); + + const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0); + const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1); + +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+0)] = grid_l; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+1)] = grid_h; +#else + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+0)] = grid_l; + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+1)] = grid_h; +#endif // INT8_MMA_AVAILABLE + } + + const int ls = 1 + 2*((bxi->scales[kqsx/2] >> (((2*kqsx) << 1) & 0x04)) & 0x0F); + const float d = bxi->d; +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = ls*d; +#else + x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = ls*d; +#endif // INT8_MMA_AVAILABLE + } +} + +template static __device__ __forceinline__ void load_tiles_iq1_s( + const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + half2 * x_ds = (half2 *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y); + int * x_qs = (int *) x_tile; + half2 * x_ds = (half2 *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + const int kqsx = threadIdx.x % QI1_S; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/QI1_S) { + int i = i0 + threadIdx.y*(WARP_SIZE/QI1_S) + threadIdx.x/QI1_S; + + if (need_check) { + i = min(i, i_max); + } + + const block_iq1_s * bxi = (const block_iq1_s *) x + kbx0 + i*stride; + + const int qs_packed = get_int_b2(bxi->qs, kqsx); + const uint8_t * qs = (const uint8_t *) &qs_packed; + + const int qh = bxi->qh[kqsx]; + + #pragma unroll + for (int l = 0; l < QR1_S/2; ++l) { + const int grid = iq1s_grid_gpu[qs[l] | (((qh >> (3*l)) & 0x07) << 8)]; + + const int grid0 = (grid >> 0) & 0x0F0F0F0F; + const int grid1 = (grid >> 4) & 0x0F0F0F0F; + +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+0)] = grid0; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+1)] = grid1; +#else + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+0)] = grid0; + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+1)] = grid1; +#endif // INT8_MMA_AVAILABLE + } + + const float d1q = __half2float(bxi->d) * (((qh >> 11) & 0x0E) + 1); + const float delta = -1.0f + IQ1S_DELTA - (qh & 0x8000) * (2.0f*IQ1S_DELTA/0x8000); + +#ifdef INT8_MMA_AVAILABLE + x_ds[i*MMQ_MMA_TILE_X_K_Q8_1 + kqsx] = make_half2(d1q, d1q*delta); +#else + x_ds[i*(WARP_SIZE/4) + i/4 + kqsx] = make_half2(d1q, d1q*delta); +#endif // INT8_MMA_AVAILABLE + } +} + +template static __device__ __forceinline__ void load_tiles_iq4_xs( + const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + const int kbx = 0; // threadIdx.x / QI4_XS + const int kqsx = threadIdx.x; // threadIdx.x % QI4_XS + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + threadIdx.y; + + if (need_check) { + i = min(i, i_max); + } + + const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride + kbx; + + const int aux_q4 = get_int_b4(bxi->qs, kqsx); + const int2 v = get_int_from_table_16(aux_q4); + const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4; +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y; +#else + x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x; + x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y; +#endif // INT8_MMA_AVAILABLE + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) { + int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4); + + if (need_check) { + i = min(i, i_max); + } + + const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride; + + const float d = __half2float(bxi->d); + + const int ls = ((bxi->scales_l[(threadIdx.x % 8)/2] >> (4*(threadIdx.x % 2))) & 0x0F) + | (((bxi->scales_h >> (2*(threadIdx.x % 8))) & 0x03) << 4); + +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * (ls - 32); +#else + x_df[i*(WARP_SIZE/4) + i/4 + threadIdx.x % 8] = d * (ls - 32); +#endif // INT8_MMA_AVAILABLE + } +} + +template +static __device__ __forceinline__ void mmq_write_back_dp4a( + const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max) { #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { const int j = j0 + threadIdx.y; + if (j > j_max) { + return; + } + #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; - const float * x_dmf = (const float *) x_dm; - const float * y_df = (const float *) y_ds; + if (need_check && i > i_max) { + continue; + } - const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/8]); + dst[j*stride + i] = sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE]; + } + } +} - const int index_x = i * (QR6_K*WARP_SIZE + 1) + QR6_K*k0; - const int index_y = j * WARP_SIZE + (QR6_K*k0) % WARP_SIZE; - sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q6_K_q8_1_impl_mmq( - &x_ql[index_x], &y_qs[index_y], sc, x_dmf[i * (WARP_SIZE/QI6_K) + i/QI6_K], &y_df[index_y/QI8_1]); +template +static __device__ __forceinline__ void mmq_write_back_mma( + const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max) { + + typedef mma_int_C_I16J8 mma_C; + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = 2 * granularity; + constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp. + + const int i0 = (threadIdx.y / ntx) * (ntx*mma_C::I); +#ifdef INT8_MMA_AVAILABLE + static_assert(nwarps*mma_C::I == mmq_y, "nwarps*mma_C::I != mmq_y"); +#endif // INT8_MMA_AVAILABLE + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { +#pragma unroll + for (int n = 0; n < ntx; ++n) { +#pragma unroll + for (int l = 0; l < mma_C::ne; ++l) { + const int j = j0 + (threadIdx.y % ntx) * mma_C::J + mma_C::get_j(l); + + if (j > j_max) { + continue; + } + + const int i = i0 + n*mma_C::I + mma_C::get_i(l); + + if (need_check && i > i_max) { + continue; + } + + dst[j*stride + i] = sum[(j0/mma_C::J + n)*mma_C::ne + l]; + } } } } @@ -990,84 +2375,225 @@ struct mmq_type_traits; template struct mmq_type_traits { - static constexpr bool need_sum = true; - static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0; - static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_0_q8_1_mul_mat; + static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_0_q8_1_dp4a; }; template struct mmq_type_traits { - static constexpr bool need_sum = true; - static constexpr int vdr = VDR_Q4_1_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1; - static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_1_q8_1_mul_mat; + static constexpr int vdr = VDR_Q4_1_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_1_q8_1_dp4a; }; template struct mmq_type_traits { - static constexpr bool need_sum = false; - static constexpr int vdr = VDR_Q5_0_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0; - static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_0_q8_1_mul_mat; + static constexpr int vdr = VDR_Q5_0_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; }; template struct mmq_type_traits { - static constexpr bool need_sum = true; - static constexpr int vdr = VDR_Q5_1_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1; - static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_1_q8_1_mul_mat; + static constexpr int vdr = VDR_Q5_1_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a; }; template struct mmq_type_traits { - static constexpr bool need_sum = false; - static constexpr int vdr = VDR_Q8_0_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0; - static constexpr vec_dot_mmq_t vec_dot = vec_dot_q8_0_q8_1_mul_mat; + static constexpr int vdr = VDR_Q8_0_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; }; template struct mmq_type_traits { - static constexpr bool need_sum = false; - static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_K; - static constexpr vec_dot_mmq_t vec_dot = vec_dot_q2_K_q8_1_mul_mat; + static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_K; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q2_K_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q2_K_q8_1_dp4a; }; template struct mmq_type_traits { - static constexpr bool need_sum = false; - static constexpr int vdr = VDR_Q3_K_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K; - static constexpr vec_dot_mmq_t vec_dot = vec_dot_q3_K_q8_1_mul_mat; + static constexpr int vdr = VDR_Q3_K_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q3_K_q8_1_dp4a; }; template struct mmq_type_traits { - static constexpr bool need_sum = true; - static constexpr int vdr = VDR_Q4_K_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K; - static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_K_q8_1_mul_mat; + static constexpr int vdr = VDR_Q4_K_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_K_q8_1_dp4a; }; template struct mmq_type_traits { - static constexpr bool need_sum = true; - static constexpr int vdr = VDR_Q5_K_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K; - static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_K_q8_1_mul_mat; + static constexpr int vdr = VDR_Q5_K_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_K_q8_1_dp4a; }; template struct mmq_type_traits { - static constexpr bool need_sum = false; - static constexpr int vdr = VDR_Q6_K_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K; - static constexpr vec_dot_mmq_t vec_dot = vec_dot_q6_K_q8_1_mul_mat; + static constexpr int vdr = VDR_Q6_K_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q6_K_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q6_K_q8_1_dp4a; }; +template +struct mmq_type_traits { + static constexpr int vdr = VDR_IQ2_XXS_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xxs; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; +}; + +template +struct mmq_type_traits { + static constexpr int vdr = VDR_IQ2_XS_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xs; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a; +}; + +template +struct mmq_type_traits { + static constexpr int vdr = VDR_IQ2_S_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_s; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a; +}; + +template +struct mmq_type_traits { + static constexpr int vdr = VDR_IQ3_XXS_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_xxs; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; +}; + +template +struct mmq_type_traits { + static constexpr int vdr = VDR_IQ3_S_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_s; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; +}; + +template +struct mmq_type_traits { + static constexpr int vdr = VDR_IQ1_S_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq1_s; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a; +}; + +template +struct mmq_type_traits { + static constexpr int vdr = VDR_IQ4_NL_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_nl; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; +}; + +template +struct mmq_type_traits { + static constexpr int vdr = VDR_IQ4_XS_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_xs; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; +}; + +template +static __device__ void mul_mat_q_process_tile( + const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst, float * __restrict__ tmp_fixup, + const int & ne00, const int & ne01, const int & stride01, const int & ne10, const int & ne11, const int & stride11, const int & ne0, + const int & it, const int & jt, const int & kb0_start, const int & kb0_stop) { + + constexpr int qk = ggml_cuda_type_traits::qk; + constexpr int mmq_y = get_mmq_y_device(); + constexpr load_tiles_mmq_t load_tiles = mmq_type_traits::load_tiles; + + extern __shared__ char data_mul_mat_q[]; + int * tile_y = (int *) data_mul_mat_q; + int * tile_x = tile_y + GGML_PAD(mmq_x*(WARP_SIZE + WARP_SIZE/QI8_1), nwarps*WARP_SIZE); + +#ifdef INT8_MMA_AVAILABLE + constexpr vec_dot_mmq_t vec_dot = mmq_type_traits::vec_dot_mma; + constexpr mmq_write_back_t write_back = mmq_write_back_mma; +#else + constexpr vec_dot_mmq_t vec_dot = mmq_type_traits::vec_dot_dp4a; + constexpr mmq_write_back_t write_back = mmq_write_back_dp4a; +#endif // INT8_MMA_AVAILABLE + + constexpr int blocks_per_iter = MMQ_ITER_K / qk; + + float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f}; + + const int tile_x_max_i = ne01 - it*mmq_y - 1; + const int tile_y_max_j = ne11 - jt*mmq_x - 1; + + const int * y = (const int *) yc + jt*(mmq_x*sizeof(block_q8_1_mmq)/sizeof(int)); + + for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_iter) { + load_tiles(x, tile_x, stride01*it*mmq_y + kb0, tile_x_max_i, stride01); + + { + const int * by0 = y + stride11*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 0*sizeof(block_q8_1_mmq)/sizeof(int)); +#pragma unroll + for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*WARP_SIZE) { + int l = l0 + threadIdx.y*WARP_SIZE + threadIdx.x; + + tile_y[l] = by0[l]; + } + } + + __syncthreads(); + + vec_dot(tile_x, tile_y, sum, 0); + + __syncthreads(); + + { + const int * by0 = y + stride11*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 1*sizeof(block_q8_1_mmq)/sizeof(int)); +#pragma unroll + for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*WARP_SIZE) { + int l = l0 + threadIdx.y*WARP_SIZE + threadIdx.x; + + tile_y[l] = by0[l]; + } + } + + __syncthreads(); + + vec_dot(tile_x, tile_y, sum, WARP_SIZE); + + __syncthreads(); + } + + if (fixup) { + write_back(sum, tmp_fixup + blockIdx.x*(mmq_x*mmq_y), mmq_y, mmq_y, mmq_x); + } else { + write_back(sum, dst + jt*mmq_x*ne0 + it*mmq_y, ne0, tile_x_max_i, tile_y_max_j); + } +} + + +// The mul_mat_q kernel implements "stream-k" work partitioning as described in https://arxiv.org/abs/2301.03598 + template #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) #if defined(RDNA3) || defined(RDNA2) @@ -1077,241 +2603,334 @@ template #if __CUDA_ARCH__ >= CC_VOLTA __launch_bounds__(WARP_SIZE*nwarps, 1) #else - __launch_bounds__(WARP_SIZE*nwarps, type == GGML_TYPE_Q2_K ? 1 : 2) + __launch_bounds__(WARP_SIZE*nwarps, 2) #endif // __CUDA_ARCH__ >= CC_VOLTA #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) static __global__ void mul_mat_q( - const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst, - const int ne00, const int ne01, const int stride00, const int ne10, const int ne11, const int ne0) { + const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst, float * __restrict__ tmp_fixup, + const int ne00, const int ne01, const int stride01, const int ne10, const int ne11, const int stride11, const int ne0) { // Skip unused template specializations for faster compilation: - if (mmq_x > get_mmq_x_max_device()) { + if (mmq_x > get_mmq_x_max_device() || mmq_x % mmq_get_granularity_device(mmq_x) != 0) { NO_DEVICE_CODE; return; } - constexpr int qk = ggml_cuda_type_traits::qk; - constexpr int qr = ggml_cuda_type_traits::qr; - constexpr int qi = ggml_cuda_type_traits::qi; - constexpr int mmq_y = get_mmq_y_device(mmq_x); - constexpr bool need_sum = mmq_type_traits::need_sum; - constexpr int vdr = mmq_type_traits::vdr; - constexpr load_tiles_mmq_t load_tiles = mmq_type_traits::load_tiles; - constexpr vec_dot_mmq_t vec_dot = mmq_type_traits::vec_dot; + constexpr int qk = ggml_cuda_type_traits::qk; + constexpr int mmq_y = get_mmq_y_device(); - constexpr tile_x_sizes txs = get_tile_x_sizes_device(type); + // On AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead: +#if (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < CC_VOLTA + { + constexpr bool fixup = false; + mul_mat_q_process_tile + (x, yc, dst, tmp_fixup, ne00, ne01, stride01, ne10, ne11, stride11, ne0, + blockIdx.x, blockIdx.y, 0, ne00/qk); + return; + } +#endif // (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < CC_VOLTA - extern __shared__ char data_mul_mat_q[]; - int * tile_x_ql = (int *) data_mul_mat_q; - half2 * tile_x_dm = (half2 *) (tile_x_ql + txs.ql); - int * tile_x_qh = (int *) (tile_x_dm + txs.dm); - int * tile_x_sc = (int *) (tile_x_qh + txs.qh); - int * tile_y_qs = (int *) (tile_x_sc + txs.sc); // [mmq_x * WARP_SIZE] - half2 * tile_y_ds = (half2 *) (tile_y_qs + mmq_x*WARP_SIZE); // [mmq_x * WARP_SIZE/QI8_1]; + const int64_t blocks_per_ne00 = ne00 / qk; + constexpr int blocks_per_iter = MMQ_ITER_K / qk; - const block_q8_1 * y = (const block_q8_1 *) yc; + const int ntx = (ne11 + mmq_x - 1) / mmq_x; // Number of tiles x + const int nty = (ne01 + mmq_y - 1) / mmq_y; // Number of tiles y - const int blocks_per_row_x = ne00 / qk; - const int blocks_per_col_y = ne10 / QK8_1; - const int blocks_per_warp = WARP_SIZE / qi; + // kbc == k block continuous, current index in continuous ijk space. + int64_t kbc = (int64_t) blockIdx.x *blocks_per_ne00*ntx*nty / gridDim.x; + int64_t kbc_stop = (int64_t)(blockIdx.x + 1)*blocks_per_ne00*ntx*nty / gridDim.x; - const int & ne1 = ne11; + kbc -= (kbc % blocks_per_ne00) % blocks_per_iter; + kbc_stop -= (kbc_stop % blocks_per_ne00) % blocks_per_iter; - const int tile_x_max_i = ne01 - blockIdx.x*mmq_y - 1; + // kb0 == k index when doing the matrix multiplication for an output tile. + int kb0_start = kbc % blocks_per_ne00; + int kb0_stop = min(blocks_per_ne00, kb0_start + kbc_stop - kbc); + while (kbc < kbc_stop && kb0_stop == blocks_per_ne00) { + const int jt = kbc / (blocks_per_ne00*nty); // j index of current tile. + const int it = (kbc - jt*(blocks_per_ne00*nty)) / blocks_per_ne00; // i index of current tile. - float sum[(mmq_x/nwarps) * (mmq_y/WARP_SIZE)] = {0.0f}; + constexpr bool fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer. + mul_mat_q_process_tile + (x, yc, dst, tmp_fixup, ne00, ne01, stride01, ne10, ne11, stride11, ne0, + it, jt, kb0_start, kb0_stop); - for (int kb0 = 0; kb0 < blocks_per_row_x; kb0 += blocks_per_warp) { + kbc += blocks_per_ne00; + kbc -= kbc % blocks_per_ne00; - load_tiles(x, tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, stride00*blockIdx.x*mmq_y + kb0, tile_x_max_i, stride00); + kb0_start = 0; + kb0_stop = min(blocks_per_ne00, kbc_stop - kbc); + } + + if (kbc >= kbc_stop) { + return; + } + + const int jt = kbc / (blocks_per_ne00*nty); + const int it = (kbc - jt*(blocks_per_ne00*nty)) / blocks_per_ne00; + + constexpr bool fixup = true; // Last index writes it data to fixup buffer to avoid data races with other blocks. + mul_mat_q_process_tile + (x, yc, dst, tmp_fixup, ne00, ne01, stride01, ne10, ne11, stride11, ne0, + it, jt, kb0_start, kb0_stop); +} + + +template +static __global__ void mul_mat_q_stream_k_fixup( + float * __restrict__ dst, const float * __restrict__ tmp_last_tile, const int ne00, const int ne01, const int ne11, const int ne0, const int block_num_mmq) { + + constexpr int mmq_y = get_mmq_y_device(); + constexpr int qk = ggml_cuda_type_traits::qk; + constexpr int blocks_per_iter = MMQ_ITER_K / qk; + const int64_t blocks_per_ne00 = ne00 / qk; + + float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f}; + + const int ntx = (ne11 + mmq_x - 1) / mmq_x; + const int nty = (ne01 + mmq_y - 1) / mmq_y; + + bool any_fixup = false; + + const int bidx_start = ((blockIdx.y*nty + blockIdx.x) * block_num_mmq) / (gridDim.y*gridDim.x); + const int bidx_stop = ((blockIdx.y*nty + blockIdx.x + 1) * block_num_mmq + gridDim.y*gridDim.x - 1) / (gridDim.y*gridDim.x); + + int64_t kbc_0; + int64_t kbc_stop_0 = (int64_t) bidx_start*blocks_per_ne00*ntx*nty / block_num_mmq; + + for (int bidx = bidx_start; bidx < bidx_stop; ++bidx) { + kbc_0 = kbc_stop_0; + kbc_stop_0 = (int64_t) (bidx + 1)*blocks_per_ne00*ntx*nty / block_num_mmq; + + const int64_t kbc = kbc_0 - (kbc_0 % blocks_per_ne00) % blocks_per_iter; + const int64_t kbc_stop = kbc_stop_0 - (kbc_stop_0 % blocks_per_ne00) % blocks_per_iter; + + // Skip fixup tile if the MMQ CUDA block never wrote anything to it: + if (kbc == kbc_stop || kbc_stop % blocks_per_ne00 == 0) { + continue; + } + + const int jt = kbc_stop / (blocks_per_ne00*nty); + const int it = (kbc_stop - jt*(blocks_per_ne00*nty)) / blocks_per_ne00; + + // Skip fixup tile if it's unrelated to the output tile assigned to this CUDA block: + if (it != blockIdx.x || jt != blockIdx.y) { + continue; + } + + any_fixup = true; #pragma unroll - for (int kr = 0; kr < qr; ++kr) { - const int kqs = kr*WARP_SIZE + threadIdx.x; - const int kbxd = kqs / QI8_1; + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; #pragma unroll - for (int i0 = 0; i0 < mmq_x; i0 += nwarps) { - const int i = min(blockIdx.y*mmq_x + threadIdx.y + i0, ne11-1); // to prevent out-of-bounds memory accesses + for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; - const block_q8_1 * by0 = &y[i*blocks_per_col_y + kb0 * (qk/QK8_1) + kbxd]; - - const int index_y = (i0 + threadIdx.y) * WARP_SIZE + kqs % WARP_SIZE; - tile_y_qs[index_y] = get_int_from_int8_aligned(by0->qs, threadIdx.x % QI8_1); + sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE] += tmp_last_tile[bidx*(mmq_x*mmq_y) + j*mmq_y + i]; } - -#pragma unroll - for (int ids0 = 0; ids0 < mmq_x; ids0 += nwarps * QI8_1) { - const int ids = (ids0 + threadIdx.y * QI8_1 + threadIdx.x / (WARP_SIZE/QI8_1)) % mmq_x; - const int kby = threadIdx.x % (WARP_SIZE/QI8_1); - const int i_y_eff = min(blockIdx.y*mmq_x + ids, ne11-1); - - // if the sum is not needed it's faster to transform the scale to f32 ahead of time - const half2 * dsi_src = &y[i_y_eff*blocks_per_col_y + kb0 * (qk/QK8_1) + kr*(WARP_SIZE/QI8_1) + kby].ds; - half2 * dsi_dst = &tile_y_ds[ids * (WARP_SIZE/QI8_1) + kby]; - if (need_sum) { - *dsi_dst = *dsi_src; - } else { - float * dfi_dst = (float *) dsi_dst; - *dfi_dst = __low2float(*dsi_src); - } - } - - __syncthreads(); - -// #pragma unroll // unrolling this loop causes too much register pressure - for (int k0 = kr*WARP_SIZE/qr; k0 < (kr+1)*WARP_SIZE/qr; k0 += vdr) { - vec_dot(tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, tile_y_qs, tile_y_ds, sum, k0); - } - - __syncthreads(); } } + if (!any_fixup) { + return; + } + + dst += blockIdx.y*mmq_x*ne0 + blockIdx.x*mmq_y; + + const int i_max = ne01 - blockIdx.x*mmq_y - 1; + const int j_max = ne11 - blockIdx.y*mmq_x - 1; + #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { - const int j = blockIdx.y*mmq_x + j0 + threadIdx.y; + const int j = j0 + threadIdx.y; - if (j >= ne1) { + if (j > j_max) { return; } #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { - const int i = blockIdx.x*mmq_y + i0 + threadIdx.x; + const int i = i0 + threadIdx.x; - if (need_check && i >= ne0) { + if (need_check && i > i_max) { continue; } - dst[j*ne0 + i] = sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE]; + dst[j*ne0 + i] += sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE]; } } } struct mmq_args { const char * x; const char * y; float * dst; - int64_t ne00; int64_t ne01; int64_t stride00; - int64_t ne10; int64_t ne11; + int64_t ne00; int64_t ne01; int64_t stride01; + int64_t ne10; int64_t ne11; int64_t stride11; int64_t ne0; }; -template -static void launch_mul_mat_q(const mmq_args & args, cudaStream_t stream) { +template +static int mmq_get_shmem(const int mmq_x, const int mmq_y, const int cc) { + const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y); + const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type); + const int shmem_x = int8_mma_available(cc) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int); + const int shmem_y = mmq_x*sizeof(block_q8_1_mmq); + return shmem_x + GGML_PAD(shmem_y, MMQ_NWARPS*WARP_SIZE*sizeof(int)); +} + +template +static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) { const int id = ggml_cuda_get_device(); const int cc = ggml_cuda_info().devices[id].cc; - const int mmq_y = get_mmq_y_host(cc, mmq_x); + const int nsm = ggml_cuda_info().devices[id].nsm; + const int mmq_y = get_mmq_y_host(cc); - const int block_num_x = (args.ne01 + mmq_y - 1) / mmq_y; - const int block_num_y = (args.ne11 + mmq_x - 1) / mmq_x; - const dim3 block_nums(block_num_x, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, nwarps, 1); + const dim3 block_dims(WARP_SIZE, MMQ_NWARPS, 1); - const tile_x_sizes txs = get_tile_x_sizes_host(type, mmq_y); - const int shmem_x = txs.ql*sizeof(int) + txs.dm*sizeof(half2) + txs.qh*sizeof(int) + txs.sc*sizeof(int); - const int shmem_y = mmq_x*WARP_SIZE*sizeof(int) + mmq_x*(WARP_SIZE/QI8_1)*sizeof(half2); - const int shmem = shmem_x + shmem_y; + const int shmem = mmq_get_shmem(mmq_x, mmq_y, cc); #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) static bool shmem_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; if (!shmem_limit_raised[id]) { - CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem)); - CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem)); + CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem)); + CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem)); shmem_limit_raised[id] = true; } #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) + const int nty = (args.ne01 + mmq_y - 1) / mmq_y; + const int ntx = (args.ne11 + mmq_x - 1) / mmq_x; + const dim3 block_nums_xy_tiling(nty, ntx, 1); + + const bool use_stream_k = cc >= CC_VOLTA && cc < CC_OFFSET_AMD; + if (!use_stream_k) { + if (args.ne01 % mmq_y == 0) { + constexpr bool need_check = false; + mul_mat_q<<>> + (args.x, args.y, args.dst, nullptr, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0); + } else { + constexpr bool need_check = true; + mul_mat_q<<>> + (args.x, args.y, args.dst, nullptr, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0); + } + return; + } + + const dim3 block_nums_mmq(nsm, 1, 1); + + ggml_cuda_pool & pool = ctx.pool(id); + ggml_cuda_pool_alloc tmp_fixup(pool, block_nums_mmq.x * mmq_x*mmq_y); + if (args.ne01 % mmq_y == 0) { - const bool need_check = false; - mul_mat_q<<>> - (args.x, args.y, args.dst, args.ne00, args.ne01, args.stride00, args.ne10, args.ne11, args.ne0); + constexpr bool need_check = false; + + mul_mat_q<<>> + (args.x, args.y, args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0); + + mul_mat_q_stream_k_fixup<<>> + (args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.ne11, args.ne0, block_nums_mmq.x); } else { - const bool need_check = true; - mul_mat_q<<>> - (args.x, args.y, args.dst, args.ne00, args.ne01, args.stride00, args.ne10, args.ne11, args.ne0); + constexpr bool need_check = true; + + mul_mat_q<<>> + (args.x, args.y, args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0); + + mul_mat_q_stream_k_fixup<<>> + (args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.ne11, args.ne0, block_nums_mmq.x); } } template -void mul_mat_q_case(const mmq_args & args, cudaStream_t stream) { - const int id = ggml_cuda_get_device(); - const int nsm = ggml_cuda_info().devices[id].nsm; - const int cc = ggml_cuda_info().devices[id].cc; +void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) { + const int id = ggml_cuda_get_device(); + const int nsm = ggml_cuda_info().devices[id].nsm; + const int cc = ggml_cuda_info().devices[id].cc; + const int smpbo = ggml_cuda_info().devices[id].smpbo; const int mmq_x_max = get_mmq_x_max_host(cc); - const int mmq_y = get_mmq_y_host(cc, mmq_x_max); + const int mmq_y = get_mmq_y_host(cc); const int block_num_y = (args.ne01 + mmq_y - 1) / mmq_y; + const bool use_stream_k = cc >= CC_VOLTA && cc < CC_OFFSET_AMD; int mmq_x_best = 0; - int nwaves_best = INT_MAX; + int nparts_best = INT_MAX; - for (int mmq_x = 8; mmq_x <= mmq_x_max && nwaves_best > 1; mmq_x += 8) { - const int block_num_x = (args.ne11 + mmq_x - 1) / mmq_x; - const int nwaves = (block_num_x*block_num_y + nsm - 1) / nsm; + for (int mmq_x = 8; mmq_x <= mmq_x_max && nparts_best > 1; mmq_x += 8) { + const int granularity = mmq_get_granularity_host(mmq_x, cc); - if (nwaves < nwaves_best) { + if (mmq_x % granularity != 0 || mmq_get_shmem(mmq_x, mmq_y, cc) > smpbo) { + continue; + } + + const int ntiles_x = (args.ne11 + mmq_x - 1) / mmq_x; + const int nwaves_xy_tiling = ntiles_x*block_num_y; + const int nparts = use_stream_k ? ntiles_x : nwaves_xy_tiling; + + if (nparts < nparts_best) { mmq_x_best = mmq_x; - nwaves_best = nwaves; + nparts_best = nparts; } } switch (mmq_x_best) { case 8: - launch_mul_mat_q(args, stream); + launch_mul_mat_q(ctx, args, stream); break; case 16: - launch_mul_mat_q(args, stream); + launch_mul_mat_q(ctx, args, stream); break; case 24: - launch_mul_mat_q(args, stream); + launch_mul_mat_q(ctx, args, stream); break; case 32: - launch_mul_mat_q(args, stream); + launch_mul_mat_q(ctx, args, stream); break; case 40: - launch_mul_mat_q(args, stream); + launch_mul_mat_q(ctx, args, stream); break; case 48: - launch_mul_mat_q(args, stream); + launch_mul_mat_q(ctx, args, stream); break; case 56: - launch_mul_mat_q(args, stream); + launch_mul_mat_q(ctx, args, stream); break; case 64: - launch_mul_mat_q(args, stream); + launch_mul_mat_q(ctx, args, stream); break; case 72: - launch_mul_mat_q(args, stream); + launch_mul_mat_q(ctx, args, stream); break; case 80: - launch_mul_mat_q(args, stream); + launch_mul_mat_q(ctx, args, stream); break; case 88: - launch_mul_mat_q(args, stream); + launch_mul_mat_q(ctx, args, stream); break; case 96: - launch_mul_mat_q(args, stream); + launch_mul_mat_q(ctx, args, stream); break; case 104: - launch_mul_mat_q(args, stream); + launch_mul_mat_q(ctx, args, stream); break; case 112: - launch_mul_mat_q(args, stream); + launch_mul_mat_q(ctx, args, stream); break; case 120: - launch_mul_mat_q(args, stream); + launch_mul_mat_q(ctx, args, stream); break; case 128: - launch_mul_mat_q(args, stream); + launch_mul_mat_q(ctx, args, stream); break; default: - GGML_ASSERT(false); + fprintf(stderr, "mmq_x_best=%d\n", mmq_x_best); + GGML_ABORT("fatal error"); break; } } #define DECL_MMQ_CASE(type) \ - template void mul_mat_q_case(const mmq_args & args, cudaStream_t stream) \ + template void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) \ extern DECL_MMQ_CASE(GGML_TYPE_Q4_0); extern DECL_MMQ_CASE(GGML_TYPE_Q4_1); @@ -1323,6 +2942,14 @@ extern DECL_MMQ_CASE(GGML_TYPE_Q3_K); extern DECL_MMQ_CASE(GGML_TYPE_Q4_K); extern DECL_MMQ_CASE(GGML_TYPE_Q5_K); extern DECL_MMQ_CASE(GGML_TYPE_Q6_K); +extern DECL_MMQ_CASE(GGML_TYPE_IQ2_XXS); +extern DECL_MMQ_CASE(GGML_TYPE_IQ2_XS); +extern DECL_MMQ_CASE(GGML_TYPE_IQ2_S); +extern DECL_MMQ_CASE(GGML_TYPE_IQ3_XXS); +extern DECL_MMQ_CASE(GGML_TYPE_IQ3_S); +extern DECL_MMQ_CASE(GGML_TYPE_IQ1_S); +extern DECL_MMQ_CASE(GGML_TYPE_IQ4_NL); +extern DECL_MMQ_CASE(GGML_TYPE_IQ4_XS); // ------------------------------------------------------------------------------------------------------------------------- @@ -1332,4 +2959,4 @@ void ggml_cuda_op_mul_mat_q( const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, const int64_t src1_padded_row_size, cudaStream_t stream); -bool ggml_cuda_supports_mmq(enum ggml_type type); +bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11); diff --git a/llama/ggml-cuda/mmvq.cu b/llama/ggml-cuda/mmvq.cu index ca0c824e..f693109a 100644 --- a/llama/ggml-cuda/mmvq.cu +++ b/llama/ggml-cuda/mmvq.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * @@ -54,16 +54,22 @@ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) static constexpr __device__ int get_vdr_mmvq(ggml_type type) { return type == GGML_TYPE_Q4_0 ? VDR_Q4_0_Q8_1_MMVQ : - type == GGML_TYPE_Q4_1 ? VDR_Q4_1_Q8_1_MMVQ : - type == GGML_TYPE_Q5_0 ? VDR_Q5_0_Q8_1_MMVQ : - type == GGML_TYPE_Q5_1 ? VDR_Q5_1_Q8_1_MMVQ : - type == GGML_TYPE_Q8_0 ? VDR_Q8_0_Q8_1_MMVQ : - type == GGML_TYPE_Q2_K ? VDR_Q2_K_Q8_1_MMVQ : - type == GGML_TYPE_Q3_K ? VDR_Q3_K_Q8_1_MMVQ : - type == GGML_TYPE_Q4_K ? VDR_Q4_K_Q8_1_MMVQ : - type == GGML_TYPE_Q5_K ? VDR_Q5_K_Q8_1_MMVQ : - type == GGML_TYPE_Q6_K ? VDR_Q6_K_Q8_1_MMVQ : - type == GGML_TYPE_IQ4_NL ? VDR_Q4_K_Q8_1_MMVQ : + type == GGML_TYPE_Q4_1 ? VDR_Q4_1_Q8_1_MMVQ : + type == GGML_TYPE_Q5_0 ? VDR_Q5_0_Q8_1_MMVQ : + type == GGML_TYPE_Q5_1 ? VDR_Q5_1_Q8_1_MMVQ : + type == GGML_TYPE_Q8_0 ? VDR_Q8_0_Q8_1_MMVQ : + type == GGML_TYPE_Q2_K ? VDR_Q2_K_Q8_1_MMVQ : + type == GGML_TYPE_Q3_K ? VDR_Q3_K_Q8_1_MMVQ : + type == GGML_TYPE_Q4_K ? VDR_Q4_K_Q8_1_MMVQ : + type == GGML_TYPE_Q5_K ? VDR_Q5_K_Q8_1_MMVQ : + type == GGML_TYPE_Q6_K ? VDR_Q6_K_Q8_1_MMVQ : + type == GGML_TYPE_IQ2_XXS ? VDR_IQ2_XXS_Q8_1_MMVQ : + type == GGML_TYPE_IQ2_XS ? VDR_IQ2_XS_Q8_1_MMVQ : + type == GGML_TYPE_IQ2_S ? VDR_IQ2_S_Q8_1_MMVQ : + type == GGML_TYPE_IQ3_XXS ? VDR_IQ3_XXS_Q8_1_MMVQ : + type == GGML_TYPE_IQ3_S ? VDR_IQ3_S_Q8_1_MMVQ : + type == GGML_TYPE_IQ4_NL ? VDR_IQ4_NL_Q8_1_MMVQ : + type == GGML_TYPE_IQ4_XS ? VDR_IQ4_XS_Q8_1_MMVQ : 1; } @@ -143,7 +149,7 @@ static __global__ void mul_mat_vec_q( tmp[j][i] = warp_reduce_sum(tmp[j][i]); } - if (threadIdx.x < rows_per_cuda_block) { + if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx.x < nrows_dst)) { dst[j*nrows_dst + row0 + threadIdx.x] = tmp[j][threadIdx.x]; } } @@ -182,7 +188,7 @@ static void mul_mat_vec_q_cuda( rows_per_cuda_block = 2; break; default: - GGML_ASSERT(false); + GGML_ABORT("fatal error"); break; } } @@ -216,7 +222,7 @@ static void mul_mat_vec_q_cuda( mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); break; default: - GGML_ASSERT(false); + GGML_ABORT("fatal error"); break; } } @@ -433,7 +439,7 @@ void ggml_cuda_op_mul_mat_vec_q( mul_mat_vec_iq3_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); break; default: - GGML_ASSERT(false); + GGML_ABORT("fatal error"); break; } diff --git a/llama/ggml-cuda/mmvq.cuh b/llama/ggml-cuda/mmvq.cuh index e2e138bd..c76123b1 100644 --- a/llama/ggml-cuda/mmvq.cuh +++ b/llama/ggml-cuda/mmvq.cuh @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * @@ -26,6 +26,8 @@ #include "common.cuh" +#define MMVQ_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVQ kernels. + void ggml_cuda_op_mul_mat_vec_q( ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, diff --git a/llama/ggml-cuda/norm.cu b/llama/ggml-cuda/norm.cu index 1770b32d..f27c597f 100644 --- a/llama/ggml-cuda/norm.cu +++ b/llama/ggml-cuda/norm.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/norm.cuh b/llama/ggml-cuda/norm.cuh index cbb30be9..cd20016a 100644 --- a/llama/ggml-cuda/norm.cuh +++ b/llama/ggml-cuda/norm.cuh @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/pad.cu b/llama/ggml-cuda/pad.cu index 7bbb79a2..38abb23e 100644 --- a/llama/ggml-cuda/pad.cu +++ b/llama/ggml-cuda/pad.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/pad.cuh b/llama/ggml-cuda/pad.cuh index b0ab5d3f..33b5f1b6 100644 --- a/llama/ggml-cuda/pad.cuh +++ b/llama/ggml-cuda/pad.cuh @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/pool2d.cu b/llama/ggml-cuda/pool2d.cu index acb837ca..f14bdd35 100644 --- a/llama/ggml-cuda/pool2d.cu +++ b/llama/ggml-cuda/pool2d.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/pool2d.cuh b/llama/ggml-cuda/pool2d.cuh index 2773cad9..3a680462 100644 --- a/llama/ggml-cuda/pool2d.cuh +++ b/llama/ggml-cuda/pool2d.cuh @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/quantize.cu b/llama/ggml-cuda/quantize.cu index ad7e2ef3..6c5b6f9f 100644 --- a/llama/ggml-cuda/quantize.cu +++ b/llama/ggml-cuda/quantize.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * @@ -25,24 +25,25 @@ */ #include "quantize.cuh" +#include -static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int64_t kx, const int64_t kx_padded) { - const int64_t ix = (int64_t)blockDim.x*blockIdx.x + threadIdx.x; +static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int64_t kx, const int64_t kx0_padded) { + const int64_t ix0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x; - if (ix >= kx_padded) { + if (ix0 >= kx0_padded) { return; } - const int64_t iy = (int64_t)blockDim.y*blockIdx.y + threadIdx.y; + const int64_t ix1 = blockIdx.y; - const int64_t i_padded = (int64_t)iy*kx_padded + ix; + const int64_t i_padded = ix1*kx0_padded + ix0; block_q8_1 * y = (block_q8_1 *) vy; const int64_t ib = i_padded / QK8_1; // block index const int64_t iqs = i_padded % QK8_1; // quant index - const float xi = ix < kx ? x[iy*kx + ix] : 0.0f; + const float xi = ix0 < kx ? x[ix1*kx + ix0] : 0.0f; float amax = fabsf(xi); float sum = xi; @@ -62,10 +63,133 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest reinterpret_cast(y[ib].ds.y) = sum; } -void quantize_row_q8_1_cuda(const float * x, void * vy, const int64_t kx, const int64_t ky, const int64_t kx_padded, cudaStream_t stream) { - const int64_t block_num_x = (kx_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE; - const dim3 num_blocks(block_num_x, ky, 1); - const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1); - quantize_q8_1<<>>(x, vy, kx, kx_padded); +template +static __global__ void quantize_mmq_q8_1( + const float * __restrict__ x, void * __restrict__ vy, const int64_t kx0, const int64_t kx1, const int64_t kx0_padded) { + + constexpr int vals_per_scale = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 64 : 32; + constexpr int vals_per_sum = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 16 : 32; + + const int64_t ix0 = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*4; + + if (ix0 >= kx0_padded) { + return; + } + + const float4 * x4 = (const float4 *) x; + + const int64_t ix1 = kx1*blockIdx.z + blockIdx.y; + + block_q8_1_mmq * y = (block_q8_1_mmq *) vy; + + const int64_t ib0 = blockIdx.z*((int64_t)gridDim.y*gridDim.x*blockDim.x/QK8_1); // first block of channel + const int64_t ib = ib0 + (ix0 / (4*QK8_1))*kx1 + blockIdx.y; // block index in channel + const int64_t iqs = ix0 % (4*QK8_1); // quant index in block + + // Load 4 floats per thread and calculate max. abs. value between them: + const float4 xi = ix0 < kx0 ? x4[(ix1*kx0 + ix0)/4] : make_float4(0.0f, 0.0f, 0.0f, 0.0f); + float amax = fabsf(xi.x); + amax = fmaxf(amax, fabsf(xi.y)); + amax = fmaxf(amax, fabsf(xi.z)); + amax = fmaxf(amax, fabsf(xi.w)); + + // Exchange max. abs. value between vals_per_scale/4 threads. +#pragma unroll + for (int mask = vals_per_scale/8; mask > 0; mask >>= 1) { + amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, WARP_SIZE)); + } + + float sum; + if (ds_layout != MMQ_Q8_1_DS_LAYOUT_D4) { + sum = xi.x + xi.y + xi.z + xi.w; + + // Exchange calculate sum across vals_per_sum/4 threads. +#pragma unroll + for (int mask = vals_per_sum/8; mask > 0; mask >>= 1) { + sum += __shfl_xor_sync(0xFFFFFFFF, sum, mask, WARP_SIZE); + } + } + + const float d_inv = 127.0f / amax; + char4 q; + q.x = roundf(xi.x*d_inv); + q.y = roundf(xi.y*d_inv); + q.z = roundf(xi.z*d_inv); + q.w = roundf(xi.w*d_inv); + + // Write back 4 int8 values as a single 32 bit value for better memroy bandwidth: + char4 * yqs4 = (char4 *) y[ib].qs; + yqs4[iqs/4] = q; + + if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6) { + if (iqs % 16 != 0 || iqs >= 96) { + return; + } + + y[ib].d2s6[2 + iqs/16] = sum; + + if (iqs % 64 != 0) { + return; + } + + const float d = 1.0f / d_inv; + + y[ib].d2s6[iqs/64] = d; + + return; + } + + if (iqs % 32 != 0) { + return; + } + + const float d = 1.0f / d_inv; + + if (ds_layout == MMQ_Q8_1_DS_LAYOUT_DS4) { + y[ib].ds4[iqs/32] = make_half2(d, sum); + } else { + y[ib].d4[iqs/32] = d; + } } +void quantize_row_q8_1_cuda( + const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, + const int64_t kx0_padded, const ggml_type type_x, cudaStream_t stream) { + + GGML_ASSERT(kx0_padded % QK8_1 == 0); + + const int64_t block_num_x = (kx0_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE; + const dim3 num_blocks(block_num_x, kx1*channels, 1); + const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1); + quantize_q8_1<<>>(x, vy, kx0, kx0_padded); + + GGML_UNUSED(type_x); +} + +void quantize_mmq_q8_1_cuda( + const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, + const int64_t kx0_padded, const ggml_type type_x, cudaStream_t stream) { + + GGML_ASSERT(kx0_padded % (4*QK8_1) == 0); + + const int64_t block_num_x = (kx0_padded + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ); + const dim3 num_blocks(block_num_x, kx1, channels); + const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE_MMQ, 1, 1); + switch (mmq_get_q8_1_ds_layout(type_x)) { + case MMQ_Q8_1_DS_LAYOUT_D4: + quantize_mmq_q8_1 + <<>>(x, vy, kx0, kx1, kx0_padded); + break; + case MMQ_Q8_1_DS_LAYOUT_DS4: + quantize_mmq_q8_1 + <<>>(x, vy, kx0, kx1, kx0_padded); + break; + case MMQ_Q8_1_DS_LAYOUT_D2S6: + quantize_mmq_q8_1 + <<>>(x, vy, kx0, kx1, kx0_padded); + break; + default: + GGML_ABORT("fatal error"); + break; + } +} diff --git a/llama/ggml-cuda/quantize.cuh b/llama/ggml-cuda/quantize.cuh index 82656f24..f533e30e 100644 --- a/llama/ggml-cuda/quantize.cuh +++ b/llama/ggml-cuda/quantize.cuh @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * @@ -24,8 +24,27 @@ * SOFTWARE. */ +#pragma once + #include "common.cuh" +#include "mmq.cuh" -#define CUDA_QUANTIZE_BLOCK_SIZE 256 +#include -void quantize_row_q8_1_cuda(const float * x, void * vy, const int64_t kx, const int64_t ky, const int64_t kx_padded, cudaStream_t stream); +#define CUDA_QUANTIZE_BLOCK_SIZE 256 +#define CUDA_QUANTIZE_BLOCK_SIZE_MMQ 128 + +static_assert(MATRIX_ROW_PADDING % CUDA_QUANTIZE_BLOCK_SIZE == 0, "Risk of out-of-bounds access."); +static_assert(MATRIX_ROW_PADDING % (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ) == 0, "Risk of out-of-bounds access."); + +typedef void (*quantize_cuda_t)( + const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded, + const ggml_type type_x, cudaStream_t stream); + +void quantize_row_q8_1_cuda( + const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded, + const ggml_type type_x, cudaStream_t stream); + +void quantize_mmq_q8_1_cuda( + const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded, + const ggml_type type_x, cudaStream_t stream); diff --git a/llama/ggml-cuda/rope.cu b/llama/ggml-cuda/rope.cu index 6e7327a3..5046697c 100644 --- a/llama/ggml-cuda/rope.cu +++ b/llama/ggml-cuda/rope.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * @@ -277,7 +277,7 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { attn_factor, corr_dims, freq_factors, stream ); } else { - GGML_ASSERT(false); + GGML_ABORT("fatal error"); } } else { if (src0->type == GGML_TYPE_F32) { @@ -291,7 +291,7 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { attn_factor, corr_dims, freq_factors, stream ); } else { - GGML_ASSERT(false); + GGML_ABORT("fatal error"); } } } diff --git a/llama/ggml-cuda/rope.cuh b/llama/ggml-cuda/rope.cuh index d06466b5..aa34b1df 100644 --- a/llama/ggml-cuda/rope.cuh +++ b/llama/ggml-cuda/rope.cuh @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/scale.cu b/llama/ggml-cuda/scale.cu index 84eb76fe..e2d849e0 100644 --- a/llama/ggml-cuda/scale.cu +++ b/llama/ggml-cuda/scale.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/scale.cuh b/llama/ggml-cuda/scale.cuh index 69e29111..4c0dc83f 100644 --- a/llama/ggml-cuda/scale.cuh +++ b/llama/ggml-cuda/scale.cuh @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/softmax.cu b/llama/ggml-cuda/softmax.cu index d3fc1044..db94d7de 100644 --- a/llama/ggml-cuda/softmax.cu +++ b/llama/ggml-cuda/softmax.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * @@ -156,6 +156,7 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + // FIXME: this limit could be raised by ~2-4x on Ampere or newer if (shmem < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) { switch (ncols_x) { case 32: diff --git a/llama/ggml-cuda/softmax.cuh b/llama/ggml-cuda/softmax.cuh index 60df821a..ac4e2914 100644 --- a/llama/ggml-cuda/softmax.cuh +++ b/llama/ggml-cuda/softmax.cuh @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/sumrows.cu b/llama/ggml-cuda/sumrows.cu index 241bfb71..a6b8f720 100644 --- a/llama/ggml-cuda/sumrows.cu +++ b/llama/ggml-cuda/sumrows.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/sumrows.cuh b/llama/ggml-cuda/sumrows.cuh index 69eb3994..9b8c9cd6 100644 --- a/llama/ggml-cuda/sumrows.cuh +++ b/llama/ggml-cuda/sumrows.cuh @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu index cfd7e8a7..05196989 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-f16.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu index 817efe2d..fd02735b 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_0.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu index 277e8fc2..5fdcd8e4 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q4_1.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu index 47722adf..e032d0b3 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_0.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu index af548e34..6c89d944 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q5_1.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu index 01a268ae..b5326ec7 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-f16-q8_0.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu index cbde6106..c654b9d9 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-f16.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu index 7d8734c5..3eeed729 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_0.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu index e7cc120f..4c8b8e7e 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q4_1.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu index a6b29bea..ed93bda8 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_0.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu index 115a8f69..dd7a6ed9 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q5_1.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu index b6e9b64e..f13cbabb 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_0-q8_0.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu index 01ae3deb..c50660d2 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-f16.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu index bb761ecf..a32ba4e0 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_0.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu index bd4f2035..117c686d 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q4_1.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu index 8b797364..83b169e4 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_0.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu index 1324b33a..44883202 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q5_1.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu index 432ce759..ea964906 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q4_1-q8_0.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu index 45a99d8d..488ff9a6 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-f16.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu index a6549a73..1a0449a2 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_0.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu index 8abaa17a..b1a2723e 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q4_1.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu index 7ae05c52..74f18b63 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_0.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu index cd6c7980..d6350bec 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q5_1.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu index 362b035e..5ecc0c48 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_0-q8_0.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu index fffe9344..641d6a04 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-f16.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu index 0ad3077b..7615d691 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_0.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu index 03a1a375..c5755ff3 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q4_1.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu index cd27a825..375f370e 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_0.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu index 2adc1a0b..555eba19 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q5_1.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu index 118bfa56..29982a4c 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q5_1-q8_0.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu index 7da915b9..cd8b538a 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-f16.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu index 66d49c32..a102886a 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_0.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu index a17736df..700c84a8 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q4_1.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu index 41d8c6d0..acf305d4 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_0.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu index 4800bf50..c29b8262 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q5_1.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu index 63dc37ea..5b96efad 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs128-q8_0-q8_0.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu index de960397..2c4a76ad 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs256-f16-f16.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu index c5333b6c..6a4a424c 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-f16.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu index 74d561bc..949cba5f 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_0.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu index d1435f15..7e360e14 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q4_1.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu index 34ed5b18..afb4d80e 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_0.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu index cffcaad2..aa39aefb 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q5_1.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu index 319529d9..78bc0019 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f16-instance-hs64-f16-q8_0.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu index 02f05ac4..35f772f1 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-f16.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu index f449f7d1..6afb111e 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_0.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu index c77e21e8..03a69b8c 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q4_1.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu index fff43be6..59ad9cd8 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_0.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu index c0fff458..cd84c81e 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q5_1.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu index ec4b3fb7..6ef8b30f 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-f16-q8_0.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu index 9117b831..1cdf9601 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-f16.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu index ce9a8515..092b6757 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_0.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu index ceceb334..5fd20888 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q4_1.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu index 6ea92d98..7fd85f46 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_0.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu index 6342bdd9..39d5f402 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q5_1.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu index 5a621d3a..5dd34807 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_0-q8_0.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu index 5637c3ac..8fa2a892 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-f16.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu index 337ebc56..74a935f6 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_0.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu index d2316647..9c336952 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q4_1.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu index c57f31bb..c1691913 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_0.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu index 5693bbd3..ddb6f5c4 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q5_1.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu index 7278d4b7..460e0501 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q4_1-q8_0.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu index ff0ea22e..788346ed 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-f16.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu index f5d9254c..dfb2a12d 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_0.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu index 3501943d..4b9848d5 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q4_1.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu index cf0b7d78..141a3c0d 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_0.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu index 157abfd3..5e9736b8 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q5_1.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu index 96dbf49c..6027c480 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_0-q8_0.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu index c9a7d79d..d766d427 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-f16.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu index 948392db..3af17ada 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_0.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu index 28db7404..28ce6f86 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q4_1.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu index 12b2795a..5dc4609e 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_0.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu index 5f428d13..bd97d45e 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q5_1.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu index d4774fc5..7d0b363e 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q5_1-q8_0.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu index 39180299..92ee4c0f 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-f16.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu index be1d5bb5..51fde074 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_0.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu index 63c564bd..235e3872 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q4_1.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu index 244ff284..dc3715d7 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_0.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu index 7d829023..a5b4241f 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q5_1.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu index f8e90769..9a2fe54a 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs128-q8_0-q8_0.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu index 7d935c2c..5d8153e2 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs256-f16-f16.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu index de917434..73102eaf 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-f16.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu index 1ec95ad5..2f1a60bc 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_0.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu index 6463247e..5c2395be 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q4_1.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu index 423a6432..e038d84f 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_0.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu index 9af97fe1..832789fa 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q5_1.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu index 064bb392..c5b27e37 100644 --- a/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu +++ b/llama/ggml-cuda/template-instances/fattn-vec-f32-instance-hs64-f16-q8_0.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu b/llama/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu index 076d4233..2f34c8fb 100644 --- a/llama/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu +++ b/llama/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb16.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu b/llama/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu index 5029593a..f443658e 100644 --- a/llama/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu +++ b/llama/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqfloat-cpb32.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu b/llama/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu index b5d008eb..3e1304de 100644 --- a/llama/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu +++ b/llama/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb16.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu b/llama/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu index fb148383..d7c6d597 100644 --- a/llama/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu +++ b/llama/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb32.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu b/llama/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu index ee8af8f0..6bc3dc3f 100644 --- a/llama/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu +++ b/llama/ggml-cuda/template-instances/fattn-wmma-f16-instance-kqhalf-cpb8.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/mmq-instance-iq1_s.cu b/llama/ggml-cuda/template-instances/mmq-instance-iq1_s.cu new file mode 100644 index 00000000..7b484e65 --- /dev/null +++ b/llama/ggml-cuda/template-instances/mmq-instance-iq1_s.cu @@ -0,0 +1,31 @@ +/** + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file + * + * MIT License + * + * Copyright (c) 2023-2024 The ggml authors + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +DECL_MMQ_CASE(GGML_TYPE_IQ1_S); diff --git a/llama/ggml-cuda/template-instances/mmq-instance-iq2_s.cu b/llama/ggml-cuda/template-instances/mmq-instance-iq2_s.cu new file mode 100644 index 00000000..445791db --- /dev/null +++ b/llama/ggml-cuda/template-instances/mmq-instance-iq2_s.cu @@ -0,0 +1,31 @@ +/** + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file + * + * MIT License + * + * Copyright (c) 2023-2024 The ggml authors + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +DECL_MMQ_CASE(GGML_TYPE_IQ2_S); diff --git a/llama/ggml-cuda/template-instances/mmq-instance-iq2_xs.cu b/llama/ggml-cuda/template-instances/mmq-instance-iq2_xs.cu new file mode 100644 index 00000000..4f7eb4ba --- /dev/null +++ b/llama/ggml-cuda/template-instances/mmq-instance-iq2_xs.cu @@ -0,0 +1,31 @@ +/** + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file + * + * MIT License + * + * Copyright (c) 2023-2024 The ggml authors + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +DECL_MMQ_CASE(GGML_TYPE_IQ2_XS); diff --git a/llama/ggml-cuda/template-instances/mmq-instance-iq2_xxs.cu b/llama/ggml-cuda/template-instances/mmq-instance-iq2_xxs.cu new file mode 100644 index 00000000..bb1a3adb --- /dev/null +++ b/llama/ggml-cuda/template-instances/mmq-instance-iq2_xxs.cu @@ -0,0 +1,31 @@ +/** + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file + * + * MIT License + * + * Copyright (c) 2023-2024 The ggml authors + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +DECL_MMQ_CASE(GGML_TYPE_IQ2_XXS); diff --git a/llama/ggml-cuda/template-instances/mmq-instance-iq3_s.cu b/llama/ggml-cuda/template-instances/mmq-instance-iq3_s.cu new file mode 100644 index 00000000..01affe46 --- /dev/null +++ b/llama/ggml-cuda/template-instances/mmq-instance-iq3_s.cu @@ -0,0 +1,31 @@ +/** + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file + * + * MIT License + * + * Copyright (c) 2023-2024 The ggml authors + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +DECL_MMQ_CASE(GGML_TYPE_IQ3_S); diff --git a/llama/ggml-cuda/template-instances/mmq-instance-iq3_xxs.cu b/llama/ggml-cuda/template-instances/mmq-instance-iq3_xxs.cu new file mode 100644 index 00000000..badd19cf --- /dev/null +++ b/llama/ggml-cuda/template-instances/mmq-instance-iq3_xxs.cu @@ -0,0 +1,31 @@ +/** + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file + * + * MIT License + * + * Copyright (c) 2023-2024 The ggml authors + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +DECL_MMQ_CASE(GGML_TYPE_IQ3_XXS); diff --git a/llama/ggml-cuda/template-instances/mmq-instance-iq4_nl.cu b/llama/ggml-cuda/template-instances/mmq-instance-iq4_nl.cu new file mode 100644 index 00000000..e79360f9 --- /dev/null +++ b/llama/ggml-cuda/template-instances/mmq-instance-iq4_nl.cu @@ -0,0 +1,31 @@ +/** + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file + * + * MIT License + * + * Copyright (c) 2023-2024 The ggml authors + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +DECL_MMQ_CASE(GGML_TYPE_IQ4_NL); diff --git a/llama/ggml-cuda/template-instances/mmq-instance-iq4_xs.cu b/llama/ggml-cuda/template-instances/mmq-instance-iq4_xs.cu new file mode 100644 index 00000000..fa75948f --- /dev/null +++ b/llama/ggml-cuda/template-instances/mmq-instance-iq4_xs.cu @@ -0,0 +1,31 @@ +/** + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file + * + * MIT License + * + * Copyright (c) 2023-2024 The ggml authors + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +DECL_MMQ_CASE(GGML_TYPE_IQ4_XS); diff --git a/llama/ggml-cuda/template-instances/mmq-instance-q2_k.cu b/llama/ggml-cuda/template-instances/mmq-instance-q2_k.cu index ed3c331a..cb3d2b14 100644 --- a/llama/ggml-cuda/template-instances/mmq-instance-q2_k.cu +++ b/llama/ggml-cuda/template-instances/mmq-instance-q2_k.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/mmq-instance-q3_k.cu b/llama/ggml-cuda/template-instances/mmq-instance-q3_k.cu index 1dbfe57d..3afd2877 100644 --- a/llama/ggml-cuda/template-instances/mmq-instance-q3_k.cu +++ b/llama/ggml-cuda/template-instances/mmq-instance-q3_k.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/mmq-instance-q4_0.cu b/llama/ggml-cuda/template-instances/mmq-instance-q4_0.cu index 221485a5..e6fcb3d5 100644 --- a/llama/ggml-cuda/template-instances/mmq-instance-q4_0.cu +++ b/llama/ggml-cuda/template-instances/mmq-instance-q4_0.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/mmq-instance-q4_1.cu b/llama/ggml-cuda/template-instances/mmq-instance-q4_1.cu index be424bd2..e8c23dae 100644 --- a/llama/ggml-cuda/template-instances/mmq-instance-q4_1.cu +++ b/llama/ggml-cuda/template-instances/mmq-instance-q4_1.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/mmq-instance-q4_k.cu b/llama/ggml-cuda/template-instances/mmq-instance-q4_k.cu index aeb1071e..1b106850 100644 --- a/llama/ggml-cuda/template-instances/mmq-instance-q4_k.cu +++ b/llama/ggml-cuda/template-instances/mmq-instance-q4_k.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/mmq-instance-q5_0.cu b/llama/ggml-cuda/template-instances/mmq-instance-q5_0.cu index 278e3003..d17d2636 100644 --- a/llama/ggml-cuda/template-instances/mmq-instance-q5_0.cu +++ b/llama/ggml-cuda/template-instances/mmq-instance-q5_0.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/mmq-instance-q5_1.cu b/llama/ggml-cuda/template-instances/mmq-instance-q5_1.cu index e8f1d0e0..e0f6b4ad 100644 --- a/llama/ggml-cuda/template-instances/mmq-instance-q5_1.cu +++ b/llama/ggml-cuda/template-instances/mmq-instance-q5_1.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/mmq-instance-q5_k.cu b/llama/ggml-cuda/template-instances/mmq-instance-q5_k.cu index ee9d03ff..cc50ae8d 100644 --- a/llama/ggml-cuda/template-instances/mmq-instance-q5_k.cu +++ b/llama/ggml-cuda/template-instances/mmq-instance-q5_k.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/mmq-instance-q6_k.cu b/llama/ggml-cuda/template-instances/mmq-instance-q6_k.cu index 4df022f6..66cd6c91 100644 --- a/llama/ggml-cuda/template-instances/mmq-instance-q6_k.cu +++ b/llama/ggml-cuda/template-instances/mmq-instance-q6_k.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/template-instances/mmq-instance-q8_0.cu b/llama/ggml-cuda/template-instances/mmq-instance-q8_0.cu index 7034a4bf..ac2f5322 100644 --- a/llama/ggml-cuda/template-instances/mmq-instance-q8_0.cu +++ b/llama/ggml-cuda/template-instances/mmq-instance-q8_0.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/tsembd.cu b/llama/ggml-cuda/tsembd.cu index 452dfd63..3feed02b 100644 --- a/llama/ggml-cuda/tsembd.cu +++ b/llama/ggml-cuda/tsembd.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/tsembd.cuh b/llama/ggml-cuda/tsembd.cuh index 6d423c1f..cbfd942e 100644 --- a/llama/ggml-cuda/tsembd.cuh +++ b/llama/ggml-cuda/tsembd.cuh @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/unary.cu b/llama/ggml-cuda/unary.cu index 38e5f1b4..db9fa38d 100644 --- a/llama/ggml-cuda/unary.cu +++ b/llama/ggml-cuda/unary.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * @@ -118,6 +118,15 @@ static __global__ void sqr_f32(const float * x, float * dst, const int k) { dst[i] = x[i] * x[i]; } +static __global__ void sqrt_f32(const float * x, float * dst, const int k) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + dst[i] = sqrtf(x[i]); +} + static void gelu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE; gelu_f32<<>>(x, dst, k); @@ -168,12 +177,19 @@ static void sqr_f32_cuda(const float * x, float * dst, const int k, cudaStream_t sqr_f32<<>>(x, dst, k); } +static void sqrt_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_SQRT_BLOCK_SIZE - 1) / CUDA_SQRT_BLOCK_SIZE; + sqrt_f32<<>>(x, dst, k); +} + void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const float * src0_d = (const float *)src0->data; float * dst_d = (float *)dst->data; cudaStream_t stream = ctx.stream(); + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); @@ -186,6 +202,8 @@ void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { float * dst_d = (float *)dst->data; cudaStream_t stream = ctx.stream(); + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); @@ -198,6 +216,8 @@ void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) float * dst_d = (float *)dst->data; cudaStream_t stream = ctx.stream(); + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); @@ -210,6 +230,8 @@ void ggml_cuda_op_tanh(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { float * dst_d = (float *)dst->data; cudaStream_t stream = ctx.stream(); + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); @@ -222,6 +244,8 @@ void ggml_cuda_op_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { float * dst_d = (float *)dst->data; cudaStream_t stream = ctx.stream(); + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); @@ -234,6 +258,8 @@ void ggml_cuda_op_sigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { float * dst_d = (float *)dst->data; cudaStream_t stream = ctx.stream(); + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); @@ -246,6 +272,8 @@ void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst float * dst_d = (float *)dst->data; cudaStream_t stream = ctx.stream(); + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); @@ -258,6 +286,8 @@ void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst) float * dst_d = (float *)dst->data; cudaStream_t stream = ctx.stream(); + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); @@ -270,6 +300,8 @@ void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) float * dst_d = (float *)dst->data; cudaStream_t stream = ctx.stream(); + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); @@ -285,8 +317,24 @@ void ggml_cuda_op_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { float * dst_d = (float *)dst->data; cudaStream_t stream = ctx.stream(); + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); sqr_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream); } + +void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const float * src0_d = (const float *)src0->data; + float * dst_d = (float *)dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(ggml_is_contiguous(src0)); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + sqrt_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream); +} diff --git a/llama/ggml-cuda/unary.cuh b/llama/ggml-cuda/unary.cuh index 75fd718b..3d4a675b 100644 --- a/llama/ggml-cuda/unary.cuh +++ b/llama/ggml-cuda/unary.cuh @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * @@ -34,6 +34,7 @@ #define CUDA_HARDSIGMOID_BLOCK_SIZE 256 #define CUDA_HARDSWISH_BLOCK_SIZE 256 #define CUDA_SQR_BLOCK_SIZE 256 +#define CUDA_SQRT_BLOCK_SIZE 256 void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); @@ -54,3 +55,5 @@ void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/llama/ggml-cuda/upscale.cu b/llama/ggml-cuda/upscale.cu index e261bc17..4e5e614f 100644 --- a/llama/ggml-cuda/upscale.cu +++ b/llama/ggml-cuda/upscale.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/upscale.cuh b/llama/ggml-cuda/upscale.cuh index 06d64e6f..e3951934 100644 --- a/llama/ggml-cuda/upscale.cuh +++ b/llama/ggml-cuda/upscale.cuh @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/ggml-cuda/vecdotq.cuh b/llama/ggml-cuda/vecdotq.cuh index bdac93c8..97360639 100644 --- a/llama/ggml-cuda/vecdotq.cuh +++ b/llama/ggml-cuda/vecdotq.cuh @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * @@ -25,36 +25,21 @@ */ #include "common.cuh" +#include -static __device__ __forceinline__ int get_int_from_int8(const int8_t * x8, const int & i32) { - const uint16_t * x16 = (const uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment +static __device__ __forceinline__ int get_int_b2(const void * x, const int & i32) { + const uint16_t * x16 = (const uint16_t *) x; // assume at least 2 byte alignment - int x32 = 0; - x32 |= x16[0] << 0; - x32 |= x16[1] << 16; + int x32 = x16[2*i32 + 0] << 0; + x32 |= x16[2*i32 + 1] << 16; return x32; } -static __device__ __forceinline__ int get_int_from_uint8(const uint8_t * x8, const int & i32) { - const uint16_t * x16 = (const uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment - - int x32 = 0; - x32 |= x16[0] << 0; - x32 |= x16[1] << 16; - - return x32; +static __device__ __forceinline__ int get_int_b4(const void * x, const int & i32) { + return ((const int *) x)[i32]; // assume at least 4 byte alignment } -static __device__ __forceinline__ int get_int_from_int8_aligned(const int8_t * x8, const int & i32) { - return *((const int *) (x8 + sizeof(int) * i32)); // assume at least 4 byte alignment -} - -static __device__ __forceinline__ int get_int_from_uint8_aligned(const uint8_t * x8, const int & i32) { - return *((const int *) (x8 + sizeof(int) * i32)); // assume at least 4 byte alignment -} - - // VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called // MMVQ = mul_mat_vec_q, MMQ = mul_mat_q @@ -64,7 +49,6 @@ static __device__ __forceinline__ int get_int_from_uint8_aligned(const uint8_t * template static __device__ __forceinline__ float vec_dot_q4_0_q8_1_impl( const int * v, const int * u, const float & d4, const half2 & ds8) { -#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics int sumi = 0; #pragma unroll @@ -73,17 +57,14 @@ template static __device__ __forceinline__ float vec_dot_q4_0_q8_1_imp const int vi1 = (v[i] >> 4) & 0x0F0F0F0F; // SIMD dot product of quantized values - sumi = __dp4a(vi0, u[2*i+0], sumi); - sumi = __dp4a(vi1, u[2*i+1], sumi); + sumi = ggml_cuda_dp4a(vi0, u[2*i+0], sumi); + sumi = ggml_cuda_dp4a(vi1, u[2*i+1], sumi); } const float2 ds8f = __half22float2(ds8); // second part effectively subtracts 8 from each quant value return d4 * (sumi * ds8f.x - (8*vdr/QI4_0) * ds8f.y); -#else - NO_DEVICE_CODE; -#endif // __CUDA_ARCH__ >= MIN_CC_DP4A } #define VDR_Q4_1_Q8_1_MMVQ 2 @@ -92,7 +73,6 @@ template static __device__ __forceinline__ float vec_dot_q4_0_q8_1_imp template static __device__ __forceinline__ float vec_dot_q4_1_q8_1_impl( const int * v, const int * u, const half2 & dm4, const half2 & ds8) { -#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics int sumi = 0; #pragma unroll @@ -101,8 +81,8 @@ template static __device__ __forceinline__ float vec_dot_q4_1_q8_1_imp const int vi1 = (v[i] >> 4) & 0x0F0F0F0F; // SIMD dot product of quantized values - sumi = __dp4a(vi0, u[2*i+0], sumi); - sumi = __dp4a(vi1, u[2*i+1], sumi); + sumi = ggml_cuda_dp4a(vi0, u[2*i+0], sumi); + sumi = ggml_cuda_dp4a(vi1, u[2*i+1], sumi); } #ifdef GGML_CUDA_F16 @@ -118,9 +98,6 @@ template static __device__ __forceinline__ float vec_dot_q4_1_q8_1_imp // scale second part of sum by QI8_1/(vdr * QR4_1) to compensate for multiple threads adding it return sumi * d4d8 + m4s8 / (QI8_1 / (vdr * QR4_1)); -#else - NO_DEVICE_CODE; -#endif // __CUDA_ARCH__ >= MIN_CC_DP4A } #define VDR_Q5_0_Q8_1_MMVQ 2 @@ -129,7 +106,6 @@ template static __device__ __forceinline__ float vec_dot_q4_1_q8_1_imp template static __device__ __forceinline__ float vec_dot_q5_0_q8_1_impl( const int * vl, const int * vh, const int * u, const float & d5, const half2 & ds8) { -#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics int sumi = 0; #pragma unroll @@ -139,23 +115,20 @@ template static __device__ __forceinline__ float vec_dot_q5_0_q8_1_imp vi0 |= (vh[i] << 11) & 0x00001000; // 1 -> 12 vi0 |= (vh[i] << 18) & 0x00100000; // 2 -> 20 vi0 |= (vh[i] << 25) & 0x10000000; // 3 -> 28 - sumi = __dp4a(vi0, u[2*i+0], sumi); // SIMD dot product of quantized values + sumi = ggml_cuda_dp4a(vi0, u[2*i+0], sumi); // SIMD dot product of quantized values int vi1 = (vl[i] >> 4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh as 5th bits vi1 |= (vh[i] >> 12) & 0x00000010; // 16 -> 4 vi1 |= (vh[i] >> 5) & 0x00001000; // 17 -> 12 vi1 |= (vh[i] << 2) & 0x00100000; // 18 -> 20 vi1 |= (vh[i] << 9) & 0x10000000; // 19 -> 28 - sumi = __dp4a(vi1, u[2*i+1], sumi); // SIMD dot product of quantized values + sumi = ggml_cuda_dp4a(vi1, u[2*i+1], sumi); // SIMD dot product of quantized values } const float2 ds8f = __half22float2(ds8); // second part effectively subtracts 16 from each quant value return d5 * (sumi * ds8f.x - (16*vdr/QI5_0) * ds8f.y); -#else - NO_DEVICE_CODE; -#endif // __CUDA_ARCH__ >= MIN_CC_DP4A } #define VDR_Q5_1_Q8_1_MMVQ 2 @@ -164,7 +137,6 @@ template static __device__ __forceinline__ float vec_dot_q5_0_q8_1_imp template static __device__ __forceinline__ float vec_dot_q5_1_q8_1_impl( const int * vl, const int * vh, const int * u, const half2 & dm5, const half2 & ds8) { -#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics int sumi = 0; #pragma unroll @@ -174,14 +146,14 @@ template static __device__ __forceinline__ float vec_dot_q5_1_q8_1_imp vi0 |= (vh[i] << 11) & 0x00001000; // 1 -> 12 vi0 |= (vh[i] << 18) & 0x00100000; // 2 -> 20 vi0 |= (vh[i] << 25) & 0x10000000; // 3 -> 28 - sumi = __dp4a(vi0, u[2*i+0], sumi); // SIMD dot product of quantized values + sumi = ggml_cuda_dp4a(vi0, u[2*i+0], sumi); // SIMD dot product of quantized values int vi1 = (vl[i] >> 4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh as 5th bits vi1 |= (vh[i] >> 12) & 0x00000010; // 16 -> 4 vi1 |= (vh[i] >> 5) & 0x00001000; // 17 -> 12 vi1 |= (vh[i] << 2) & 0x00100000; // 18 -> 20 vi1 |= (vh[i] << 9) & 0x10000000; // 19 -> 28 - sumi = __dp4a(vi1, u[2*i+1], sumi); // SIMD dot product of quantized values + sumi = ggml_cuda_dp4a(vi1, u[2*i+1], sumi); // SIMD dot product of quantized values } #ifdef GGML_CUDA_F16 @@ -197,10 +169,6 @@ template static __device__ __forceinline__ float vec_dot_q5_1_q8_1_imp // scale second part of sum by QI5_1 / vdr to compensate for multiple threads adding it return sumi*d5d8 + m5s8 / (QI5_1 / vdr); - -#else - NO_DEVICE_CODE; -#endif // __CUDA_ARCH__ >= MIN_CC_DP4A } #define VDR_Q8_0_Q8_1_MMVQ 2 @@ -209,31 +177,26 @@ template static __device__ __forceinline__ float vec_dot_q5_1_q8_1_imp template static __device__ __forceinline__ T vec_dot_q8_0_q8_1_impl( const int * v, const int * u, const T & d8_0, const T & d8_1) { -#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics int sumi = 0; #pragma unroll for (int i = 0; i < vdr; ++i) { // SIMD dot product of quantized values - sumi = __dp4a(v[i], u[i], sumi); + sumi = ggml_cuda_dp4a(v[i], u[i], sumi); } return d8_0*d8_1 * ((T) sumi); -#else - NO_DEVICE_CODE; -#endif // __CUDA_ARCH__ >= MIN_CC_DP4A } template static __device__ __forceinline__ float vec_dot_q8_1_q8_1_impl( const int * v, const int * u, const half2 & dm8, const half2 & ds8) { -#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics int sumi = 0; #pragma unroll for (int i = 0; i < vdr; ++i) { // SIMD dot product of quantized values - sumi = __dp4a(v[i], u[i], sumi); + sumi = ggml_cuda_dp4a(v[i], u[i], sumi); } #ifdef GGML_CUDA_F16 @@ -249,20 +212,37 @@ template static __device__ __forceinline__ float vec_dot_q8_1_q8_1_imp // scale second part of sum by QI8_1/ vdr to compensate for multiple threads adding it return sumi*d8d8 + m8s8 / (QI8_1 / vdr); -#else - NO_DEVICE_CODE; -#endif // __CUDA_ARCH__ >= MIN_CC_DP4A +} + +template static __device__ __forceinline__ float vec_dot_q8_0_16_q8_1_impl( + const int * v, const int * u, const float * d8_0, const float & d8_1) { + + float sumf = 0.0f; + +#pragma unroll + for (int i0 = 0; i0 < vdr; i0 += QI8_0/2) { + int sumi = 0; + +#pragma unroll + for (int i = i0; i < i0 + QI8_0/2; ++i) { + // SIMD dot product of quantized values + sumi = ggml_cuda_dp4a(v[i], u[i], sumi); + } + + sumf += d8_0[i0/(QI8_0/2)]*sumi; + } + + return d8_1*sumf; } #define VDR_Q2_K_Q8_1_MMVQ 1 -#define VDR_Q2_K_Q8_1_MMQ 2 +#define VDR_Q2_K_Q8_1_MMQ 4 // contiguous v/x values static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq( const int & v, const int * __restrict__ u, const uint8_t * __restrict__ scales, const half2 & dm2, const float * __restrict__ d8) { -#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics float sumf_d = 0.0f; float sumf_m = 0.0f; @@ -272,58 +252,70 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq( const int vi = (v >> (2*i)) & 0x03030303; - sumf_d += d8[i] * (__dp4a(vi, u[i], 0) * (sc & 0xF)); // SIMD dot product + sumf_d += d8[i] * (ggml_cuda_dp4a(vi, u[i], 0) * (sc & 0xF)); // SIMD dot product // fill int with 4x m int m = sc >> 4; m |= m << 8; m |= m << 16; - sumf_m += d8[i] * __dp4a(m, u[i], 0); // multiply constant q2_K part with sum of q8_1 values + sumf_m += d8[i] * ggml_cuda_dp4a(m, u[i], 0); // multiply constant q2_K part with sum of q8_1 values } const float2 dm2f = __half22float2(dm2); return dm2f.x*sumf_d - dm2f.y*sumf_m; -#else - NO_DEVICE_CODE; -#endif // __CUDA_ARCH__ >= MIN_CC_DP4A } -// contiguous u/y values +// contiguous v/x + u/y values +template static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmq( - const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ scales, - const half2 & dm2, const float & d8) { + const int * __restrict__ v, const int * __restrict__ u, const half2 * dm2, const float & d8, const half2 * s8) { -#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics - int sumi_d = 0; - int sumi_m = 0; + float sumf = 0.0f; + float sumf_d8 = 0.0f; #pragma unroll - for (int i0 = 0; i0 < QI8_1; i0 += QI8_1/2) { - int sumi_d_sc = 0; + for (int i0 = 0; i0 < QR2_K*VDR_Q2_K_Q8_1_MMQ; i0 += QI8_1) { + const float2 dm2f0 = __half22float2(dm2[i0/(QI8_1/2) + 0]); + int sumi_d0 = 0; - const int sc = scales[i0 / (QI8_1/2)]; - - // fill int with 4x m - int m = sc >> 4; - m |= m << 8; - m |= m << 16; + const float2 dm2f1 = __half22float2(dm2[i0/(QI8_1/2) + 1]); + int sumi_d1 = 0; #pragma unroll for (int i = i0; i < i0 + QI8_1/2; ++i) { - sumi_d_sc = __dp4a(v[i], u[i], sumi_d_sc); // SIMD dot product - sumi_m = __dp4a(m, u[i], sumi_m); // multiply sum of q8_1 values with m + sumi_d0 = ggml_cuda_dp4a(v[i], u[i], sumi_d0); } + sumf_d8 += dm2f0.x * sumi_d0; - sumi_d += sumi_d_sc * (sc & 0xF); +#pragma unroll + for (int i = i0 + QI8_1/2; i < i0 + QI8_1; ++i) { + sumi_d1 = ggml_cuda_dp4a(v[i], u[i], sumi_d1); + } + sumf_d8 += dm2f1.x * sumi_d1; + + if (i0/QI8_1 < ns8) { + const float2 s8f = __half22float2(s8[i0/QI8_1]); + sumf -= dm2f0.y*s8f.x; + sumf -= dm2f1.y*s8f.y; + } else { + int sumi_m0 = 0; +#pragma unroll + for (int i = i0; i < i0 + QI8_1/2; ++i) { + sumi_m0 = ggml_cuda_dp4a(0x01010101, u[i], sumi_m0); + } + sumf_d8 -= dm2f0.y * sumi_m0; + + int sumi_m1 = 0; +#pragma unroll + for (int i = i0 + QI8_1/2; i < i0 + QI8_1; ++i) { + sumi_m1 = ggml_cuda_dp4a(0x01010101, u[i], sumi_m1); + } + sumf_d8 -= dm2f1.y * sumi_m1; + } } - const float2 dm2f = __half22float2(dm2); - - return d8 * (dm2f.x*sumi_d - dm2f.y*sumi_m); -#else - NO_DEVICE_CODE; -#endif // __CUDA_ARCH__ >= MIN_CC_DP4A + return sumf + d8*sumf_d8; } #define VDR_Q3_K_Q8_1_MMVQ 1 @@ -334,7 +326,6 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmvq( const int & vl, const int & vh, const int * __restrict__ u, const uint8_t * __restrict__ scales, const int & scale_offset, const float & d3, const float * __restrict__ d8) { -#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics float sumf = 0.0f; #pragma unroll @@ -357,38 +348,32 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmvq( const int vi = __vsubss4(vil, vih); - sumf += d8[i] * (__dp4a(vi, u[i], 0) * sc); // SIMD dot product + sumf += d8[i] * (ggml_cuda_dp4a(vi, u[i], 0) * sc); // SIMD dot product } return d3 * sumf; -#else - NO_DEVICE_CODE; -#endif // __CUDA_ARCH__ >= MIN_CC_DP4A } -// contiguous u/y values +// contiguous v/x + u/y values static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq( const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ scales, const float & d3, const float & d8) { -#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics int sumi = 0; #pragma unroll for (int i0 = 0; i0 < QR3_K*VDR_Q3_K_Q8_1_MMQ; i0 += QI8_1/2) { int sumi_sc = 0; +#pragma unroll for (int i = i0; i < i0 + QI8_1/2; ++i) { - sumi_sc = __dp4a(v[i], u[i], sumi_sc); // SIMD dot product + sumi_sc = ggml_cuda_dp4a(v[i], u[i], sumi_sc); // SIMD dot product } sumi += sumi_sc * scales[i0 / (QI8_1/2)]; } return d3*d8 * sumi; -#else - NO_DEVICE_CODE; -#endif // __CUDA_ARCH__ >= MIN_CC_DP4A } #define VDR_Q4_K_Q8_1_MMVQ 2 @@ -399,7 +384,6 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_vmmq( const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc, const uint8_t * __restrict__ m, const half2 & dm4, const float * __restrict__ d8) { -#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics float sumf_d = 0.0f; float sumf_m = 0.0f; @@ -408,8 +392,8 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_vmmq( const int v0i = (v[0] >> (4*i)) & 0x0F0F0F0F; const int v1i = (v[1] >> (4*i)) & 0x0F0F0F0F; - const int dot1 = __dp4a(v1i, u[2*i+1], __dp4a(v0i, u[2*i+0], 0)); // SIMD dot product - const int dot2 = __dp4a(0x01010101, u[2*i+1], __dp4a(0x01010101, u[2*i+0], 0)); // sum of u + const int dot1 = ggml_cuda_dp4a(v1i, u[2*i+1], ggml_cuda_dp4a(v0i, u[2*i+0], 0)); // SIMD dot product + const int dot2 = ggml_cuda_dp4a(0x01010101, u[2*i+1], ggml_cuda_dp4a(0x01010101, u[2*i+0], 0)); // sum of u sumf_d += d8[i] * (dot1 * sc[i]); sumf_m += d8[i] * (dot2 * m[i]); // multiply constant part of q4_K with sum of q8_1 values @@ -418,18 +402,13 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_vmmq( const float2 dm4f = __half22float2(dm4); return dm4f.x*sumf_d - dm4f.y*sumf_m; - -#else - NO_DEVICE_CODE; -#endif // __CUDA_ARCH__ >= MIN_CC_DP4A } -// contiguous u/y values +// contiguous v/x + u/y values static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_mmq( const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc, const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) { -#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics float sumf_d = 0.0f; float sumf_m = 0.0f; @@ -439,7 +418,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_mmq( #pragma unroll for (int j = 0; j < QI8_1; ++j) { - sumi_d = __dp4a((v[j] >> (4*i)) & 0x0F0F0F0F, u[i*QI8_1 + j], sumi_d); // SIMD dot product + sumi_d = ggml_cuda_dp4a((v[j] >> (4*i)) & 0x0F0F0F0F, u[i*QI8_1 + j], sumi_d); // SIMD dot product } const float2 ds8f = __half22float2(ds8[i]); @@ -451,10 +430,6 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_mmq( const float2 dm4f = __half22float2(dm4); return dm4f.x*sumf_d - dm4f.y*sumf_m; - -#else - NO_DEVICE_CODE; -#endif // __CUDA_ARCH__ >= MIN_CC_DP4A } #define VDR_Q5_K_Q8_1_MMVQ 2 @@ -465,7 +440,6 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_vmmq( const int * __restrict__ vl, const int * __restrict__ vh, const int * __restrict__ u, const uint8_t * __restrict__ sc, const uint8_t * __restrict__ m, const half2 & dm5, const float * __restrict__ d8) { -#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics float sumf_d = 0.0f; float sumf_m = 0.0f; @@ -480,8 +454,8 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_vmmq( const int v0i = vl0i | vh0i; const int v1i = vl1i | vh1i; - const int dot1 = __dp4a(v0i, u[2*i+0], __dp4a(v1i, u[2*i+1], 0)); // SIMD dot product - const int dot2 = __dp4a(0x01010101, u[2*i+0], __dp4a(0x01010101, u[2*i+1], 0)); // sum of u + const int dot1 = ggml_cuda_dp4a(v0i, u[2*i+0], ggml_cuda_dp4a(v1i, u[2*i+1], 0)); // SIMD dot product + const int dot2 = ggml_cuda_dp4a(0x01010101, u[2*i+0], ggml_cuda_dp4a(0x01010101, u[2*i+1], 0)); // sum of u sumf_d += d8[i] * (dot1 * sc[i]); sumf_m += d8[i] * (dot2 * m[i]); @@ -491,18 +465,13 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_vmmq( const float2 dm5f = __half22float2(dm5); return dm5f.x*sumf_d - dm5f.y*sumf_m; - -#else - NO_DEVICE_CODE; -#endif // __CUDA_ARCH__ >= MIN_CC_DP4A } -// contiguous u/y values +// contiguous v/x + u/y values static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_mmq( const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc, const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) { -#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics float sumf_d = 0.0f; float sumf_m = 0.0f; @@ -512,7 +481,7 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_mmq( #pragma unroll for (int j = 0; j < QI8_1; ++j) { - sumi_d = __dp4a(v[i*QI8_1 + j], u[i*QI8_1 + j], sumi_d); // SIMD dot product + sumi_d = ggml_cuda_dp4a(v[i*QI8_1 + j], u[i*QI8_1 + j], sumi_d); // SIMD dot product } const float2 ds8f = __half22float2(ds8[i]); @@ -524,10 +493,6 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_mmq( const float2 dm4f = __half22float2(dm4); return dm4f.x*sumf_d - dm4f.y*sumf_m; - -#else - NO_DEVICE_CODE; -#endif // __CUDA_ARCH__ >= MIN_CC_DP4A } #define VDR_Q6_K_Q8_1_MMVQ 1 @@ -538,7 +503,6 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmvq( const int & vl, const int & vh, const int * __restrict__ u, const int8_t * __restrict__ scales, const float & d, const float * __restrict__ d8) { -#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics float sumf = 0.0f; #pragma unroll @@ -551,44 +515,39 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmvq( const int vi = __vsubss4((vil | vih), 0x20202020); // vi = (vil | vih) - 32 - sumf += d8[i] * (__dp4a(vi, u[i], 0) * sc); // SIMD dot product + sumf += d8[i] * (ggml_cuda_dp4a(vi, u[i], 0) * sc); // SIMD dot product } return d*sumf; -#else - NO_DEVICE_CODE; -#endif // __CUDA_ARCH__ >= MIN_CC_DP4A } -// contiguous u/y values +// contiguous v/x + u/y values static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq( const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ sc, const float & d6, const float * __restrict__ d8) { -#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics float sumf_d = 0.0f; + const int sc_packed = get_int_b4(sc, 0); + const int8_t * sc_reg = (const int8_t *) &sc_packed; + #pragma unroll for (int i0 = 0; i0 < VDR_Q6_K_Q8_1_MMQ; i0 += 4) { int2 sumi_d = {0, 0}; // 2 q6_K scales per q8_1 scale #pragma unroll for (int i = i0; i < i0 + 2; ++i) { - sumi_d.x = __dp4a(v[2*i+0], u[2*i+0], sumi_d.x); // SIMD dot product - sumi_d.x = __dp4a(v[2*i+1], u[2*i+1], sumi_d.x); // SIMD dot product + sumi_d.x = ggml_cuda_dp4a(v[2*i+0], u[2*i+0], sumi_d.x); // SIMD dot product + sumi_d.x = ggml_cuda_dp4a(v[2*i+1], u[2*i+1], sumi_d.x); // SIMD dot product - sumi_d.y = __dp4a(v[2*i+4], u[2*i+4], sumi_d.y); // SIMD dot product - sumi_d.y = __dp4a(v[2*i+5], u[2*i+5], sumi_d.y); // SIMD dot product + sumi_d.y = ggml_cuda_dp4a(v[2*i+4], u[2*i+4], sumi_d.y); // SIMD dot product + sumi_d.y = ggml_cuda_dp4a(v[2*i+5], u[2*i+5], sumi_d.y); // SIMD dot product } - sumf_d += d8[i0/4] * (sc[i0/2+0]*sumi_d.x + sc[i0/2+1]*sumi_d.y); + sumf_d += d8[i0/4] * (sc_reg[i0/2+0]*sumi_d.x + sc_reg[i0/2+1]*sumi_d.y); } return d6 * sumf_d; - -#else - NO_DEVICE_CODE; -#endif // __CUDA_ARCH__ >= MIN_CC_DP4A } static __device__ __forceinline__ float vec_dot_q4_0_q8_1( @@ -601,9 +560,9 @@ static __device__ __forceinline__ float vec_dot_q4_0_q8_1( #pragma unroll for (int i = 0; i < VDR_Q4_0_Q8_1_MMVQ; ++i) { - v[i] = get_int_from_uint8(bq4_0->qs, iqs + i); - u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i); - u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI4_0); + v[i] = get_int_b2(bq4_0->qs, iqs + i); + u[2*i+0] = get_int_b4(bq8_1->qs, iqs + i); + u[2*i+1] = get_int_b4(bq8_1->qs, iqs + i + QI4_0); } return vec_dot_q4_0_q8_1_impl(v, u, bq4_0->d, bq8_1->ds); @@ -620,9 +579,9 @@ static __device__ __forceinline__ float vec_dot_q4_1_q8_1( #pragma unroll for (int i = 0; i < VDR_Q4_1_Q8_1_MMVQ; ++i) { - v[i] = get_int_from_uint8_aligned(bq4_1->qs, iqs + i); - u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i); - u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI4_1); + v[i] = get_int_b4(bq4_1->qs, iqs + i); + u[2*i+0] = get_int_b4(bq8_1->qs, iqs + i); + u[2*i+1] = get_int_b4(bq8_1->qs, iqs + i + QI4_1); } return vec_dot_q4_1_q8_1_impl(v, u, bq4_1->dm, bq8_1->ds); @@ -639,10 +598,10 @@ static __device__ __forceinline__ float vec_dot_q5_0_q8_1( #pragma unroll for (int i = 0; i < VDR_Q5_0_Q8_1_MMVQ; ++i) { - vl[i] = get_int_from_uint8(bq5_0->qs, iqs + i); - vh[i] = get_int_from_uint8(bq5_0->qh, 0) >> (4 * (iqs + i)); - u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i); - u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI5_0); + vl[i] = get_int_b2(bq5_0->qs, iqs + i); + vh[i] = get_int_b2(bq5_0->qh, 0) >> (4 * (iqs + i)); + u[2*i+0] = get_int_b4(bq8_1->qs, iqs + i); + u[2*i+1] = get_int_b4(bq8_1->qs, iqs + i + QI5_0); } return vec_dot_q5_0_q8_1_impl(vl, vh, u, bq5_0->d, bq8_1->ds); @@ -659,10 +618,10 @@ static __device__ __forceinline__ float vec_dot_q5_1_q8_1( #pragma unroll for (int i = 0; i < VDR_Q5_1_Q8_1_MMVQ; ++i) { - vl[i] = get_int_from_uint8_aligned(bq5_1->qs, iqs + i); - vh[i] = get_int_from_uint8_aligned(bq5_1->qh, 0) >> (4 * (iqs + i)); - u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i); - u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI5_1); + vl[i] = get_int_b4(bq5_1->qs, iqs + i); + vh[i] = get_int_b4(bq5_1->qh, 0) >> (4 * (iqs + i)); + u[2*i+0] = get_int_b4(bq8_1->qs, iqs + i); + u[2*i+1] = get_int_b4(bq8_1->qs, iqs + i + QI5_1); } return vec_dot_q5_1_q8_1_impl(vl, vh, u, bq5_1->dm, bq8_1->ds); @@ -678,8 +637,8 @@ static __device__ __forceinline__ float vec_dot_q8_0_q8_1( #pragma unroll for (int i = 0; i < VDR_Q8_0_Q8_1_MMVQ; ++i) { - v[i] = get_int_from_int8(bq8_0->qs, iqs + i); - u[i] = get_int_from_int8_aligned(bq8_1->qs, iqs + i); + v[i] = get_int_b2(bq8_0->qs, iqs + i); + u[i] = get_int_b4(bq8_1->qs, iqs + i); } return vec_dot_q8_0_q8_1_impl(v, u, bq8_0->d, __low2half(bq8_1->ds)); @@ -695,13 +654,13 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1( const uint8_t * scales = bq2_K->scales + scale_offset; - const int v = get_int_from_uint8_aligned(bq2_K->qs, iqs); + const int v = get_int_b4(bq2_K->qs, iqs); int u[QR2_K]; float d8[QR2_K]; #pragma unroll for (int i = 0; i < QR2_K; ++ i) { - u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + i].qs, iqs % QI8_1); + u[i] = get_int_b4(bq8_1[bq8_offset + i].qs, iqs % QI8_1); d8[i] = __low2float(bq8_1[bq8_offset + i].ds); } @@ -718,17 +677,17 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1( const float d = bq3_K->d; - const int vl = get_int_from_uint8(bq3_K->qs, iqs); + const int vl = get_int_b2(bq3_K->qs, iqs); // invert the mask with ~ so that a 0/1 results in 4/0 being subtracted - const int vh = ~get_int_from_uint8(bq3_K->hmask, iqs % (QI3_K/2)) >> bq8_offset; + const int vh = ~get_int_b2(bq3_K->hmask, iqs % (QI3_K/2)) >> bq8_offset; int u[QR3_K]; float d8[QR3_K]; #pragma unroll for (int i = 0; i < QR3_K; ++i) { - u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + i].qs, iqs % QI8_1); + u[i] = get_int_b4(bq8_1[bq8_offset + i].qs, iqs % QI8_1); d8[i] = __low2float(bq8_1[bq8_offset + i].ds); } @@ -836,8 +795,8 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1( const int scale_offset = (QI6_K/4) * (iqs / (QI6_K/2)) + (iqs % (QI6_K/2)) / (QI6_K/8); const int vh_shift = 2 * ((iqs % (QI6_K/2)) / (QI6_K/4)); - const int vl = get_int_from_uint8(bq6_K->ql, iqs); - const int vh = get_int_from_uint8(bq6_K->qh, (QI6_K/4) * (iqs / (QI6_K/2)) + iqs % (QI6_K/4)) >> vh_shift; + const int vl = get_int_b2(bq6_K->ql, iqs); + const int vh = get_int_b2(bq6_K->qh, (QI6_K/4) * (iqs / (QI6_K/2)) + iqs % (QI6_K/4)) >> vh_shift; const int8_t * scales = bq6_K->scales + scale_offset; @@ -846,335 +805,355 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1( #pragma unroll for (int i = 0; i < QR6_K; ++i) { - u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + 2*i].qs, iqs % QI8_1); + u[i] = get_int_b4(bq8_1[bq8_offset + 2*i].qs, iqs % QI8_1); d8[i] = __low2float(bq8_1[bq8_offset + 2*i].ds); } return vec_dot_q6_K_q8_1_impl_mmvq(vl, vh, u, scales, bq6_K->d, d8); } +#define VDR_IQ2_XXS_Q8_1_MMVQ 2 +#define VDR_IQ2_XXS_Q8_1_MMQ 2 + static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { + const block_iq2_xxs * bq2 = (const block_iq2_xxs *) vbq + kbx; -#if QR2_XXS == 8 - const int ib32 = iqs; - const uint16_t * q2 = bq2->qs + 4*ib32; - const uint8_t * aux8 = (const uint8_t *)q2; - const int8_t * q8 = bq8_1[ib32].qs; - uint32_t aux32 = q2[2] | (q2[3] << 16); + const int q2 = get_int_b2(bq2->qs, iqs); + const uint8_t * aux8 = (const uint8_t *) &q2; + const uint32_t aux32 = get_int_b2(bq2->qs, iqs + 1); + int sumi = 0; - for (int l = 0; l < 4; ++l) { - const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[l]); - const uint8_t signs = ksigns_iq2xs[aux32 & 127]; - for (int j = 0; j < 8; ++j) { - sumi += q8[j] * grid[j] * (signs & kmask_iq2xs[j] ? -1 : 1); - } - q8 += 8; - aux32 >>= 7; +#pragma unroll + for (int k0 = 0; k0 < 8; k0 += 2) { + const int * grid_pos = (const int *) (iq2xxs_grid + aux8[k0/2]); + const int signs_packed = ksigns_iq2xs[(aux32 >> (7*k0/2)) & 0x7F]; + + const int signs0 = __vcmpne4(((signs_packed & 0x03) << 7) | ((signs_packed & 0x0C) << 21), 0x00000000); + const int grid0 = __vsub4(grid_pos[0] ^ signs0, signs0); + const int u0 = get_int_b4(bq8_1[iqs/2].qs, k0 + 0); + sumi = ggml_cuda_dp4a(grid0, u0, sumi); + + const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000); + const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1); + const int u1 = get_int_b4(bq8_1[iqs/2].qs, k0 + 1); + sumi = ggml_cuda_dp4a(grid1, u1, sumi); } - const float d = (float)bq2->d * (0.5f + aux32) * __low2float(bq8_1[ib32].ds) * 0.25f; + + const int ls = aux32 >> 28; + sumi = (ls*sumi + sumi/2)/4; + const float d = __half2float(bq2->d) * __low2float(bq8_1[iqs/2].ds); return d * sumi; -#else - // iqs is 0...15 - const int ib32 = iqs/2; - const int il = iqs%2; - const uint16_t * q2 = bq2->qs + 4*ib32; - const uint8_t * aux8 = (const uint8_t *)q2; - const uint8_t * grid1 = (const uint8_t *)(iq2xxs_grid + aux8[2*il+0]); - const uint8_t * grid2 = (const uint8_t *)(iq2xxs_grid + aux8[2*il+1]); - const uint32_t aux32 = q2[2] | (q2[3] << 16); - const float d = (float)bq2->d * (0.5f + (aux32 >> 28)) * __low2float(bq8_1[ib32].ds) * 0.25f; - const uint8_t signs1 = ksigns_iq2xs[(aux32 >> 14*il) & 127]; - const uint8_t signs2 = ksigns_iq2xs[(aux32 >> (14*il + 7)) & 127]; - const int8_t * q8 = bq8_1[ib32].qs + 16*il; - int sumi1 = 0, sumi2 = 0; - for (int j = 0; j < 8; ++j) { - sumi1 += q8[j+0] * grid1[j] * (signs1 & kmask_iq2xs[j] ? -1 : 1); - sumi2 += q8[j+8] * grid2[j] * (signs2 & kmask_iq2xs[j] ? -1 : 1); - } - return d * (sumi1 + sumi2); -#endif } +#define VDR_IQ2_XS_Q8_1_MMVQ 2 +#define VDR_IQ2_XS_Q8_1_MMQ 2 + static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { -#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics + const block_iq2_xs * bq2 = (const block_iq2_xs *) vbq + kbx; - const int ib32 = iqs; - const uint16_t * q2 = bq2->qs + 4*ib32; - const int8_t * q8 = bq8_1[ib32].qs; - const uint8_t ls1 = bq2->scales[ib32] & 0xf; - const uint8_t ls2 = bq2->scales[ib32] >> 4; + const int2 q2_packed = make_int2(get_int_b2(bq2->qs, iqs + 0), get_int_b2(bq2->qs, iqs + 1)); + const uint16_t * q2 = (const uint16_t *) &q2_packed; + const int ls0 = bq2->scales[iqs/2] & 0x0F; + const int ls1 = bq2->scales[iqs/2] >> 4; + + int sumi0 = 0; int sumi1 = 0; - for (int l = 0; l < 2; ++l) { - const uint32_t * grid = (const uint32_t *)(iq2xs_grid + (q2[l] & 511)); - const uint32_t * signs = (const uint32_t *)(ksigns64 + (q2[l] >> 9)); - const int grid_l = __vsub4(grid[0] ^ signs[0], signs[0]); - const int grid_h = __vsub4(grid[1] ^ signs[1], signs[1]); - sumi1 = __dp4a(grid_l, *((const int *)q8 + 0), sumi1); - sumi1 = __dp4a(grid_h, *((const int *)q8 + 1), sumi1); - q8 += 8; +#pragma unroll + for (int l0 = 0; l0 < 8; l0 += 2) { + const uint32_t * grid_pos = (const uint32_t *)(iq2xs_grid + (q2[l0/2] & 0x000001FF)); + const uint32_t * signs = (const uint32_t *)(ksigns64 + (q2[l0/2] >> 9)); + + const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]); + const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]); + + const int u0 = get_int_b4(bq8_1[iqs/2].qs, l0 + 0); + const int u1 = get_int_b4(bq8_1[iqs/2].qs, l0 + 1); + + if (l0 < 4) { + sumi0 = ggml_cuda_dp4a(grid_l, u0, sumi0); + sumi0 = ggml_cuda_dp4a(grid_h, u1, sumi0); + } else { + sumi1 = ggml_cuda_dp4a(grid_l, u0, sumi1); + sumi1 = ggml_cuda_dp4a(grid_h, u1, sumi1); + } } - int sumi2 = 0; - for (int l = 2; l < 4; ++l) { - const uint32_t * grid = (const uint32_t *)(iq2xs_grid + (q2[l] & 511)); - const uint32_t * signs = (const uint32_t *)(ksigns64 + (q2[l] >> 9)); - const int grid_l = __vsub4(grid[0] ^ signs[0], signs[0]); - const int grid_h = __vsub4(grid[1] ^ signs[1], signs[1]); - sumi2 = __dp4a(grid_l, *((const int *)q8 + 0), sumi2); - sumi2 = __dp4a(grid_h, *((const int *)q8 + 1), sumi2); - q8 += 8; - } - const float d = (float)bq2->d * __low2float(bq8_1[ib32].ds) * 0.25f; - return d * ((0.5f + ls1) * sumi1 + (0.5f + ls2) * sumi2); -#else - GGML_UNUSED(ksigns64); - NO_DEVICE_CODE; -#endif + const int sumi = (sumi0*ls0 + sumi1*ls1 + (sumi0 + sumi1)/2)/4; + const float d = __half2float(bq2->d) * __low2float(bq8_1[iqs/2].ds); + return d * sumi; } -// TODO +#define VDR_IQ2_S_Q8_1_MMVQ 2 +#define VDR_IQ2_S_Q8_1_MMQ 2 + static __device__ __forceinline__ float vec_dot_iq2_s_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { -#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics + const block_iq2_s * bq2 = (const block_iq2_s *) vbq + kbx; - const int ib32 = iqs; - const int8_t * q8 = bq8_1[ib32].qs; - const uint8_t * signs = bq2->qs + QK_K/8 + 4*ib32; - const uint8_t ls1 = bq2->scales[ib32] & 0xf; - const uint8_t ls2 = bq2->scales[ib32] >> 4; + const int qs_packed = get_int_b2(bq2->qs, iqs/2); + const uint8_t * qs = (const uint8_t *) &qs_packed; + + const int qh = bq2->qh[iqs/2]; + + const int signs_packed_32 = get_int_b2(bq2->qs, QK_K/32 + iqs/2); + const uint8_t * signs_packed_8 = (const uint8_t *) &signs_packed_32; + + const int ls0 = bq2->scales[iqs/2] & 0x0F; + const int ls1 = bq2->scales[iqs/2] >> 4; + + int sumi0 = 0; int sumi1 = 0; - for (int l = 0; l < 2; ++l) { - const uint32_t * grid = (const uint32_t *)(iq2s_grid + (bq2->qs[4*ib32+l] | ((bq2->qh[ib32] << (8-2*l)) & 0x300))); - const uint32_t signs0 = __vcmpeq4(((signs[l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201); - const uint32_t signs1 = __vcmpeq4(((signs[l] >> 4) * 0x01010101) & 0x08040201, 0x08040201); - const int grid_l = __vsub4(grid[0] ^ signs0, signs0); - const int grid_h = __vsub4(grid[1] ^ signs1, signs1); - sumi1 = __dp4a(grid_l, *((const int *)q8 + 0), sumi1); - sumi1 = __dp4a(grid_h, *((const int *)q8 + 1), sumi1); - q8 += 8; +#pragma unroll + for (int l0 = 0; l0 < 8; l0 += 2) { + const int * grid_pos = (const int *)(iq2s_grid + (qs[l0/2] | ((qh << (8-l0)) & 0x300))); + + const int signs0 = __vcmpne4(((signs_packed_8[l0/2] & 0x03) << 7) | ((signs_packed_8[l0/2] & 0x0C) << 21), 0x00000000); + const int signs1 = __vcmpne4(((signs_packed_8[l0/2] & 0x30) << 3) | ((signs_packed_8[l0/2] & 0xC0) << 17), 0x00000000); + + const int grid_l = __vsub4(grid_pos[0] ^ signs0, signs0); + const int grid_h = __vsub4(grid_pos[1] ^ signs1, signs1); + + const int u0 = get_int_b4(bq8_1[iqs/2].qs, l0 + 0); + const int u1 = get_int_b4(bq8_1[iqs/2].qs, l0 + 1); + + if (l0 < 4) { + sumi0 = ggml_cuda_dp4a(grid_l, u0, sumi0); + sumi0 = ggml_cuda_dp4a(grid_h, u1, sumi0); + } else { + sumi1 = ggml_cuda_dp4a(grid_l, u0, sumi1); + sumi1 = ggml_cuda_dp4a(grid_h, u1, sumi1); + } } - int sumi2 = 0; - for (int l = 2; l < 4; ++l) { - const uint32_t * grid = (const uint32_t *)(iq2s_grid + (bq2->qs[4*ib32+l] | ((bq2->qh[ib32] << (8-2*l)) & 0x300))); - const uint32_t signs0 = __vcmpeq4(((signs[l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201); - const uint32_t signs1 = __vcmpeq4(((signs[l] >> 4) * 0x01010101) & 0x08040201, 0x08040201); - const int grid_l = __vsub4(grid[0] ^ signs0, signs0); - const int grid_h = __vsub4(grid[1] ^ signs1, signs1); - sumi2 = __dp4a(grid_l, *((const int *)q8 + 0), sumi2); - sumi2 = __dp4a(grid_h, *((const int *)q8 + 1), sumi2); - q8 += 8; - } - const float d = (float)bq2->d * __low2float(bq8_1[ib32].ds) * 0.25f; - return d * ((0.5f + ls1) * sumi1 + (0.5f + ls2) * sumi2); -#else - GGML_UNUSED(ksigns64); - NO_DEVICE_CODE; -#endif + const int sumi = (sumi0*ls0 + sumi1*ls1 + (sumi0 + sumi1)/2)/4; + + const float d = __half2float(bq2->d) * __low2float(bq8_1[iqs/2].ds); + return d * sumi; } +#define VDR_IQ3_XXS_Q8_1_MMVQ 2 +#define VDR_IQ3_XXS_Q8_1_MMQ 2 + static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { -#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics - const block_iq3_xxs * bq2 = (const block_iq3_xxs *) vbq + kbx; - const int ib32 = iqs; - const uint8_t * q3 = bq2->qs + 8*ib32; - const uint16_t * gas = (const uint16_t *)(bq2->qs + QK_K/4) + 2*ib32; - const int8_t * q8 = bq8_1[ib32].qs; - uint32_t aux32 = gas[0] | (gas[1] << 16); + const block_iq3_xxs * bq3 = (const block_iq3_xxs *) vbq + kbx; + + const int2 q3_packed = make_int2(get_int_b2(bq3->qs, iqs), get_int_b2(bq3->qs, iqs+1)); + const uint8_t * q3 = (const uint8_t *) &q3_packed; + const uint32_t aux32 = get_int_b2(bq3->qs, QK_K/16 + iqs/2); + int sumi = 0; - for (int l = 0; l < 4; ++l) { - const uint32_t * grid1 = iq3xxs_grid + q3[2*l+0]; - const uint32_t * grid2 = iq3xxs_grid + q3[2*l+1]; - const uint32_t * signs = (const uint32_t *)(ksigns64 + (aux32 & 127)); - const int grid_l = __vsub4(grid1[0] ^ signs[0], signs[0]); - const int grid_h = __vsub4(grid2[0] ^ signs[1], signs[1]); - sumi = __dp4a(grid_l, *((int *)q8+0), sumi); - sumi = __dp4a(grid_h, *((int *)q8+1), sumi); - q8 += 8; - aux32 >>= 7; +#pragma unroll + for (int l0 = 0; l0 < 8; l0 += 2) { + const int2 grid_pos = make_int2(iq3xxs_grid[q3[l0 + 0]], iq3xxs_grid[q3[l0 + 1]]); + + const int * signs = (const int *)(ksigns64 + ((aux32 >> (7*l0/2)) & 0x7F)); + + const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]); + const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]); + + const int u0 = get_int_b4(bq8_1[iqs/2].qs, l0 + 0); + const int u1 = get_int_b4(bq8_1[iqs/2].qs, l0 + 1); + + sumi = ggml_cuda_dp4a(grid_l, u0, sumi); + sumi = ggml_cuda_dp4a(grid_h, u1, sumi); } - const float d = (float)bq2->d * (0.5f + aux32) * __low2float(bq8_1[ib32].ds) * 0.5f; + + const int ls = aux32 >> 28; + sumi = (ls*sumi + sumi/2)/2; + const float d = __half2float(bq3->d) * __low2float(bq8_1[iqs/2].ds); return d * sumi; -#else - NO_DEVICE_CODE; -#endif } +#define VDR_IQ3_S_Q8_1_MMVQ 2 +#define VDR_IQ3_S_Q8_1_MMQ 2 + // TODO: don't use lookup table for signs static __device__ __forceinline__ float vec_dot_iq3_s_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { -#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics - const block_iq3_s * bq2 = (const block_iq3_s *) vbq + kbx; - const int ib32 = iqs; - const uint8_t * qs = bq2->qs + 8*ib32; - const int8_t * q8 = bq8_1[ib32].qs; + const block_iq3_s * bq3 = (const block_iq3_s *) vbq + kbx; + + const int2 qs_packed = make_int2(get_int_b2(bq3->qs, iqs + 0), get_int_b2(bq3->qs, iqs + 1)); + const uint8_t * qs = (const uint8_t *) &qs_packed; + + const int qh = bq3->qh[iqs/2]; + + const int signs_packed_32 = get_int_b2(bq3->signs, iqs/2); + const uint8_t * signs_packed_8 = (const uint8_t *) &signs_packed_32; + int sumi = 0; - for (int l = 0; l < 4; ++l) { - const uint32_t * grid1 = iq3s_grid + (qs[2*l+0] | ((bq2->qh[ib32] << (8 - 2*l)) & 256)); - const uint32_t * grid2 = iq3s_grid + (qs[2*l+1] | ((bq2->qh[ib32] << (7 - 2*l)) & 256)); - uint32_t signs0 = __vcmpeq4(((bq2->signs[4*ib32+l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201); - uint32_t signs1 = __vcmpeq4(((bq2->signs[4*ib32+l] >> 4) * 0x01010101) & 0x08040201, 0x08040201); - const int grid_l = __vsub4(grid1[0] ^ signs0, signs0); - const int grid_h = __vsub4(grid2[0] ^ signs1, signs1); - sumi = __dp4a(grid_l, *((int *)q8+0), sumi); - sumi = __dp4a(grid_h, *((int *)q8+1), sumi); - q8 += 8; +#pragma unroll + for (int l0 = 0; l0 < 8; l0 += 2) { + const int2 grid_pos = make_int2( + iq3s_grid[qs[l0 + 0] | ((qh << (8 - l0)) & 0x100)], + iq3s_grid[qs[l0 + 1] | ((qh << (7 - l0)) & 0x100)]); + + const int signs0 = __vcmpne4(((signs_packed_8[l0/2] & 0x03) << 7) | ((signs_packed_8[l0/2] & 0x0C) << 21), 0x00000000); + const int signs1 = __vcmpne4(((signs_packed_8[l0/2] & 0x30) << 3) | ((signs_packed_8[l0/2] & 0xC0) << 17), 0x00000000); + + const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0); + const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1); + + const int u0 = get_int_b4(bq8_1[iqs/2].qs, l0 + 0); + const int u1 = get_int_b4(bq8_1[iqs/2].qs, l0 + 1); + + sumi = ggml_cuda_dp4a(grid_l, u0, sumi); + sumi = ggml_cuda_dp4a(grid_h, u1, sumi); } - const float d = (float)bq2->d * (1 + 2*((bq2->scales[ib32/2] >> 4*(ib32%2)) & 0xf)) * __low2float(bq8_1[ib32].ds); + + sumi *= 1 + 2*((bq3->scales[iqs/4] >> ((iqs << 1) & 0x04)) & 0x0F); + + const float d = __half2float(bq3->d) * __low2float(bq8_1[iqs/2].ds); return d * sumi; -#else - NO_DEVICE_CODE; -#endif } +#define VDR_IQ1_S_Q8_1_MMVQ 1 +#define VDR_IQ1_S_Q8_1_MMQ 1 + static __device__ __forceinline__ float vec_dot_iq1_s_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { const block_iq1_s * bq1 = (const block_iq1_s *) vbq + kbx; - const int ib32 = iqs; + const int qs_packed = get_int_b2(bq1->qs, iqs); + const uint8_t * qs = (const uint8_t *) &qs_packed; + + const int qh = bq1->qh[iqs]; + int sumi = 0; -#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics - const int * q8 = (const int *)bq8_1[ib32].qs; - for (int l = 0; l < 4; ++l) { - const int * grid = (const int *)(iq1s_grid_gpu + (bq1->qs[4*ib32+l] | (((bq1->qh[ib32] >> 3*l) & 7) << 8))); - int grid0 = grid[0] & 0x0f0f0f0f; - int grid1 = (grid[0] >> 4) & 0x0f0f0f0f; - sumi = __dp4a(q8[2*l+1], grid1, __dp4a(q8[2*l+0], grid0, sumi)); +#pragma unroll + for (int l0 = 0; l0 < 8; l0 += 2) { + const int grid = iq1s_grid_gpu[qs[l0/2] | (((qh >> 3*(l0/2)) & 0x07) << 8)]; + + const int grid0 = (grid >> 0) & 0x0F0F0F0F; + const int grid1 = (grid >> 4) & 0x0F0F0F0F; + + const int u0 = get_int_b4(bq8_1[iqs].qs, l0 + 0); + const int u1 = get_int_b4(bq8_1[iqs].qs, l0 + 1); + + sumi = ggml_cuda_dp4a(grid0, u0, sumi); + sumi = ggml_cuda_dp4a(grid1, u1, sumi); } -#else - const int8_t * q8 = bq8_1[ib32].qs; - for (int l = 0; l < 4; ++l) { - const uint8_t * grid = (const uint8_t *)(iq1s_grid_gpu + (bq1->qs[4*ib32+l] | (((bq1->qh[ib32] >> 3*l) & 7) << 8))); - for (int j = 0; j < 4; ++j) { - sumi += q8[j] * (grid[j] & 0xf) + q8[j+4] * (grid[j] >> 4); - } - q8 += 8; - } -#endif - const float delta = bq1->qh[ib32] & 0x8000 ? -1-IQ1S_DELTA : -1+IQ1S_DELTA; - const float d1q = (float)bq1->d * (2*((bq1->qh[ib32] >> 12) & 7) + 1); - const float d = d1q * __low2float (bq8_1[ib32].ds); - const float m = d1q * __high2float(bq8_1[ib32].ds); - return d * sumi + m * delta; + + const float d1q = __half2float(bq1->d) * (((qh >> 11) & 0x0E) + 1); + const float delta = -1.0f + IQ1S_DELTA - (qh & 0x8000) * (2.0f*IQ1S_DELTA/0x8000); + const float2 ds = __half22float2(bq8_1[iqs].ds); + return d1q * (ds.x*sumi + ds.y*delta); } +#define VDR_IQ1_M_Q8_1_MMVQ 1 +#define VDR_IQ1_M_Q8_1_MMQ 1 + static __device__ __forceinline__ float vec_dot_iq1_m_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { + const block_iq1_m * bq1 = (const block_iq1_m *) vbq + kbx; - const int ib32 = iqs; - int sumi[2] = {0, 0}; - float sumf[2] = {0.f, 0.f}; -#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics - const int * q8 = (const int *)bq8_1[ib32].qs; - for (int l = 0; l < 4; ++l) { - const int * grid = (const int *)(iq1s_grid_gpu + (bq1->qs[4*ib32+l] | (((bq1->qh[2*ib32+l/2] >> 4*(l%2)) & 7) << 8))); - int grid0 = grid[0] & 0x0f0f0f0f; - int grid1 = (grid[0] >> 4) & 0x0f0f0f0f; - sumi[l/2] = __dp4a(q8[2*l+1], grid1, __dp4a(q8[2*l+0], grid0, sumi[l/2])); - const float delta = (bq1->qh[2*ib32+l/2] >> 4*(l%2)) & 0x08 ? -1-IQ1M_DELTA : -1+IQ1M_DELTA; - const int sumy = __dp4a(q8[2*l+1], 0x01010101, __dp4a(q8[2*l+0], 0x01010101, 0)); - sumf[l/2] += delta*sumy; - } -#else - const int8_t * q8 = bq8_1[ib32].qs; - for (int l = 0; l < 4; ++l) { - const uint8_t * grid = (const uint8_t *)(iq1s_grid_gpu + (bq1->qs[4*ib32+l] | (((bq1->qh[ib32] >> 3*l) & 7) << 8))); + const int qs_packed = get_int_b4(bq1->qs, iqs); + const uint8_t * qs = (const uint8_t *) &qs_packed; + + int sumi[2] = {0}; + float sumf[2] = {0.0f}; +#pragma unroll + for (int l0 = 0; l0 < 8; l0 += 2) { + const int qhl = bq1->qh[2*iqs + l0/4] >> (4 * ((l0/2) % 2)); + + const int grid = iq1s_grid_gpu[qs[l0/2] | ((qhl & 0x07) << 8)]; + + const int grid0 = (grid >> 0) & 0x0F0F0F0F; + const int grid1 = (grid >> 4) & 0x0F0F0F0F; + + const int u0 = get_int_b4(bq8_1[iqs].qs, l0 + 0); + const int u1 = get_int_b4(bq8_1[iqs].qs, l0 + 1); + + sumi[l0/4] = ggml_cuda_dp4a(grid0, u0, sumi[l0/4]); + sumi[l0/4] = ggml_cuda_dp4a(grid1, u1, sumi[l0/4]); + + const float delta = -1.0f + IQ1M_DELTA - (qhl & 0x08) * (2.0f*IQ1M_DELTA/0x08); int sumy = 0; - for (int j = 0; j < 4; ++j) { - sumi[l/2] += q8[j] * (grid[j] & 0xf) + q8[j+4] * (grid[j] >> 4); - sumy += q8[j] + q8[j+4]; - } - const float delta = (bq1->qh[2*ib32+l/2] >> 4*(l%2)) & 0x08 ? -1-IQ1M_DELTA : -1+IQ1M_DELTA; - sumf[l/2] += delta*sumy; - q8 += 8; + sumy = ggml_cuda_dp4a(u0, 0x01010101, sumy); + sumy = ggml_cuda_dp4a(u1, 0x01010101, sumy); + sumf[l0/4] += delta*sumy; } -#endif + + const uint16_t * sc = (const uint16_t *) bq1->scales; + iq1m_scale_t scale; - const uint16_t * sc = (const uint16_t *)bq1->scales; - scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); - const float d = (float)scale.f16 * __low2float (bq8_1[ib32].ds); - return d * ((sumi[0] + sumf[0]) * (2*((sc[ib32/2] >> 6*(ib32%2)) & 0x7) + 1) + (sumi[1] + sumf[1]) * (2*((sc[ib32/2] >> (6*(ib32%2)+3)) & 0x7) + 1)); + scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00F0) | ((sc[2] >> 4) & 0x0F00) | (sc[3] & 0xF000); + const float d = __half2float(scale.f16) * __low2float(bq8_1[iqs].ds); + + const int tmp = sc[iqs/2] >> (6*(iqs%2)); + const int sc0 = 2*((tmp >> 0) & 0x07) + 1; + const int sc1 = 2*((tmp >> 3) & 0x07) + 1; + return d * ((sumi[0] + sumf[0]) * sc0 + (sumi[1] + sumf[1]) * sc1); } -#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics -static __device__ __forceinline__ void get_int_from_table_16(const uint32_t & q4, const uint8_t * values, - int & val1, int & val2) { +static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4) { + const int q0_32 = (q4 >> 0) & 0x0F0F0F0F; + const int8_t * q0_8 = (const int8_t *) &q0_32; + const char4 val0_8 = make_char4( + kvalues_iq4nl[q0_8[0]], kvalues_iq4nl[q0_8[1]], kvalues_iq4nl[q0_8[2]], kvalues_iq4nl[q0_8[3]]); - uint32_t aux32; const uint8_t * q8 = (const uint8_t *)&aux32; - aux32 = q4 & 0x0f0f0f0f; - uint16_t v1 = values[q8[0]] | (values[q8[1]] << 8); - uint16_t v2 = values[q8[2]] | (values[q8[3]] << 8); - val1 = v1 | (v2 << 16); - aux32 = (q4 >> 4) & 0x0f0f0f0f; - v1 = values[q8[0]] | (values[q8[1]] << 8); - v2 = values[q8[2]] | (values[q8[3]] << 8); - val2 = v1 | (v2 << 16); + const int q1_32 = (q4 >> 4) & 0x0F0F0F0F; + const int8_t * q1_8 = (const int8_t *) &q1_32; + const char4 val1_8 = make_char4( + kvalues_iq4nl[q1_8[0]], kvalues_iq4nl[q1_8[1]], kvalues_iq4nl[q1_8[2]], kvalues_iq4nl[q1_8[3]]); + + return make_int2(*((const int *) &val0_8), *((const int *) &val1_8)); } -#endif + +#define VDR_IQ4_NL_Q8_1_MMVQ 2 +#define VDR_IQ4_NL_Q8_1_MMQ 4 static __device__ __forceinline__ float vec_dot_iq4_nl_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { - const block_iq4_nl * bq = (const block_iq4_nl *) vbq + kbx; + const block_iq4_nl * bq4 = (const block_iq4_nl *) vbq + kbx; -#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics - const uint16_t * q4 = (const uint16_t *)bq->qs + 2*iqs; - const int32_t * q8 = (const int32_t *)bq8_1->qs + iqs; + const int * q8 = (const int *) bq8_1->qs + iqs; - const uint8_t * values = (const uint8_t *)kvalues_iq4nl; - - int v1, v2; - int sumi1 = 0, sumi2 = 0; + int sumi = 0; +#pragma unroll for (int l = 0; l < VDR_Q4_0_Q8_1_MMVQ; ++l) { - const uint32_t aux = q4[2*l] | (q4[2*l+1] << 16); - get_int_from_table_16(aux, values, v1, v2); - sumi1 = __dp4a(v1, q8[l+0], sumi1); - sumi2 = __dp4a(v2, q8[l+4], sumi2); + const int aux_q4 = get_int_b2(bq4->qs, iqs + l); + const int2 v = get_int_from_table_16(aux_q4); + + sumi = ggml_cuda_dp4a(v.x, q8[l + 0], sumi); + sumi = ggml_cuda_dp4a(v.y, q8[l + 4], sumi); } -#else - const uint8_t * q4 = bq->qs + 4*iqs; - const int8_t * q8 = bq8_1->qs + 4*iqs; - - int sumi1 = 0, sumi2 = 0; - for (int l = 0; l < 4*VDR_Q4_0_Q8_1_MMVQ; ++l) { - sumi1 += q8[l+ 0] * kvalues_iq4nl[q4[l] & 0xf]; - sumi2 += q8[l+16] * kvalues_iq4nl[q4[l] >> 4]; - } -#endif - const float d = (float)bq->d * __low2float(bq8_1->ds); - return d * (sumi1 + sumi2); + const float d = __half2float(bq4->d) * __low2float(bq8_1->ds); + return d * sumi; } +#define VDR_IQ4_XS_Q8_1_MMVQ 4 +#define VDR_IQ4_XS_Q8_1_MMQ 4 + static __device__ __forceinline__ float vec_dot_iq4_xs_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { -#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics const block_iq4_xs * bq4 = (const block_iq4_xs *) vbq + kbx; - const uint8_t * values = (const uint8_t *)kvalues_iq4nl; - // iqs is 0...7 - const int ib32 = iqs; - const int32_t * q8 = (const int *)bq8_1[ib32].qs; - const uint32_t * q4 = (const uint32_t *)bq4->qs + 4*ib32; - const int8_t ls = ((bq4->scales_l[ib32/2] >> 4*(ib32%2)) & 0xf) | (((bq4->scales_h >> 2*ib32) & 3) << 4); - const float d = (float)bq4->d * (ls - 32) * __low2float(bq8_1[ib32].ds); - int v1, v2; - int sumi1 = 0, sumi2 = 0; + int sumi = 0; +#pragma unroll for (int j = 0; j < 4; ++j) { - get_int_from_table_16(q4[j], values, v1, v2); - sumi1 = __dp4a(v1, q8[j+0], sumi1); - sumi2 = __dp4a(v2, q8[j+4], sumi2); + const int aux_q4 = get_int_b4(bq4->qs, iqs + j); + const int2 v = get_int_from_table_16(aux_q4); + + const int u0 = get_int_b4(bq8_1[iqs/4].qs, j + 0); + const int u1 = get_int_b4(bq8_1[iqs/4].qs, j + 4); + + sumi = ggml_cuda_dp4a(v.x, u0, sumi); + sumi = ggml_cuda_dp4a(v.y, u1, sumi); } - return d * (sumi1 + sumi2); -#else - return vec_dot_iq4_xs_q8_1(vbq, bq8_1, kbx, iqs); -#endif + + const int ls = ((bq4->scales_l[iqs/8] >> (iqs & 0x04)) & 0x0F) | (((bq4->scales_h >> (iqs/2)) & 0x03) << 4); + sumi *= ls - 32; + + const float d = __half2float(bq4->d) * __low2float(bq8_1[iqs/4].ds); + return d * sumi; } diff --git a/llama/ggml-impl.h b/llama/ggml-impl.h index f16d6674..80ca886d 100644 --- a/llama/ggml-impl.h +++ b/llama/ggml-impl.h @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * @@ -43,7 +43,7 @@ #define MIN(a, b) ((a) < (b) ? (a) : (b)) #define MAX(a, b) ((a) > (b) ? (a) : (b)) -#if defined(_WIN32) +#if defined(_MSC_VER) #define m512bh(p) p #define m512i(p) p @@ -635,6 +635,10 @@ static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) { #endif // defined(__ARM_NEON) && (!defined(__MSC_VER) +#ifdef __ARM_FEATURE_SVE +#include +#endif // __ARM_FEATURE_SVE + // precomputed f32 table for f16 (256 KB) // defined in ggml.c, initialized in ggml_init() extern float ggml_table_f32_f16[1 << 16]; @@ -656,21 +660,121 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) { #define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x) #endif -#define GGML_HASHTABLE_FULL ((size_t)-1) -#define GGML_HASHTABLE_ALREADY_EXISTS ((size_t)-2) +// bitset + +static_assert(sizeof(ggml_bitset_t) == 4, "bitset_t constants must be updated"); +#define BITSET_SHR 5 // log2(sizeof(ggml_bitset_t)*8) +#define BITSET_MASK (sizeof(ggml_bitset_t)*8 - 1) + +static size_t ggml_bitset_size(size_t n) { + return (n + BITSET_MASK) >> BITSET_SHR; +} + +static inline bool ggml_bitset_get(const ggml_bitset_t * bitset, size_t i) { + return !!(bitset[i >> BITSET_SHR] & (1u << (i & BITSET_MASK))); +} + +static inline void ggml_bitset_set(ggml_bitset_t * bitset, size_t i) { + bitset[i >> BITSET_SHR] |= (1u << (i & BITSET_MASK)); +} + +static inline void ggml_bitset_clear(ggml_bitset_t * bitset, size_t i) { + bitset[i >> BITSET_SHR] &= ~(1u << (i & BITSET_MASK)); +} + +// hash set + +#define GGML_HASHSET_FULL ((size_t)-1) +#define GGML_HASHSET_ALREADY_EXISTS ((size_t)-2) struct ggml_hash_set ggml_hash_set_new(size_t size); +void ggml_hash_set_free(struct ggml_hash_set * hash_set); -bool ggml_hash_contains (const struct ggml_hash_set hash_set, struct ggml_tensor * key); +// returns the minimum size for a hash set that can hold min_sz elements +size_t ggml_hash_size(size_t min_sz); -// returns GGML_HASHTABLE_FULL if table is full, otherwise the current index of the key or where it should be inserted -size_t ggml_hash_find (const struct ggml_hash_set hash_set, struct ggml_tensor * key); +// remove all elements from the hash set +void ggml_hash_set_reset(struct ggml_hash_set * hash_set); -// returns GGML_HASHTABLE_ALREADY_EXISTS if key already exists, index otherwise, asserts if table is full -size_t ggml_hash_insert ( struct ggml_hash_set hash_set, struct ggml_tensor * key); +// returns true if key is in the hash set +static bool ggml_hash_contains(const struct ggml_hash_set * hash_set, struct ggml_tensor * key); + +// returns GGML_HASHSET_FULL if table is full, otherwise the current index of the key or where it should be inserted +static size_t ggml_hash_find(const struct ggml_hash_set * hash_set, struct ggml_tensor * key); + +// returns GGML_HASHSET_ALREADY_EXISTS if key already exists, index otherwise, asserts if table is full +static size_t ggml_hash_insert(struct ggml_hash_set * hash_set, struct ggml_tensor * key); // return index, asserts if table is full -size_t ggml_hash_find_or_insert( struct ggml_hash_set hash_set, struct ggml_tensor * key); +static size_t ggml_hash_find_or_insert(struct ggml_hash_set * hash_set, struct ggml_tensor * key); + +// hash function for ggml_tensor +static inline size_t ggml_hash(const struct ggml_tensor * p) { + // the last 4 bits are always zero due to alignment + return (size_t)(uintptr_t)p >> 4; +} + +static size_t ggml_hash_find(const struct ggml_hash_set * hash_set, struct ggml_tensor * key) { + size_t h = ggml_hash(key) % hash_set->size; + + // linear probing + size_t i = h; + while (ggml_bitset_get(hash_set->used, i) && hash_set->keys[i] != key) { + i = (i + 1) % hash_set->size; + if (i == h) { + // visited all hash table entries -> not found + return GGML_HASHSET_FULL; + } + } + return i; +} + +static bool ggml_hash_contains(const struct ggml_hash_set * hash_set, struct ggml_tensor * key) { + size_t i = ggml_hash_find(hash_set, key); + return i != GGML_HASHSET_FULL && ggml_bitset_get(hash_set->used, i); +} + +static size_t ggml_hash_insert(struct ggml_hash_set * hash_set, struct ggml_tensor * key) { + size_t h = ggml_hash(key) % hash_set->size; + + // linear probing + size_t i = h; + do { + if (!ggml_bitset_get(hash_set->used, i)) { + ggml_bitset_set(hash_set->used, i); + hash_set->keys[i] = key; + return i; + } + if (hash_set->keys[i] == key) { + return GGML_HASHSET_ALREADY_EXISTS; + } + i = (i + 1) % hash_set->size; + } while (i != h); + + // visited all hash table entries -> not found + GGML_ABORT("fatal error"); +} + +static size_t ggml_hash_find_or_insert(struct ggml_hash_set * hash_set, struct ggml_tensor * key) { + size_t h = ggml_hash(key) % hash_set->size; + + // linear probing + size_t i = h; + do { + if (!ggml_bitset_get(hash_set->used, i)) { + ggml_bitset_set(hash_set->used, i); + hash_set->keys[i] = key; + return i; + } + if (hash_set->keys[i] == key) { + return i; + } + i = (i + 1) % hash_set->size; + } while (i != h); + + // visited all hash table entries -> not found + GGML_ABORT("fatal error"); +} #ifdef __cplusplus } diff --git a/llama/ggml-metal-darwin_arm64.m b/llama/ggml-metal-darwin_arm64.m index a5c7d37b..67b638ac 100644 --- a/llama/ggml-metal-darwin_arm64.m +++ b/llama/ggml-metal-darwin_arm64.m @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * @@ -219,16 +219,16 @@ enum ggml_metal_kernel_type { //GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261 GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, //GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261 - GGML_METAL_KERNEL_TYPE_CPY_F32_F16, GGML_METAL_KERNEL_TYPE_CPY_F32_F32, + GGML_METAL_KERNEL_TYPE_CPY_F32_F16, + GGML_METAL_KERNEL_TYPE_CPY_F16_F16, + GGML_METAL_KERNEL_TYPE_CPY_F16_F32, GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, - GGML_METAL_KERNEL_TYPE_CPY_F16_F16, - GGML_METAL_KERNEL_TYPE_CPY_F16_F32, GGML_METAL_KERNEL_TYPE_CONCAT, GGML_METAL_KERNEL_TYPE_SQR, GGML_METAL_KERNEL_TYPE_SUM_ROWS, @@ -677,14 +677,14 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true); @@ -761,6 +761,12 @@ static id ggml_metal_get_buffer(struct ggml_tensor * t, size_t * offs } static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const struct ggml_tensor * op) { + for (size_t i = 0, n = 3; i < n; ++i) { + if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16) { + return false; + } + } + switch (op->op) { case GGML_OP_UNARY: switch (ggml_get_unary_op(op)) { @@ -770,7 +776,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const case GGML_UNARY_OP_GELU: case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_SILU: - return true; + return ggml_is_contiguous(op->src[0]); default: return false; } @@ -830,8 +836,8 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const switch (op->src[0]->type) { case GGML_TYPE_F32: switch (op->type) { - case GGML_TYPE_F16: case GGML_TYPE_F32: + case GGML_TYPE_F16: case GGML_TYPE_Q8_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: @@ -844,8 +850,8 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const } case GGML_TYPE_F16: switch (op->type) { - case GGML_TYPE_F16: case GGML_TYPE_F32: + case GGML_TYPE_F16: return true; default: return false; @@ -857,7 +863,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const case GGML_OP_DIAG_MASK_INF: case GGML_OP_GET_ROWS: { - return op->src[0]->type != GGML_TYPE_BF16 && op->ne[3] == 1; + return op->ne[3] == 1; } default: return false; @@ -889,7 +895,7 @@ static enum ggml_status ggml_metal_graph_compute( NSError * error = nil; if (![[MTLCaptureManager sharedCaptureManager] startCaptureWithDescriptor:descriptor error:&error]) { GGML_METAL_LOG_ERROR("%s: error: unable to start capture '%s'\n", __func__, [[error localizedDescription] UTF8String]); - GGML_ASSERT(!"capture failed"); + GGML_ABORT("capture failed"); } } @@ -951,7 +957,7 @@ static enum ggml_status ggml_metal_graph_compute( if (!ggml_metal_supports_op(ctx, dst)) { GGML_METAL_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst)); - GGML_ASSERT(!"unsupported op"); + GGML_ABORT("unsupported op"); } if (should_capture) { @@ -1088,7 +1094,7 @@ static enum ggml_status ggml_metal_graph_compute( case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break; case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline; break; case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline; break; - default: GGML_ASSERT(false); + default: GGML_ABORT("fatal error"); } bcast_row = true; @@ -1097,7 +1103,7 @@ static enum ggml_status ggml_metal_graph_compute( case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; break; case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break; case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break; - default: GGML_ASSERT(false); + default: GGML_ABORT("fatal error"); } } @@ -1151,7 +1157,7 @@ static enum ggml_status ggml_metal_graph_compute( case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F16].pipeline; break; case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I32].pipeline; break; case GGML_TYPE_I16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I16].pipeline; break; - default: GGML_ASSERT(false); + default: GGML_ABORT("fatal error"); } [encoder setComputePipelineState:pipeline]; @@ -1407,7 +1413,7 @@ static enum ggml_status ggml_metal_graph_compute( default: { GGML_METAL_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op)); - GGML_ASSERT(false); + GGML_ABORT("fatal error"); } } break; case GGML_OP_SQR: @@ -1596,8 +1602,8 @@ static enum ggml_status ggml_metal_graph_compute( // some Metal matrix data types require aligned pointers // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5) switch (src0->type) { - case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break; - case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break; + case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break; + case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break; default: break; } @@ -1625,7 +1631,7 @@ static enum ggml_status ggml_metal_graph_compute( case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32 ].pipeline; break; case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break; case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break; - default: GGML_ASSERT(false && "MUL MAT-MAT not implemented"); + default: GGML_ABORT("MUL MAT-MAT not implemented"); } [encoder setComputePipelineState:pipeline]; @@ -1798,14 +1804,10 @@ static enum ggml_status ggml_metal_graph_compute( default: { GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t); - GGML_ASSERT(false && "not implemented"); + GGML_ABORT("not implemented"); } }; - if (ggml_is_quantized(src0t)) { - GGML_ASSERT(ne00 >= nth0*nth1); - } - [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; @@ -1884,9 +1886,10 @@ static enum ggml_status ggml_metal_graph_compute( // ne21 = n_rows const int dst_rows = ne20*ne21; const int dst_rows_min = n_as; + const int dst_rows_max = (ctx->device.maxThreadgroupMemoryLength - 32 - 8192)/4; // max size of the rowids array in the kernel shared buffer - GGML_ASSERT(dst_rows <= 2048); + GGML_ASSERT(dst_rows <= dst_rows_max); // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel @@ -1930,7 +1933,7 @@ static enum ggml_status ggml_metal_graph_compute( case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32 ].pipeline; break; case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break; case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32 ].pipeline; break; - default: GGML_ASSERT(false && "MUL_MAT_ID not implemented"); + default: GGML_ABORT("MUL_MAT_ID not implemented"); } [encoder setComputePipelineState:pipeline]; @@ -2097,7 +2100,7 @@ static enum ggml_status ggml_metal_graph_compute( default: { GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t); - GGML_ASSERT(false && "not implemented"); + GGML_ABORT("not implemented"); } }; @@ -2197,7 +2200,7 @@ static enum ggml_status ggml_metal_graph_compute( case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline; break; case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS ].pipeline; break; case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break; - default: GGML_ASSERT(false && "not implemented"); + default: GGML_ABORT("not implemented"); } [encoder setComputePipelineState:pipeline]; @@ -2335,13 +2338,13 @@ static enum ggml_status ggml_metal_graph_compute( switch (src0->type) { case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break; case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break; - default: GGML_ASSERT(false); + default: GGML_ABORT("fatal error"); }; } else { switch (src0->type) { case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break; case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break; - default: GGML_ASSERT(false); + default: GGML_ABORT("fatal error"); }; } @@ -2418,7 +2421,7 @@ static enum ggml_status ggml_metal_graph_compute( switch (dst->type) { case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline; break; case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline; break; - default: GGML_ASSERT(false); + default: GGML_ABORT("fatal error"); }; [encoder setComputePipelineState:pipeline]; @@ -2575,7 +2578,7 @@ static enum ggml_status ggml_metal_graph_compute( switch (order) { case GGML_SORT_ORDER_ASC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC].pipeline; break; case GGML_SORT_ORDER_DESC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC].pipeline; break; - default: GGML_ASSERT(false); + default: GGML_ABORT("fatal error"); }; [encoder setComputePipelineState:pipeline]; @@ -2664,7 +2667,7 @@ static enum ggml_status ggml_metal_graph_compute( { GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00); GGML_METAL_LOG_ERROR("add template specialization for this size\n"); - GGML_ASSERT(false && "add template specialization for this size"); + GGML_ABORT("add template specialization for this size"); } } } else { @@ -2677,7 +2680,7 @@ static enum ggml_status ggml_metal_graph_compute( { GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00); GGML_METAL_LOG_ERROR("add template specialization for this size\n"); - GGML_ASSERT(false && "add template specialization for this size"); + GGML_ABORT("add template specialization for this size"); } } } @@ -2790,26 +2793,26 @@ static enum ggml_status ggml_metal_graph_compute( GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0); switch (dstt) { - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break; - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break; + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break; case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break; case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break; case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break; case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0].pipeline; break; case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1].pipeline; break; case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL].pipeline; break; - default: GGML_ASSERT(false && "not implemented"); + default: GGML_ABORT("not implemented"); }; } break; case GGML_TYPE_F16: { switch (dstt) { - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline; break; - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F32].pipeline; break; - default: GGML_ASSERT(false && "not implemented"); + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline; break; + default: GGML_ABORT("not implemented"); }; } break; - default: GGML_ASSERT(false && "not implemented"); + default: GGML_ABORT("not implemented"); } [encoder setComputePipelineState:pipeline]; @@ -2837,7 +2840,7 @@ static enum ggml_status ggml_metal_graph_compute( default: { GGML_METAL_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op)); - GGML_ASSERT(false); + GGML_ABORT("fatal error"); } } @@ -3066,12 +3069,6 @@ GGML_CALL static size_t ggml_backend_metal_buffer_type_get_max_size(ggml_backend UNUSED(buft); } -GGML_CALL static bool ggml_backend_metal_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) { - return ggml_backend_is_metal(backend) || ggml_backend_is_cpu(backend); - - UNUSED(buft); -} - GGML_CALL static bool ggml_backend_metal_buffer_type_is_host(ggml_backend_buffer_type_t buft) { return true; @@ -3086,7 +3083,6 @@ GGML_CALL ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) { /* .get_alignment = */ ggml_backend_metal_buffer_type_get_alignment, /* .get_max_size = */ ggml_backend_metal_buffer_type_get_max_size, /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes - /* .supports_backend = */ ggml_backend_metal_buffer_type_supports_backend, /* .is_host = */ ggml_backend_metal_buffer_type_is_host, }, /* .context = */ NULL, @@ -3201,6 +3197,12 @@ GGML_CALL static bool ggml_backend_metal_supports_op(ggml_backend_t backend, con return ggml_metal_supports_op(metal_ctx, op); } +GGML_CALL static bool ggml_backend_metal_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) { + return buft->iface.get_name == ggml_backend_metal_buffer_type_get_name; + + UNUSED(backend); +} + static struct ggml_backend_i ggml_backend_metal_i = { /* .get_name = */ ggml_backend_metal_name, /* .free = */ ggml_backend_metal_free, @@ -3211,9 +3213,11 @@ static struct ggml_backend_i ggml_backend_metal_i = { /* .synchronize = */ NULL, /* .graph_plan_create = */ NULL, /* .graph_plan_free = */ NULL, + /* .graph_plan_update = */ NULL, /* .graph_plan_compute = */ NULL, /* .graph_compute = */ ggml_backend_metal_graph_compute, /* .supports_op = */ ggml_backend_metal_supports_op, + /* .supports_buft = */ ggml_backend_metal_supports_buft, /* .offload_op = */ NULL, /* .event_new = */ NULL, /* .event_free = */ NULL, diff --git a/llama/ggml-metal.h b/llama/ggml-metal.h index a44dd32a..be606ecd 100644 --- a/llama/ggml-metal.h +++ b/llama/ggml-metal.h @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * @@ -89,4 +89,3 @@ GGML_API void ggml_backend_metal_capture_next_compute(ggml_backend_t backend); #ifdef __cplusplus } #endif - diff --git a/llama/ggml-metal.metal b/llama/ggml-metal.metal index e081496c..287ff1ce 100644 --- a/llama/ggml-metal.metal +++ b/llama/ggml-metal.metal @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * @@ -1245,9 +1245,10 @@ kernel void kernel_mul_mv_q8_0_f32( kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); } -#define N_F32_F32 4 +#define N_MV_T_T 4 -void kernel_mul_mv_f32_f32_impl( +template +void kernel_mul_mv_impl( device const char * src0, device const char * src1, device float * dst, @@ -1265,13 +1266,12 @@ void kernel_mul_mv_f32_f32_impl( uint64_t nb12, int64_t ne0, int64_t ne1, - uint r2, - uint r3, - uint3 tgpig, - uint tiisg) { - + uint r2, + uint r3, + uint3 tgpig, + uint tiisg) { const int64_t r0 = tgpig.x; - const int64_t rb = tgpig.y*N_F32_F32; + const int64_t rb = tgpig.y*N_MV_T_T; const int64_t im = tgpig.z; const uint i12 = im%ne12; @@ -1279,20 +1279,20 @@ void kernel_mul_mv_f32_f32_impl( const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; - device const float * x = (device const float *) (src0 + offset0); + device const T0 * x = (device const T0 *) (src0 + offset0); if (ne00 < 128) { - for (int row = 0; row < N_F32_F32; ++row) { + for (int row = 0; row < N_MV_T_T; ++row) { int r1 = rb + row; if (r1 >= ne11) { break; } - device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); + device const T1 * y = (device const T1 *) (src1 + r1*nb11 + im*nb12); float sumf = 0; for (int i = tiisg; i < ne00; i += 32) { - sumf += (float) x[i] * (float) y[i]; + sumf += (T0) x[i] * (T1) y[i]; } float all_sum = simd_sum(sumf); @@ -1301,32 +1301,32 @@ void kernel_mul_mv_f32_f32_impl( } } } else { - device const float4 * x4 = (device const float4 *)x; - for (int row = 0; row < N_F32_F32; ++row) { + device const T04 * x4 = (device const T04 *) x; + for (int row = 0; row < N_MV_T_T; ++row) { int r1 = rb + row; if (r1 >= ne11) { break; } - device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); - device const float4 * y4 = (device const float4 *) y; + device const T1 * y = (device const T1 *) (src1 + r1*nb11 + im*nb12); + device const T14 * y4 = (device const T14 *) y; float sumf = 0; for (int i = tiisg; i < ne00/4; i += 32) { - for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k]; + for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]); } float all_sum = simd_sum(sumf); if (tiisg == 0) { - for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i]; + for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) (x[i] * y[i]); dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; } } } } -[[host_name("kernel_mul_mv_f32_f32")]] -kernel void kernel_mul_mv_f32_f32( +template +kernel void kernel_mul_mv( device const char * src0, device const char * src1, device float * dst, @@ -1348,90 +1348,38 @@ kernel void kernel_mul_mv_f32_f32( constant uint & r3, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]]) { - kernel_mul_mv_f32_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg); + kernel_mul_mv_impl( + src0, + src1, + dst, + ne00, + ne01, + ne02, + nb00, + nb01, + nb02, + ne10, + ne11, + ne12, + nb10, + nb11, + nb12, + ne0, + ne1, + r2, + r3, + tgpig, + tiisg); } -#define N_F16_F16 4 +typedef decltype(kernel_mul_mv) mul_mv_t; -kernel void kernel_mul_mv_f16_f16( - device const char * src0, - device const char * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]]) { +template [[host_name("kernel_mul_mv_f32_f32")]] kernel mul_mv_t kernel_mul_mv; +template [[host_name("kernel_mul_mv_f16_f32")]] kernel mul_mv_t kernel_mul_mv; +template [[host_name("kernel_mul_mv_f16_f16")]] kernel mul_mv_t kernel_mul_mv; - const int64_t r0 = tgpig.x; - const int64_t rb = tgpig.y*N_F16_F16; - const int64_t im = tgpig.z; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; - - device const half * x = (device const half *) (src0 + offset0); - - if (ne00 < 128) { - for (int row = 0; row < N_F16_F16; ++row) { - int r1 = rb + row; - if (r1 >= ne11) { - break; - } - - device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12); - - float sumf = 0; - for (int i = tiisg; i < ne00; i += 32) { - sumf += (half) x[i] * (half) y[i]; - } - - float all_sum = simd_sum(sumf); - if (tiisg == 0) { - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } - } else { - device const half4 * x4 = (device const half4 *)x; - for (int row = 0; row < N_F16_F16; ++row) { - int r1 = rb + row; - if (r1 >= ne11) { - break; - } - - device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12); - device const half4 * y4 = (device const half4 *) y; - - float sumf = 0; - for (int i = tiisg; i < ne00/4; i += 32) { - for (int k = 0; k < 4; ++k) sumf += (half) x4[i][k] * y4[i][k]; - } - - float all_sum = simd_sum(sumf); - if (tiisg == 0) { - for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (half) x[i] * y[i]; - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } - } -} - -void kernel_mul_mv_f16_f32_1row_impl( +template +kernel void kernel_mul_mv_1row( device const char * src0, device const char * src1, device float * dst, @@ -1463,7 +1411,7 @@ void kernel_mul_mv_f16_f32_1row_impl( const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; - device const half * x = (device const half *) (src0 + offset0); + device const T * x = (device const T *) (src0 + offset0); device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); float sumf = 0; @@ -1476,153 +1424,29 @@ void kernel_mul_mv_f16_f32_1row_impl( dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; } } else { - device const half4 * x4 = (device const half4 *) x; + device const T4 * x4 = (device const T4 *) x; device const float4 * y4 = (device const float4 *) y; + for (int i = tiisg; i < ne00/4; i += 32) { - for (int k = 0; k < 4; ++k) sumf += (float)x4[i][k] * y4[i][k]; + for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]); } + float all_sum = simd_sum(sumf); + if (tiisg == 0) { - for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i]; + for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) (x[i] * y[i]); dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; } } } -[[host_name("kernel_mul_mv_f16_f32_1row")]] -kernel void kernel_mul_mv_f16_f32_1row( - device const char * src0, - device const char * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]]) { - kernel_mul_mv_f16_f32_1row_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg); -} +typedef decltype(kernel_mul_mv_1row) mul_mv_1row_t; -#define N_F16_F32 4 - -void kernel_mul_mv_f16_f32_impl( - device const char * src0, - device const char * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb00, - uint64_t nb01, - uint64_t nb02, - int64_t ne10, - int64_t ne11, - int64_t ne12, - uint64_t nb10, - uint64_t nb11, - uint64_t nb12, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - uint3 tgpig, - uint tiisg) { - - const int64_t r0 = tgpig.x; - const int64_t rb = tgpig.y*N_F16_F32; - const int64_t im = tgpig.z; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; - - device const half * x = (device const half *) (src0 + offset0); - - if (ne00 < 128) { - for (int row = 0; row < N_F16_F32; ++row) { - int r1 = rb + row; - if (r1 >= ne11) { - break; - } - - device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); - - float sumf = 0; - for (int i = tiisg; i < ne00; i += 32) { - sumf += (float) x[i] * (float) y[i]; - } - - float all_sum = simd_sum(sumf); - if (tiisg == 0) { - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } - } else { - device const half4 * x4 = (device const half4 *)x; - for (int row = 0; row < N_F16_F32; ++row) { - int r1 = rb + row; - if (r1 >= ne11) { - break; - } - - device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); - device const float4 * y4 = (device const float4 *) y; - - float sumf = 0; - for (int i = tiisg; i < ne00/4; i += 32) { - for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k]; - } - - float all_sum = simd_sum(sumf); - if (tiisg == 0) { - for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i]; - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } - } -} - -[[host_name("kernel_mul_mv_f16_f32")]] -kernel void kernel_mul_mv_f16_f32( - device const char * src0, - device const char * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]]) { - kernel_mul_mv_f16_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg); -} +template [[host_name("kernel_mul_mv_f16_f32_1row")]] kernel mul_mv_1row_t kernel_mul_mv_1row; // Assumes row size (ne00) is a multiple of 4 -kernel void kernel_mul_mv_f16_f32_l4( +template +kernel void kernel_mul_mv_l4( device const char * src0, device const char * src1, device float * dst, @@ -1654,14 +1478,14 @@ kernel void kernel_mul_mv_f16_f32_l4( const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; - device const half4 * x4 = (device const half4 *) (src0 + offset0); + device const T4 * x4 = (device const T4 *) (src0 + offset0); for (int r1 = 0; r1 < nrows; ++r1) { device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12); float sumf = 0; for (int i = tiisg; i < ne00/4; i += 32) { - for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k]; + for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]); } float all_sum = simd_sum(sumf); @@ -1671,6 +1495,10 @@ kernel void kernel_mul_mv_f16_f32_l4( } } +typedef decltype(kernel_mul_mv_l4) mul_mv_l4_t; + +template [[host_name("kernel_mul_mv_f16_f32_l4")]] kernel mul_mv_l4_t kernel_mul_mv_l4; + static float rope_yarn_ramp(const float low, const float high, const int i0) { const float y = (i0 / 2 - low) / max(0.001f, high - low); return 1.0f - min(1.0f, max(0.0f, y)); @@ -2791,9 +2619,10 @@ kernel void kernel_flash_attn_ext_vec_f16( template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>; //template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>; -kernel void kernel_cpy_f16_f16( - device const half * src0, - device half * dst, +template +kernel void kernel_cpy( + device const void * src0, + device void * dst, constant int64_t & ne00, constant int64_t & ne01, constant int64_t & ne02, @@ -2824,138 +2653,20 @@ kernel void kernel_cpy_f16_f16( const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); - device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + device T1 * dst_data = (device T1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { - device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); - dst_data[i00] = src[0]; + device const T0 * src = (device T0 *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + dst_data[i00] = (T1) src[0]; } } -kernel void kernel_cpy_f16_f32( - device const half * src0, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig[2]; - const int64_t i02 = tgpig[1]; - const int64_t i01 = tgpig[0]; +typedef decltype(kernel_cpy) kernel_cpy_t; - const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - const int64_t i3 = n / (ne2*ne1*ne0); - const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); - const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); - - device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { - device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); - dst_data[i00] = src[0]; - } -} - -kernel void kernel_cpy_f32_f16( - device const float * src0, - device half * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig[2]; - const int64_t i02 = tgpig[1]; - const int64_t i01 = tgpig[0]; - - const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - const int64_t i3 = n / (ne2*ne1*ne0); - const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); - const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); - - device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { - device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); - - dst_data[i00] = src[0]; - } -} - -kernel void kernel_cpy_f32_f32( - device const float * src0, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig[2]; - const int64_t i02 = tgpig[1]; - const int64_t i01 = tgpig[0]; - - const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - const int64_t i3 = n / (ne2*ne1*ne0); - const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); - const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); - - device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { - device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); - - dst_data[i00] = src[0]; - } -} +template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy; +template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy; +template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t kernel_cpy; +template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy; kernel void kernel_cpy_f32_q8_0( device const float * src0, @@ -5072,7 +4783,7 @@ void kernel_mul_mv_iq4_nl_f32_impl( device const float4 * y4 = (device const float4 *)yb; yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; - for (int row = 0; row < 2; ++row) { + for (int row = 0; row < 2 && first_row + row < ne01; ++row) { device const block_iq4_nl & xb = x[row*nb + ib]; device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it); @@ -5104,7 +4815,7 @@ void kernel_mul_mv_iq4_nl_f32_impl( yb += 16 * QK4_NL; } - for (int row = 0; row < 2; ++row) { + for (int row = 0; row < 2 && first_row + row < ne01; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; @@ -5756,9 +5467,9 @@ void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 } template -kernel void kernel_get_rows( +kernel void kernel_get_rows_q( device const void * src0, - device const char * src1, + device const void * src1, device float * dst, constant int64_t & ne00, constant uint64_t & nb01, @@ -5771,27 +5482,24 @@ kernel void kernel_get_rows( uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint3 tptg [[threads_per_threadgroup]]) { - //const int64_t i = tgpig; - //const int64_t r = ((device int32_t *) src1)[i]; - const int64_t i10 = tgpig.x; const int64_t i11 = tgpig.y; - const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0]; + const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0]; const int64_t i02 = i11; for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) { float4x4 temp; - dequantize_func( - ((device const block_q *) ((device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp); + dequantize_func(((device const block_q *) ((const device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp); *(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp; } } -kernel void kernel_get_rows_f32( +template +kernel void kernel_get_rows_f( device const void * src0, - device const char * src1, + device const void * src1, device float * dst, constant int64_t & ne00, constant uint64_t & nb01, @@ -5807,47 +5515,19 @@ kernel void kernel_get_rows_f32( const int64_t i10 = tgpig.x; const int64_t i11 = tgpig.y; - const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0]; + const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0]; const int64_t i02 = i11; for (int ind = tiitg; ind < ne00; ind += tptg.x) { - ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] = - ((device float *) ((device char *) src0 + r*nb01 + i02*nb02))[ind]; - } -} - -kernel void kernel_get_rows_f16( - device const void * src0, - device const char * src1, - device float * dst, - constant int64_t & ne00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb1, - constant uint64_t & nb2, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint3 tptg [[threads_per_threadgroup]]) { - const int64_t i10 = tgpig.x; - const int64_t i11 = tgpig.y; - - const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0]; - - const int64_t i02 = i11; - - for (int ind = tiitg; ind < ne00; ind += tptg.x) { - ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] = - ((device half *) ((device char *) src0 + r*nb01 + i02*nb02))[ind]; + (( device float *) (( device char *) dst + i11*nb2 + i10*nb1))[ind] = + ((const device T *) ((const device char *) src0 + i02*nb02 + r*nb01))[ind]; } } kernel void kernel_get_rows_i32( device const void * src0, - device const char * src1, + device const void * src1, device int32_t * dst, constant int64_t & ne00, constant uint64_t & nb01, @@ -5863,13 +5543,13 @@ kernel void kernel_get_rows_i32( const int64_t i10 = tgpig.x; const int64_t i11 = tgpig.y; - const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0]; + const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0]; const int64_t i02 = i11; for (int ind = tiitg; ind < ne00; ind += tptg.x) { - ((device int32_t *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] = - ((device int32_t *) ((device char *) src0 + r*nb01 + i02*nb02))[ind]; + (( device int32_t *) (( device char *) dst + i11*nb2 + i10*nb1))[ind] = + ((const device int32_t *) ((const device char *) src0 + i02*nb02 + r*nb01))[ind]; } } @@ -5886,28 +5566,28 @@ kernel void kernel_get_rows_i32( #define SG_MAT_ROW 8 // each block_q contains 16*nl weights -template -void kernel_mul_mm_impl(device const uchar * src0, - device const uchar * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne02, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup uchar * shared_memory [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { +template +kernel void kernel_mul_mm(device const uchar * src0, + device const uchar * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne02, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + threadgroup uchar * shared_memory [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { - threadgroup half * sa = (threadgroup half *)(shared_memory); + threadgroup T * sa = (threadgroup T *)(shared_memory); threadgroup float * sb = (threadgroup float *)(shared_memory + 4096); const uint r0 = tgpig.y; @@ -5922,7 +5602,7 @@ void kernel_mul_mm_impl(device const uchar * src0, short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1; short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1; - simdgroup_half8x8 ma[4]; + simdgroup_T8x8 ma[4]; simdgroup_float8x8 mb[2]; simdgroup_float8x8 c_res[8]; for (int i = 0; i < 8; i++){ @@ -5945,7 +5625,7 @@ void kernel_mul_mm_impl(device const uchar * src0, for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) { // load data and store to threadgroup memory - half4x4 temp_a; + T4x4 temp_a; dequantize_func(x, il, temp_a); threadgroup_barrier(mem_flags::mem_threadgroup); @@ -5965,7 +5645,7 @@ void kernel_mul_mm_impl(device const uchar * src0, threadgroup_barrier(mem_flags::mem_threadgroup); // load matrices from threadgroup memory and conduct outer products - threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2)); + threadgroup T * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2)); threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2)); #pragma unroll(4) @@ -6141,48 +5821,6 @@ void kernel_mul_mm_id_impl( } } -template -kernel void kernel_mul_mm(device const uchar * src0, - device const uchar * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne02, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup uchar * shared_memory [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mm_impl( - src0, - src1, - dst, - ne00, - ne02, - nb01, - nb02, - ne12, - nb10, - nb11, - nb12, - ne0, - ne1, - r2, - r3, - shared_memory, - tgpig, - tiitg, - sgitg); -} - template kernel void kernel_mul_mm_id( device const uchar * src0s, @@ -6263,69 +5901,60 @@ kernel void kernel_mul_mm_id( // get rows // -typedef void (get_rows_t)( - device const void * src0, - device const char * src1, - device float * dst, - constant int64_t & ne00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb1, - constant uint64_t & nb2, - uint3, uint, uint3); +typedef decltype(kernel_get_rows_f) get_rows_f_t; -//template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows; -//template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows; -template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_f32")]] kernel get_rows_f_t kernel_get_rows_f; +template [[host_name("kernel_get_rows_f16")]] kernel get_rows_f_t kernel_get_rows_f; + +typedef decltype(kernel_get_rows_q) get_rows_q_t; + +template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_q_t kernel_get_rows_q; // // matrix-matrix multiplication // -typedef decltype(kernel_mul_mm) mat_mm_t; +typedef decltype(kernel_mul_mm) mat_mm_t; -template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm; // // indirect matrix-matrix multiplication @@ -6462,7 +6091,7 @@ void mmv_fn( impl_fn(src0,(const device float *)src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,shared_values,tgpig,tiisg,sgitg); } -typedef decltype(mmv_fn) mul_mv_impl_fn_t; +typedef decltype(mmv_fn>) mul_mv_impl_fn_t; template kernel void kernel_mul_mv_id( @@ -6540,20 +6169,20 @@ kernel void kernel_mul_mv_id( sgitg); } -typedef decltype(kernel_mul_mv_id>) kernel_mul_mv_id_t; +typedef decltype(kernel_mul_mv_id>>) kernel_mul_mv_id_t; -template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; +template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; template [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; template [[host_name("kernel_mul_mv_id_iq1_m_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; @@ -6563,4 +6192,3 @@ template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; - diff --git a/llama/ggml-quants.c b/llama/ggml-quants.c index cc7d6f16..81d64d19 100644 --- a/llama/ggml-quants.c +++ b/llama/ggml-quants.c @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * @@ -30,8 +30,6 @@ #include "ggml-quants.h" #include "ggml-impl.h" -#define GGML_COMMON_IMPL_C -#include "ggml-common.h" #include #include @@ -686,7 +684,7 @@ static inline __m128i packNibbles( __m256i bytes ) { #endif //__loongarch_asx // reference implementation for deterministic creation of model files -void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int64_t k) { +void quantize_row_q4_0_ref(const float * restrict x, block_q4_0 * restrict y, int64_t k) { static const int qk = QK4_0; assert(k % qk == 0); @@ -724,11 +722,11 @@ void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict } void quantize_row_q4_0(const float * restrict x, void * restrict y, int64_t k) { - quantize_row_q4_0_reference(x, y, k); + quantize_row_q4_0_ref(x, y, k); } -void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict y, int64_t k) { +void quantize_row_q4_1_ref(const float * restrict x, block_q4_1 * restrict y, int64_t k) { const int qk = QK4_1; assert(k % qk == 0); @@ -766,10 +764,10 @@ void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict } void quantize_row_q4_1(const float * restrict x, void * restrict y, int64_t k) { - quantize_row_q4_1_reference(x, y, k); + quantize_row_q4_1_ref(x, y, k); } -void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict y, int64_t k) { +void quantize_row_q5_0_ref(const float * restrict x, block_q5_0 * restrict y, int64_t k) { static const int qk = QK5_0; assert(k % qk == 0); @@ -814,10 +812,10 @@ void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict } void quantize_row_q5_0(const float * restrict x, void * restrict y, int64_t k) { - quantize_row_q5_0_reference(x, y, k); + quantize_row_q5_0_ref(x, y, k); } -void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict y, int64_t k) { +void quantize_row_q5_1_ref(const float * restrict x, block_q5_1 * restrict y, int64_t k) { const int qk = QK5_1; assert(k % qk == 0); @@ -862,11 +860,11 @@ void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict } void quantize_row_q5_1(const float * restrict x, void * restrict y, int64_t k) { - quantize_row_q5_1_reference(x, y, k); + quantize_row_q5_1_ref(x, y, k); } // reference implementation for deterministic creation of model files -void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int64_t k) { +void quantize_row_q8_0_ref(const float * restrict x, block_q8_0 * restrict y, int64_t k) { assert(k % QK8_0 == 0); const int nb = k / QK8_0; @@ -1104,6 +1102,7 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k) } vec_xst(vec_pack(vec_pack(vi[0], vi[1]), vec_pack(vi[2], vi[3])), 0, &y[i].qs[0]); vec_xst(vec_pack(vec_pack(vi[4], vi[5]), vec_pack(vi[6], vi[7])), 16, &y[i].qs[0]); + } #elif defined(__loongarch_asx) for (int i = 0; i < nb; i++) { @@ -1171,12 +1170,12 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k) #else GGML_UNUSED(nb); // scalar - quantize_row_q8_0_reference(x, y, k); + quantize_row_q8_0_ref(x, y, k); #endif } // reference implementation for deterministic creation of model files -void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * restrict y, int64_t k) { +void quantize_row_q8_1_ref(const float * restrict x, block_q8_1 * restrict y, int64_t k) { assert(QK8_1 == 32); assert(k % QK8_1 == 0); const int nb = k / QK8_1; @@ -1463,6 +1462,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k) accv = vec_add(accv, vec_sld(accv, accv, 4)); accv = vec_add(accv, vec_sld(accv, accv, 8)); y[i].s = GGML_FP32_TO_FP16(d * vec_extract(accv, 0)); + } #elif defined(__loongarch_asx) for (int i = 0; i < nb; i++) { @@ -1534,7 +1534,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k) #else GGML_UNUSED(nb); // scalar - quantize_row_q8_1_reference(x, y, k); + quantize_row_q8_1_ref(x, y, k); #endif } @@ -1925,7 +1925,7 @@ static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t * //========================- 2-bit (de)-quantization -void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict y, int64_t k) { +void quantize_row_q2_K_ref(const float * restrict x, block_q2_K * restrict y, int64_t k) { assert(k % QK_K == 0); const int nb = k / QK_K; @@ -2028,7 +2028,7 @@ void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int6 } void quantize_row_q2_K(const float * restrict x, void * restrict vy, int64_t k) { - quantize_row_q2_K_reference(x, vy, k); + quantize_row_q2_K_ref(x, vy, k); } static float make_qkx3_quants(int n, int nmax, const float * restrict x, const float * restrict weights, @@ -2252,7 +2252,7 @@ static void quantize_row_q2_K_impl(const float * restrict x, block_q2_K * restri size_t quantize_q2_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { size_t row_size = ggml_row_size(GGML_TYPE_Q2_K, n_per_row); if (!quant_weights) { - quantize_row_q2_K_reference(src, dst, (int64_t)nrow*n_per_row); + quantize_row_q2_K_ref(src, dst, (int64_t)nrow*n_per_row); } else { char * qrow = (char *)dst; @@ -2267,7 +2267,7 @@ size_t quantize_q2_K(const float * restrict src, void * restrict dst, int64_t nr //========================= 3-bit (de)-quantization -void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict y, int64_t k) { +void quantize_row_q3_K_ref(const float * restrict x, block_q3_K * restrict y, int64_t k) { assert(k % QK_K == 0); const int nb = k / QK_K; @@ -2394,7 +2394,7 @@ void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int6 } void quantize_row_q3_K(const float * restrict x, void * restrict vy, int64_t k) { - quantize_row_q3_K_reference(x, vy, k); + quantize_row_q3_K_ref(x, vy, k); } static void quantize_row_q3_K_impl(const float * restrict x, block_q3_K * restrict y, int64_t n_per_row, const float * restrict quant_weights) { @@ -2484,7 +2484,7 @@ static void quantize_row_q3_K_impl(const float * restrict x, block_q3_K * restri size_t quantize_q3_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { size_t row_size = ggml_row_size(GGML_TYPE_Q3_K, n_per_row); if (!quant_weights) { - quantize_row_q3_K_reference(src, dst, (int64_t)nrow*n_per_row); + quantize_row_q3_K_ref(src, dst, (int64_t)nrow*n_per_row); } else { char * qrow = (char *)dst; @@ -2499,7 +2499,7 @@ size_t quantize_q3_K(const float * restrict src, void * restrict dst, int64_t nr // ====================== 4-bit (de)-quantization -void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict y, int64_t k) { +void quantize_row_q4_K_ref(const float * restrict x, block_q4_K * restrict y, int64_t k) { assert(k % QK_K == 0); const int nb = k / QK_K; @@ -2598,7 +2598,7 @@ void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int6 void quantize_row_q4_K(const float * restrict x, void * restrict vy, int64_t k) { assert(k % QK_K == 0); block_q4_K * restrict y = vy; - quantize_row_q4_K_reference(x, y, k); + quantize_row_q4_K_ref(x, y, k); } static void quantize_row_q4_K_impl(const float * restrict x, block_q4_K * restrict y, int64_t n_per_row, const float * quant_weights) { @@ -2677,7 +2677,7 @@ static void quantize_row_q4_K_impl(const float * restrict x, block_q4_K * restri size_t quantize_q4_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { size_t row_size = ggml_row_size(GGML_TYPE_Q4_K, n_per_row); if (!quant_weights) { - quantize_row_q4_K_reference(src, dst, (int64_t)nrow*n_per_row); + quantize_row_q4_K_ref(src, dst, (int64_t)nrow*n_per_row); } else { char * qrow = (char *)dst; @@ -2692,7 +2692,7 @@ size_t quantize_q4_K(const float * restrict src, void * restrict dst, int64_t nr // ====================== 5-bit (de)-quantization -void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict y, int64_t k) { +void quantize_row_q5_K_ref(const float * restrict x, block_q5_K * restrict y, int64_t k) { assert(k % QK_K == 0); const int64_t nb = k / QK_K; @@ -2809,7 +2809,7 @@ void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int6 void quantize_row_q5_K(const float * restrict x, void * restrict vy, int64_t k) { assert(k % QK_K == 0); block_q5_K * restrict y = vy; - quantize_row_q5_K_reference(x, y, k); + quantize_row_q5_K_ref(x, y, k); } static void quantize_row_q5_K_impl(const float * restrict x, block_q5_K * restrict y, int64_t n_per_row, const float * quant_weights) { @@ -2908,7 +2908,7 @@ static void quantize_row_q5_K_impl(const float * restrict x, block_q5_K * restri size_t quantize_q5_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { size_t row_size = ggml_row_size(GGML_TYPE_Q5_K, n_per_row); if (!quant_weights) { - quantize_row_q5_K_reference(src, dst, (int64_t)nrow*n_per_row); + quantize_row_q5_K_ref(src, dst, (int64_t)nrow*n_per_row); } else { char * qrow = (char *)dst; @@ -2923,7 +2923,7 @@ size_t quantize_q5_K(const float * restrict src, void * restrict dst, int64_t nr // ====================== 6-bit (de)-quantization -void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int64_t k) { +void quantize_row_q6_K_ref(const float * restrict x, block_q6_K * restrict y, int64_t k) { assert(k % QK_K == 0); const int64_t nb = k / QK_K; @@ -3027,7 +3027,7 @@ void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int6 void quantize_row_q6_K(const float * restrict x, void * restrict vy, int64_t k) { assert(k % QK_K == 0); block_q6_K * restrict y = vy; - quantize_row_q6_K_reference(x, y, k); + quantize_row_q6_K_ref(x, y, k); } static void quantize_row_q6_K_impl(const float * restrict x, block_q6_K * restrict y, int64_t n_per_row, const float * quant_weights) { @@ -3117,7 +3117,7 @@ static void quantize_row_q6_K_impl(const float * restrict x, block_q6_K * restri size_t quantize_q6_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { size_t row_size = ggml_row_size(GGML_TYPE_Q6_K, n_per_row); if (!quant_weights) { - quantize_row_q6_K_reference(src, dst, (int64_t)nrow*n_per_row); + quantize_row_q6_K_ref(src, dst, (int64_t)nrow*n_per_row); } else { char * qrow = (char *)dst; @@ -3134,7 +3134,7 @@ static void quantize_row_q4_0_impl(const float * restrict x, block_q4_0 * restri static_assert(QK4_0 == 32, "QK4_0 must be 32"); if (!quant_weights) { - quantize_row_q4_0_reference(x, y, n_per_row); + quantize_row_q4_0_ref(x, y, n_per_row); return; } @@ -3160,7 +3160,7 @@ static void quantize_row_q4_0_impl(const float * restrict x, block_q4_0 * restri size_t quantize_q4_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { if (!quant_weights) { - quantize_row_q4_0_reference(src, dst, (int64_t)nrow*n_per_row); + quantize_row_q4_0_ref(src, dst, (int64_t)nrow*n_per_row); return nrow * ggml_row_size(GGML_TYPE_Q4_0, n_per_row); } size_t row_size = ggml_row_size(GGML_TYPE_Q4_0, n_per_row); @@ -3177,7 +3177,7 @@ static void quantize_row_q4_1_impl(const float * restrict x, block_q4_1 * restri static_assert(QK4_1 == 32, "QK4_1 must be 32"); if (!quant_weights) { - quantize_row_q4_1_reference(x, y, n_per_row); + quantize_row_q4_1_ref(x, y, n_per_row); return; } @@ -3205,7 +3205,7 @@ static void quantize_row_q4_1_impl(const float * restrict x, block_q4_1 * restri size_t quantize_q4_1(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { if (!quant_weights) { - quantize_row_q4_1_reference(src, dst, (int64_t)nrow*n_per_row); + quantize_row_q4_1_ref(src, dst, (int64_t)nrow*n_per_row); return nrow * ggml_row_size(GGML_TYPE_Q4_1, n_per_row); } size_t row_size = ggml_row_size(GGML_TYPE_Q4_1, n_per_row); @@ -3222,7 +3222,7 @@ static void quantize_row_q5_0_impl(const float * restrict x, block_q5_0 * restri static_assert(QK5_0 == 32, "QK5_0 must be 32"); if (!quant_weights) { - quantize_row_q5_0_reference(x, y, n_per_row); + quantize_row_q5_0_ref(x, y, n_per_row); return; } @@ -3259,7 +3259,7 @@ static void quantize_row_q5_0_impl(const float * restrict x, block_q5_0 * restri size_t quantize_q5_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { if (!quant_weights) { - quantize_row_q5_0_reference(src, dst, (int64_t)nrow*n_per_row); + quantize_row_q5_0_ref(src, dst, (int64_t)nrow*n_per_row); return nrow * ggml_row_size(GGML_TYPE_Q5_0, n_per_row); } size_t row_size = ggml_row_size(GGML_TYPE_Q5_0, n_per_row); @@ -3276,7 +3276,7 @@ static void quantize_row_q5_1_impl(const float * restrict x, block_q5_1 * restri static_assert(QK5_1 == 32, "QK5_1 must be 32"); if (!quant_weights) { - quantize_row_q5_1_reference(x, y, n_per_row); + quantize_row_q5_1_ref(x, y, n_per_row); return; } @@ -3312,7 +3312,7 @@ static void quantize_row_q5_1_impl(const float * restrict x, block_q5_1 * restri size_t quantize_q5_1(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { if (!quant_weights) { - quantize_row_q5_1_reference(src, dst, (int64_t)nrow*n_per_row); + quantize_row_q5_1_ref(src, dst, (int64_t)nrow*n_per_row); return nrow * ggml_row_size(GGML_TYPE_Q5_1, n_per_row); } size_t row_size = ggml_row_size(GGML_TYPE_Q5_1, n_per_row); @@ -3328,7 +3328,7 @@ size_t quantize_q5_1(const float * restrict src, void * restrict dst, int64_t nr size_t quantize_q8_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { (void)quant_weights; // not used const size_t row_size = ggml_row_size(GGML_TYPE_Q8_0, n_per_row); - quantize_row_q8_0_reference(src, dst, (int64_t)nrow*n_per_row); + quantize_row_q8_0_ref(src, dst, (int64_t)nrow*n_per_row); return nrow * row_size; } @@ -3616,7 +3616,7 @@ void dequantize_row_iq4_xs(const block_iq4_xs * restrict x, float * restrict y, //===================================== Q8_K ============================================== -void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int64_t k) { +void quantize_row_q8_K_ref(const float * restrict x, block_q8_K * restrict y, int64_t k) { assert(k % QK_K == 0); const int64_t nb = k / QK_K; @@ -3667,7 +3667,7 @@ void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int6 } void quantize_row_q8_K(const float * restrict x, void * restrict y, int64_t k) { - quantize_row_q8_K_reference(x, y, k); + quantize_row_q8_K_ref(x, y, k); } //===================================== Dot ptoducts ================================= @@ -3834,59 +3834,61 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2); float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1); - vst1_f32(s, vget_low_f32(sumv2)); + vst1_f32(s, vget_low_f32(sumv2)); vst1_f32(s + bs, vget_high_f32(sumv2)); return; } #endif + + int ib = 0; + float sumf = 0; + #if defined(__ARM_FEATURE_SVE) - const svbool_t ptrueh = svptrue_pat_b8(SV_VL16); - const svbool_t ptruel = svnot_b_z(svptrue_b8(), ptrueh); + if (svcntb() == QK8_0) { + const svbool_t ptrueh = svptrue_pat_b8(SV_VL16); + const svbool_t ptruel = svnot_b_z(svptrue_b8(), ptrueh); - svfloat32_t sumv0 = svdup_n_f32(0.0f); - svfloat32_t sumv1 = svdup_n_f32(0.0f); + svfloat32_t sumv0 = svdup_n_f32(0.0f); + svfloat32_t sumv1 = svdup_n_f32(0.0f); - assert(nb % 2 == 0); // TODO: handle odd nb + for (; ib + 1 < nb; ib += 2) { + const block_q4_0 * restrict x0 = &x[ib + 0]; + const block_q4_0 * restrict x1 = &x[ib + 1]; + const block_q8_0 * restrict y0 = &y[ib + 0]; + const block_q8_0 * restrict y1 = &y[ib + 1]; - for (int i = 0; i < nb; i += 2) { - const block_q4_0 * restrict x0 = &x[i + 0]; - const block_q4_0 * restrict x1 = &x[i + 1]; - const block_q8_0 * restrict y0 = &y[i + 0]; - const block_q8_0 * restrict y1 = &y[i + 1]; + // load x + const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs); + const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs); - // load x - const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs); - const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs); + // 4-bit -> 8-bit + const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx0r, 0x0F), 0x04)); + const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx1r, 0x0F), 0x04)); - // 4-bit -> 8-bit - const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx0r, 0x0F), 0x04)); - const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx1r, 0x0F), 0x04)); + // sub 8 + const svint8_t qx0s = svsub_n_s8_x(svptrue_b8(), qx0, 8); + const svint8_t qx1s = svsub_n_s8_x(svptrue_b8(), qx1, 8); - // sub 8 - const svint8_t qx0s = svsub_n_s8_x(svptrue_b8(), qx0, 8); - const svint8_t qx1s = svsub_n_s8_x(svptrue_b8(), qx1, 8); + // load y + const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs); + const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs); - // load y - const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs); - const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs); + // dot product + sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0s, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); + sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1s, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); + } - // dot product - sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0s, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); - sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1s, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); + sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1)); } - - *s = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1)); #elif defined(__ARM_NEON) float32x4_t sumv0 = vdupq_n_f32(0.0f); float32x4_t sumv1 = vdupq_n_f32(0.0f); - assert(nb % 2 == 0); // TODO: handle odd nb - - for (int i = 0; i < nb; i += 2) { - const block_q4_0 * restrict x0 = &x[i + 0]; - const block_q4_0 * restrict x1 = &x[i + 1]; - const block_q8_0 * restrict y0 = &y[i + 0]; - const block_q8_0 * restrict y1 = &y[i + 1]; + for (; ib + 1 < nb; ib += 2) { + const block_q4_0 * restrict x0 = &x[ib + 0]; + const block_q4_0 * restrict x1 = &x[ib + 1]; + const block_q8_0 * restrict y0 = &y[ib + 0]; + const block_q8_0 * restrict y1 = &y[ib + 1]; const uint8x16_t m4b = vdupq_n_u8(0x0F); const int8x16_t s8b = vdupq_n_s8(0x8); @@ -3920,23 +3922,23 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); } - *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); + sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); #elif defined(__AVX2__) // Initialize accumulator with zeros __m256 acc = _mm256_setzero_ps(); // Main loop - for (int i = 0; i < nb; ++i) { + for (; ib < nb; ++ib) { /* Compute combined scale for the block */ - const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) ); + const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d) ); - __m256i qx = bytes_from_nibbles_32(x[i].qs); + __m256i qx = bytes_from_nibbles_32(x[ib].qs); // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. const __m256i off = _mm256_set1_epi8( 8 ); qx = _mm256_sub_epi8( qx, off ); - __m256i qy = _mm256_loadu_si256((const __m256i *)y[i].qs); + __m256i qy = _mm256_loadu_si256((const __m256i *)y[ib].qs); const __m256 q = mul_sum_i8_pairs_float(qx, qy); @@ -3944,28 +3946,28 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r acc = _mm256_fmadd_ps( d, q, acc ); } - *s = hsum_float_8(acc); + sumf = hsum_float_8(acc); #elif defined(__AVX__) // Initialize accumulator with zeros __m256 acc = _mm256_setzero_ps(); // Main loop - for (int i = 0; i < nb; ++i) { + for (; ib < nb; ++ib) { // Compute combined scale for the block - const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) ); + const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d) ); const __m128i lowMask = _mm_set1_epi8(0xF); const __m128i off = _mm_set1_epi8(8); - const __m128i tmp = _mm_loadu_si128((const __m128i *)x[i].qs); + const __m128i tmp = _mm_loadu_si128((const __m128i *)x[ib].qs); __m128i bx_0 = _mm_and_si128(lowMask, tmp); - __m128i by_0 = _mm_loadu_si128((const __m128i *)y[i].qs); + __m128i by_0 = _mm_loadu_si128((const __m128i *)y[ib].qs); bx_0 = _mm_sub_epi8(bx_0, off); const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0); bx_0 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp, 4)); - by_0 = _mm_loadu_si128((const __m128i *)(y[i].qs + 16)); + by_0 = _mm_loadu_si128((const __m128i *)(y[ib].qs + 16)); bx_0 = _mm_sub_epi8(bx_0, off); const __m128i i32_1 = mul_sum_i8_pairs(bx_0, by_0); @@ -3976,7 +3978,7 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc); } - *s = hsum_float_8(acc); + sumf = hsum_float_8(acc); #elif defined(__SSSE3__) // set constants const __m128i lowMask = _mm_set1_epi8(0xF); @@ -3988,94 +3990,40 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r __m128 acc_2 = _mm_setzero_ps(); __m128 acc_3 = _mm_setzero_ps(); - // First round without accumulation - { - _mm_prefetch(&x[0] + sizeof(block_q4_0), _MM_HINT_T0); - _mm_prefetch(&y[0] + sizeof(block_q8_0), _MM_HINT_T0); + for (; ib + 1 < nb; ib += 2) { + _mm_prefetch(&x[ib] + sizeof(block_q4_0), _MM_HINT_T0); + _mm_prefetch(&y[ib] + sizeof(block_q8_0), _MM_HINT_T0); // Compute combined scale for the block 0 and 1 - const __m128 d_0_1 = _mm_set1_ps( GGML_FP16_TO_FP32(x[0].d) * GGML_FP16_TO_FP32(y[0].d) ); + const __m128 d_0_1 = _mm_set1_ps( GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d) ); - const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[0].qs); + const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[ib].qs); __m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1); - __m128i by_0 = _mm_loadu_si128((const __m128i *)y[0].qs); + __m128i by_0 = _mm_loadu_si128((const __m128i *)y[ib].qs); bx_0 = _mm_sub_epi8(bx_0, off); const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0); __m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4)); - __m128i by_1 = _mm_loadu_si128((const __m128i *)(y[0].qs + 16)); + __m128i by_1 = _mm_loadu_si128((const __m128i *)(y[ib].qs + 16)); bx_1 = _mm_sub_epi8(bx_1, off); const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1); - _mm_prefetch(&x[1] + sizeof(block_q4_0), _MM_HINT_T0); - _mm_prefetch(&y[1] + sizeof(block_q8_0), _MM_HINT_T0); + _mm_prefetch(&x[ib] + 2 * sizeof(block_q4_0), _MM_HINT_T0); + _mm_prefetch(&y[ib] + 2 * sizeof(block_q8_0), _MM_HINT_T0); // Compute combined scale for the block 2 and 3 - const __m128 d_2_3 = _mm_set1_ps( GGML_FP16_TO_FP32(x[1].d) * GGML_FP16_TO_FP32(y[1].d) ); + const __m128 d_2_3 = _mm_set1_ps( GGML_FP16_TO_FP32(x[ib + 1].d) * GGML_FP16_TO_FP32(y[ib + 1].d) ); - const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[1].qs); + const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs); __m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3); - __m128i by_2 = _mm_loadu_si128((const __m128i *)y[1].qs); + __m128i by_2 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs); bx_2 = _mm_sub_epi8(bx_2, off); const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2); __m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4)); - __m128i by_3 = _mm_loadu_si128((const __m128i *)(y[1].qs + 16)); - bx_3 = _mm_sub_epi8(bx_3, off); - const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3); - - // Convert int32_t to float - __m128 p0 = _mm_cvtepi32_ps(i32_0); - __m128 p1 = _mm_cvtepi32_ps(i32_1); - __m128 p2 = _mm_cvtepi32_ps(i32_2); - __m128 p3 = _mm_cvtepi32_ps(i32_3); - - // Apply the scale - acc_0 = _mm_mul_ps( d_0_1, p0 ); - acc_1 = _mm_mul_ps( d_0_1, p1 ); - acc_2 = _mm_mul_ps( d_2_3, p2 ); - acc_3 = _mm_mul_ps( d_2_3, p3 ); - } - - assert(nb % 2 == 0); // TODO: handle odd nb - - // Main loop - for (int i = 2; i < nb; i+=2) { - _mm_prefetch(&x[i] + sizeof(block_q4_0), _MM_HINT_T0); - _mm_prefetch(&y[i] + sizeof(block_q8_0), _MM_HINT_T0); - - // Compute combined scale for the block 0 and 1 - const __m128 d_0_1 = _mm_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) ); - - const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[i].qs); - - __m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1); - __m128i by_0 = _mm_loadu_si128((const __m128i *)y[i].qs); - bx_0 = _mm_sub_epi8(bx_0, off); - const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0); - - __m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4)); - __m128i by_1 = _mm_loadu_si128((const __m128i *)(y[i].qs + 16)); - bx_1 = _mm_sub_epi8(bx_1, off); - const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1); - - _mm_prefetch(&x[i] + 2 * sizeof(block_q4_0), _MM_HINT_T0); - _mm_prefetch(&y[i] + 2 * sizeof(block_q8_0), _MM_HINT_T0); - - // Compute combined scale for the block 2 and 3 - const __m128 d_2_3 = _mm_set1_ps( GGML_FP16_TO_FP32(x[i + 1].d) * GGML_FP16_TO_FP32(y[i + 1].d) ); - - const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[i + 1].qs); - - __m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3); - __m128i by_2 = _mm_loadu_si128((const __m128i *)y[i + 1].qs); - bx_2 = _mm_sub_epi8(bx_2, off); - const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2); - - __m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4)); - __m128i by_3 = _mm_loadu_si128((const __m128i *)(y[i + 1].qs + 16)); + __m128i by_3 = _mm_loadu_si128((const __m128i *)(y[ib + 1].qs + 16)); bx_3 = _mm_sub_epi8(bx_3, off); const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3); @@ -4098,18 +4046,16 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r acc_3 = _mm_add_ps(p3_d, acc_3); } - *s = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3); + sumf = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3); #elif defined(__riscv_v_intrinsic) - float sumf = 0.0; - size_t vl = __riscv_vsetvl_e8m1(qk/2); - for (int i = 0; i < nb; i++) { + for (; ib < nb; ++ib) { // load elements - vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl); + vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[ib].qs, vl); - vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl); - vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl); + vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[ib].qs, vl); + vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[ib].qs+16, vl); // mask and store lower part of x, and then upper part vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl); @@ -4132,30 +4078,29 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r int sumi = __riscv_vmv_x_s_i32m1_i32(vs2); - sumf += sumi*GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d); + sumf += sumi*GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d); } - *s = sumf; - #elif defined(__POWER9_VECTOR__) const vector signed char lowMask = vec_splats((signed char)0xF); + const vector signed int v0 = vec_splats((int32_t)0); const vector unsigned char v4 = vec_splats((unsigned char)0x4); const vector signed char v8 = vec_splats((signed char)0x8); vector float vsumf0 = vec_splats(0.0f); -#pragma GCC unroll 4 - for (int i = 0; i < nb; i++) { - __builtin_prefetch(x[i].qs, 0, 1); - __builtin_prefetch(y[i].qs, 0, 1); +#pragma GCC unroll 8 + for (; ib < nb; ++ib) { + __builtin_prefetch(x[ib].qs, 0, 1); + __builtin_prefetch(y[ib].qs, 0, 1); - vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d)); - vector float vyd = vec_splats(GGML_FP16_TO_FP32(y[i].d)); + vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[ib].d)); + vector float vyd = vec_splats(GGML_FP16_TO_FP32(y[ib].d)); vector float vd = vec_mul(vxd, vyd); - vector signed char qxs = (vector signed char)vec_xl( 0, x[i].qs); - vector signed char q8y0 = vec_xl( 0, y[i].qs); - vector signed char q8y1 = vec_xl(16, y[i].qs); + vector signed char qxs = (vector signed char)vec_xl( 0, x[ib].qs); + vector signed char q8y0 = vec_xl( 0, y[ib].qs); + vector signed char q8y1 = vec_xl(16, y[ib].qs); vector signed char q4x0 = vec_and(qxs, lowMask); vector signed char q4x1 = vec_sr(qxs, v4); @@ -4166,9 +4111,10 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r vector signed short qv0 = vec_add(vec_mule(q4x0, q8y0), vec_mulo(q4x0, q8y0)); vector signed short qv1 = vec_add(vec_mule(q4x1, q8y1), vec_mulo(q4x1, q8y1)); - qv0 = vec_add(qv0, qv1); + vector signed int vsumi0 = v0; - vector signed int vsumi0 = vec_add(vec_unpackh(qv0), vec_unpackl(qv0)); + vsumi0 = vec_sum4s(qv0, vsumi0); + vsumi0 = vec_sum4s(qv1, vsumi0); vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); } @@ -4176,24 +4122,24 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); - *s = vec_extract(vsumf0, 0); + sumf = vec_extract(vsumf0, 0); #elif defined(__loongarch_asx) // Initialize accumulator with zeros __m256 acc = (__m256)__lasx_xvldi(0); // Main loop - for (int i = 0; i < nb; ++i) { + for (; ib < nb; ++ib) { /* Compute combined scale for the block */ - const __m256 d = __lasx_xvreplfr2vr_s( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) ); + const __m256 d = __lasx_xvreplfr2vr_s( GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d) ); - __m256i qx = bytes_from_nibbles_32(x[i].qs); + __m256i qx = bytes_from_nibbles_32(x[ib].qs); // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. const __m256i off = __lasx_xvreplgr2vr_b( 8 ); qx = __lasx_xvsub_b( qx, off ); - __m256i qy = __lasx_xvld((const __m256i *)y[i].qs, 0); + __m256i qy = __lasx_xvld((const __m256i *)y[ib].qs, 0); const __m256 q = mul_sum_i8_pairs_float(qx, qy); @@ -4201,7 +4147,7 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r acc = __lasx_xvfmadd_s( d, q, acc ); } - *s = hsum_float_8(acc); + sumf = hsum_float_8(acc); #elif defined(__loongarch_sx) // set constants const __m128i low_mask = __lsx_vreplgr2vr_b(0xF); @@ -4213,89 +4159,38 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r __m128 acc_2 = __lsx_vldi(0); __m128 acc_3 = __lsx_vldi(0); - // First round without accumulation - { - _mm_prefetch(&x[0] + sizeof(block_q4_0), _MM_HINT_T0); - _mm_prefetch(&y[0] + sizeof(block_q8_0), _MM_HINT_T0); + for (; ib + 1 < nb; ib += 2) { // Compute combined scale for the block 0 and 1 - const __m128 d_0_1 = __lsx_vreplgr2vr_w( GGML_FP16_TO_FP32(x[0].d) * GGML_FP16_TO_FP32(y[0].d) ); + const __m128 d_0_1 = __lsx_vreplgr2vr_w( GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d) ); - const __m128i tmp_0_1 = __lsx_vld((const __m128i *)x[0].qs, 0); + const __m128i tmp_0_1 = __lsx_vld((const __m128i *)x[ib].qs, 0); __m128i bx_0 = __lsx_vand_v(low_mask, tmp_0_1); - __m128i by_0 = __lsx_vld((const __m128i *)y[0].qs, 0); + __m128i by_0 = __lsx_vld((const __m128i *)y[ib].qs, 0); bx_0 = __lsx_vsub_b(bx_0, off); const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0); __m128i bx_1 = __lsx_vand_v(low_mask, __lsx_vsrli_d(tmp_0_1, 4)); - __m128i by_1 = __lsx_vld((const __m128i *)(y[0].qs + 16), 0); + __m128i by_1 = __lsx_vld((const __m128i *)(y[ib].qs + 16), 0); bx_1 = __lsx_vsub_b(bx_1, off); const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1); - // Compute combined scale for the block 2 and 3 - const __m128 d_2_3 = __lsx_vreplgr2vr_w( GGML_FP16_TO_FP32(x[1].d) * GGML_FP16_TO_FP32(y[1].d) ); + //_mm_prefetch(&x[ib] + 2 * sizeof(block_q4_0), _MM_HINT_T0); + //_mm_prefetch(&y[ib] + 2 * sizeof(block_q8_0), _MM_HINT_T0); - const __m128i tmp_2_3 = __lsx_vld((const __m128i *)x[1].qs, 0); + // Compute combined scale for the block 2 and 3 + const __m128 d_2_3 = __lsx_vreplgr2vr_w( GGML_FP16_TO_FP32(x[ib + 1].d) * GGML_FP16_TO_FP32(y[ib + 1].d) ); + + const __m128i tmp_2_3 = __lsx_vld((const __m128i *)x[ib + 1].qs, 0); __m128i bx_2 = __lsx_vand_v(low_mask, tmp_2_3); - __m128i by_2 = __lsx_vld((const __m128i *)y[1].qs, 0); + __m128i by_2 = __lsx_vld((const __m128i *)y[ib + 1].qs, 0); bx_2 = __lsx_vsub_b(bx_2, off); const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2); __m128i bx_3 = __lsx_vand_v(low_mask, __lsx_vsrli_d(tmp_2_3, 4)); - __m128i by_3 = __lsx_vld((const __m128i *)(y[1].qs + 16), 0); - bx_3 = __lsx_vsub_b(bx_3, off); - const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3); - - // Convert int32_t to float - __m128 p0 = __lsx_vffint_s_w(i32_0); - __m128 p1 = __lsx_vffint_s_w(i32_1); - __m128 p2 = __lsx_vffint_s_w(i32_2); - __m128 p3 = __lsx_vffint_s_w(i32_3); - - // Apply the scale - acc_0 = __lsx_vfmul_s( d_0_1, p0 ); - acc_1 = __lsx_vfmul_s( d_0_1, p1 ); - acc_2 = __lsx_vfmul_s( d_2_3, p2 ); - acc_3 = __lsx_vfmul_s( d_2_3, p3 ); - } - - assert(nb % 2 == 0); // TODO: handle odd nb - - // Main loop - for (int i = 2; i < nb; i+=2) { - - // Compute combined scale for the block 0 and 1 - const __m128 d_0_1 = __lsx_vreplgr2vr_w( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) ); - - const __m128i tmp_0_1 = __lsx_vld((const __m128i *)x[i].qs, 0); - - __m128i bx_0 = __lsx_vand_v(low_mask, tmp_0_1); - __m128i by_0 = __lsx_vld((const __m128i *)y[i].qs, 0); - bx_0 = __lsx_vsub_b(bx_0, off); - const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0); - - __m128i bx_1 = __lsx_vand_v(low_mask, __lsx_vsrli_d(tmp_0_1, 4)); - __m128i by_1 = __lsx_vld((const __m128i *)(y[i].qs + 16), 0); - bx_1 = __lsx_vsub_b(bx_1, off); - const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1); - - //_mm_prefetch(&x[i] + 2 * sizeof(block_q4_0), _MM_HINT_T0); - //_mm_prefetch(&y[i] + 2 * sizeof(block_q8_0), _MM_HINT_T0); - - // Compute combined scale for the block 2 and 3 - const __m128 d_2_3 = __lsx_vreplgr2vr_w( GGML_FP16_TO_FP32(x[i + 1].d) * GGML_FP16_TO_FP32(y[i + 1].d) ); - - const __m128i tmp_2_3 = __lsx_vld((const __m128i *)x[i + 1].qs, 0); - - __m128i bx_2 = __lsx_vand_v(low_mask, tmp_2_3); - __m128i by_2 = __lsx_vld((const __m128i *)y[i + 1].qs, 0); - bx_2 = __lsx_vsub_b(bx_2, off); - const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2); - - __m128i bx_3 = __lsx_vand_v(low_mask, __lsx_vsrli_d(tmp_2_3, 4)); - __m128i by_3 = __lsx_vld((const __m128i *)(y[i + 1].qs + 16), 0); + __m128i by_3 = __lsx_vld((const __m128i *)(y[ib + 1].qs + 16), 0); bx_3 = __lsx_vsub_b(bx_3, off); const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3); @@ -4318,27 +4213,25 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r acc_3 = __lsx_vfadd_s(p3_d, acc_3); } - *s = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3); - -#else - // scalar - float sumf = 0.0; - - for (int i = 0; i < nb; i++) { - int sumi = 0; + sumf = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3); +#endif + for (; ib < nb; ++ib) { + int sumi0 = 0; + int sumi1 = 0; for (int j = 0; j < qk/2; ++j) { - const int v0 = (x[i].qs[j] & 0x0F) - 8; - const int v1 = (x[i].qs[j] >> 4) - 8; + const int v0 = (x[ib].qs[j] & 0x0F) - 8; + const int v1 = (x[ib].qs[j] >> 4) - 8; - sumi += (v0 * y[i].qs[j]) + (v1 * y[i].qs[j + qk/2]); + sumi0 += (v0 * y[ib].qs[j]); + sumi1 += (v1 * y[ib].qs[j + qk/2]); } - sumf += sumi*GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d); + int sumi = sumi0 + sumi1; + sumf += sumi*GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d); } *s = sumf; -#endif } void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { @@ -4424,11 +4317,15 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1); sumv2 = vaddq_f32(sumv2, summs0); - vst1_f32(s, vget_low_f32(sumv2)); + vst1_f32(s, vget_low_f32 (sumv2)); vst1_f32(s + bs, vget_high_f32(sumv2)); return; } #endif + + int ib = 0; + float sumf = 0; + // TODO: add WASM SIMD #if defined(__ARM_NEON) float32x4_t sumv0 = vdupq_n_f32(0.0f); @@ -4436,13 +4333,11 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r float summs = 0; - assert(nb % 2 == 0); // TODO: handle odd nb - - for (int i = 0; i < nb; i += 2) { - const block_q4_1 * restrict x0 = &x[i + 0]; - const block_q4_1 * restrict x1 = &x[i + 1]; - const block_q8_1 * restrict y0 = &y[i + 0]; - const block_q8_1 * restrict y1 = &y[i + 1]; + for (; ib + 1 < nb; ib += 2) { + const block_q4_1 * restrict x0 = &x[ib + 0]; + const block_q4_1 * restrict x1 = &x[ib + 1]; + const block_q8_1 * restrict y0 = &y[ib + 0]; + const block_q8_1 * restrict y1 = &y[ib + 1]; summs += GGML_FP16_TO_FP32(x0->m) * GGML_FP16_TO_FP32(y0->s) + GGML_FP16_TO_FP32(x1->m) * GGML_FP16_TO_FP32(y1->s); @@ -4471,7 +4366,7 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); } - *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs; + sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs; #elif defined(__AVX2__) || defined(__AVX__) // Initialize accumulator with zeros __m256 acc = _mm256_setzero_ps(); @@ -4479,11 +4374,11 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r float summs = 0; // Main loop - for (int i = 0; i < nb; ++i) { - const float d0 = GGML_FP16_TO_FP32(x[i].d); - const float d1 = GGML_FP16_TO_FP32(y[i].d); + for (; ib < nb; ++ib) { + const float d0 = GGML_FP16_TO_FP32(x[ib].d); + const float d1 = GGML_FP16_TO_FP32(y[ib].d); - summs += GGML_FP16_TO_FP32(x[i].m) * GGML_FP16_TO_FP32(y[i].s); + summs += GGML_FP16_TO_FP32(x[ib].m) * GGML_FP16_TO_FP32(y[ib].s); const __m256 d0v = _mm256_set1_ps( d0 ); const __m256 d1v = _mm256_set1_ps( d1 ); @@ -4492,8 +4387,8 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r const __m256 d0d1 = _mm256_mul_ps( d0v, d1v ); // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes - const __m256i qx = bytes_from_nibbles_32(x[i].qs); - const __m256i qy = _mm256_loadu_si256( (const __m256i *)y[i].qs ); + const __m256i qx = bytes_from_nibbles_32(x[ib].qs); + const __m256i qy = _mm256_loadu_si256( (const __m256i *)y[ib].qs ); const __m256 xy = mul_sum_us8_pairs_float(qx, qy); @@ -4505,18 +4400,16 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r #endif } - *s = hsum_float_8(acc) + summs; + sumf = hsum_float_8(acc) + summs; #elif defined(__riscv_v_intrinsic) - float sumf = 0.0; - size_t vl = __riscv_vsetvl_e8m1(qk/2); - for (int i = 0; i < nb; i++) { + for (; ib < nb; ++ib) { // load elements - vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl); + vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[ib].qs, vl); - vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl); - vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl); + vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[ib].qs, vl); + vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[ib].qs+16, vl); // mask and store lower part of x, and then upper part vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl); @@ -4535,43 +4428,40 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r int sumi = __riscv_vmv_x_s_i32m1_i32(vs2); - sumf += (GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d))*sumi + GGML_FP16_TO_FP32(x[i].m)*GGML_FP16_TO_FP32(y[i].s); + sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d))*sumi + GGML_FP16_TO_FP32(x[ib].m)*GGML_FP16_TO_FP32(y[ib].s); } - *s = sumf; - #elif defined(__POWER9_VECTOR__) const vector signed char lowMask = vec_splats((signed char)0xF); + const vector signed int v0 = vec_splats((int32_t)0); const vector unsigned char v4 = vec_splats((unsigned char)0x4); vector float vsumf0 = vec_splats(0.0f); #pragma GCC unroll 4 - for (int i = 0; i < nb; i++) { - __builtin_prefetch(x[i].qs, 0, 1); - __builtin_prefetch(y[i].qs, 0, 1); + for (; ib < nb; ++ib) { + __builtin_prefetch(x[ib].qs, 0, 1); + __builtin_prefetch(y[ib].qs, 0, 1); - vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d)); - vector float vyd = vec_splats(GGML_FP16_TO_FP32(y[i].d)); + vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[ib].d)); + vector float vyd = vec_splats(GGML_FP16_TO_FP32(y[ib].d)); vector float vd = vec_mul(vxd, vyd); - vector float vxmin = vec_splats(GGML_FP16_TO_FP32(x[i].m)); - vector float vys = {GGML_FP16_TO_FP32(y[i].s), 0.0f, 0.0f, 0.0f}; + vector float vxmin = vec_splats(GGML_FP16_TO_FP32(x[ib].m)); + vector float vys = {GGML_FP16_TO_FP32(y[ib].s), 0.0f, 0.0f, 0.0f}; vsumf0 = vec_madd(vxmin, vys, vsumf0); - vector signed char qxs = (vector signed char)vec_xl( 0, x[i].qs); - vector signed char q8y0 = vec_xl( 0, y[i].qs); - vector signed char q8y1 = vec_xl(16, y[i].qs); + vector signed char qxs = (vector signed char)vec_xl( 0, x[ib].qs); + vector signed char q8y0 = vec_xl( 0, y[ib].qs); + vector signed char q8y1 = vec_xl(16, y[ib].qs); - vector signed char q4x0 = vec_and(qxs, lowMask); - vector signed char q4x1 = vec_sr(qxs, v4); + vector unsigned char q4x0 = (vector unsigned char)vec_and(qxs, lowMask); + vector unsigned char q4x1 = (vector unsigned char)vec_sr(qxs, v4); - vector signed short qv0 = vec_add(vec_mule(q4x0, q8y0), vec_mulo(q4x0, q8y0)); - vector signed short qv1 = vec_add(vec_mule(q4x1, q8y1), vec_mulo(q4x1, q8y1)); + vector signed int vsumi0 = v0; - qv0 = vec_add(qv0, qv1); - - vector signed int vsumi0 = vec_add(vec_unpackh(qv0), vec_unpackl(qv0)); + vsumi0 = vec_msum(q8y0, q4x0, vsumi0); + vsumi0 = vec_msum(q8y1, q4x1, vsumi0); vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); } @@ -4579,7 +4469,7 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); - *s = vec_extract(vsumf0, 0); + sumf = vec_extract(vsumf0, 0); #elif defined(__loongarch_asx) // Initialize accumulator with zeros @@ -4588,11 +4478,11 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r float summs = 0; // Main loop - for (int i = 0; i < nb; ++i) { - const float d0 = GGML_FP16_TO_FP32(x[i].d); - const float d1 = GGML_FP16_TO_FP32(y[i].d); + for (; ib < nb; ++ib) { + const float d0 = GGML_FP16_TO_FP32(x[ib].d); + const float d1 = GGML_FP16_TO_FP32(y[ib].d); - summs += GGML_FP16_TO_FP32(x[i].m) * GGML_FP16_TO_FP32(y[i].s); + summs += GGML_FP16_TO_FP32(x[ib].m) * GGML_FP16_TO_FP32(y[ib].s); const __m256 d0v = __lasx_xvreplfr2vr_s( d0 ); const __m256 d1v = __lasx_xvreplfr2vr_s( d1 ); @@ -4601,8 +4491,8 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r const __m256 d0d1 = __lasx_xvfmul_s( d0v, d1v ); // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes - const __m256i qx = bytes_from_nibbles_32(x[i].qs); - const __m256i qy = __lasx_xvld( (const __m256i *)y[i].qs, 0); + const __m256i qx = bytes_from_nibbles_32(x[ib].qs); + const __m256i qy = __lasx_xvld( (const __m256i *)y[ib].qs, 0); const __m256 xy = mul_sum_us8_pairs_float(qx, qy); @@ -4610,33 +4500,34 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r acc = __lasx_xvfmadd_s( d0d1, xy, acc ); } - *s = hsum_float_8(acc) + summs; - -#else - // scalar - float sumf = 0.0; - - for (int i = 0; i < nb; i++) { - int sumi = 0; + sumf = hsum_float_8(acc) + summs; +#endif + for (; ib < nb; ++ib) { + int sumi0 = 0; + int sumi1 = 0; for (int j = 0; j < qk/2; ++j) { - const int v0 = (x[i].qs[j] & 0x0F); - const int v1 = (x[i].qs[j] >> 4); + const int v0 = (x[ib].qs[j] & 0x0F); + const int v1 = (x[ib].qs[j] >> 4); - sumi += (v0 * y[i].qs[j]) + (v1 * y[i].qs[j + qk/2]); + sumi0 += (v0 * y[ib].qs[j]); + sumi1 += (v1 * y[ib].qs[j + qk/2]); } - sumf += (GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d))*sumi + GGML_FP16_TO_FP32(x[i].m)*GGML_FP16_TO_FP32(y[i].s); + int sumi = sumi0 + sumi1; + sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d))*sumi + GGML_FP16_TO_FP32(x[ib].m)*GGML_FP16_TO_FP32(y[ib].s); } *s = sumf; -#endif } void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { const int qk = QK8_0; const int nb = n / qk; + int ib = 0; + float sumf = 0; + assert(n % qk == 0); assert(qk == QK5_0); assert(nrc == 1); @@ -4658,13 +4549,11 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r uint64_t tmp0[4]; uint64_t tmp1[4]; - assert(nb % 2 == 0); // TODO: handle odd nb - - for (int i = 0; i < nb; i += 2) { - const block_q5_0 * restrict x0 = &x[i]; - const block_q5_0 * restrict x1 = &x[i + 1]; - const block_q8_0 * restrict y0 = &y[i]; - const block_q8_0 * restrict y1 = &y[i + 1]; + for (; ib + 1 < nb; ib += 2) { + const block_q5_0 * restrict x0 = &x[ib]; + const block_q5_0 * restrict x1 = &x[ib + 1]; + const block_q8_0 * restrict y0 = &y[ib]; + const block_q8_0 * restrict y1 = &y[ib + 1]; const uint8x16_t m4b = vdupq_n_u8(0x0F); @@ -4716,7 +4605,7 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r ggml_vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); } - *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); + sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); #elif defined(__wasm_simd128__) v128_t sumv = wasm_f32x4_splat(0.0f); @@ -4724,9 +4613,9 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r uint64_t tmp[4]; // TODO: check if unrolling this is better - for (int i = 0; i < nb; ++i) { - const block_q5_0 * restrict x0 = &x[i]; - const block_q8_0 * restrict y0 = &y[i]; + for (; ib < nb; ++ib) { + const block_q5_0 * restrict x0 = &x[ib]; + const block_q8_0 * restrict y0 = &y[ib]; const v128_t m4b = wasm_i8x16_splat(0x0F); @@ -4776,23 +4665,23 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r wasm_f32x4_splat(GGML_FP16_TO_FP32(x0->d) * GGML_FP16_TO_FP32(y0->d)))); } - *s = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) + - wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3); + sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) + + wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3); #elif defined(__AVX2__) // Initialize accumulator with zeros __m256 acc = _mm256_setzero_ps(); // Main loop - for (int i = 0; i < nb; i++) { + for (; ib < nb; ++ib) { /* Compute combined scale for the block */ - const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d)); + const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d)); - __m256i qx = bytes_from_nibbles_32(x[i].qs); - __m256i bxhi = bytes_from_bits_32(x[i].qh); + __m256i qx = bytes_from_nibbles_32(x[ib].qs); + __m256i bxhi = bytes_from_bits_32(x[ib].qh); bxhi = _mm256_andnot_si256(bxhi, _mm256_set1_epi8((char)0xF0)); qx = _mm256_or_si256(qx, bxhi); - __m256i qy = _mm256_loadu_si256((const __m256i *)y[i].qs); + __m256i qy = _mm256_loadu_si256((const __m256i *)y[ib].qs); const __m256 q = mul_sum_i8_pairs_float(qx, qy); @@ -4800,19 +4689,19 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r acc = _mm256_fmadd_ps(d, q, acc); } - *s = hsum_float_8(acc); + sumf = hsum_float_8(acc); #elif defined(__AVX__) // Initialize accumulator with zeros __m256 acc = _mm256_setzero_ps(); __m128i mask = _mm_set1_epi8((char)0xF0); // Main loop - for (int i = 0; i < nb; i++) { + for (; ib < nb; ++ib) { /* Compute combined scale for the block */ - const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d)); + const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d)); - __m256i bx_0 = bytes_from_nibbles_32(x[i].qs); - const __m256i bxhi = bytes_from_bits_32(x[i].qh); + __m256i bx_0 = bytes_from_nibbles_32(x[ib].qs); + const __m256i bxhi = bytes_from_bits_32(x[ib].qh); __m128i bxhil = _mm256_castsi256_si128(bxhi); __m128i bxhih = _mm256_extractf128_si256(bxhi, 1); bxhil = _mm_andnot_si128(bxhil, mask); @@ -4823,7 +4712,7 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r bxh = _mm_or_si128(bxh, bxhih); bx_0 = MM256_SET_M128I(bxh, bxl); - const __m256i by_0 = _mm256_loadu_si256((const __m256i *)y[i].qs); + const __m256i by_0 = _mm256_loadu_si256((const __m256i *)y[ib].qs); const __m256 q = mul_sum_i8_pairs_float(bx_0, by_0); @@ -4831,10 +4720,8 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r acc = _mm256_add_ps(_mm256_mul_ps(d, q), acc); } - *s = hsum_float_8(acc); + sumf = hsum_float_8(acc); #elif defined(__riscv_v_intrinsic) - float sumf = 0.0; - uint32_t qh; size_t vl = __riscv_vsetvl_e8m1(qk/2); @@ -4846,8 +4733,8 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r vuint32m2_t vt_3 = __riscv_vsll_vx_u32m2(vt_2, 16, vl); vuint32m2_t vt_4 = __riscv_vadd_vx_u32m2(vt_1, 12, vl); - for (int i = 0; i < nb; i++) { - memcpy(&qh, x[i].qh, sizeof(uint32_t)); + for (; ib < nb; ++ib) { + memcpy(&qh, x[ib].qh, sizeof(uint32_t)); // ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4; vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(vt_2, qh, vl); @@ -4866,10 +4753,10 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl); // load - vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl); + vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[ib].qs, vl); - vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl); - vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl); + vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[ib].qs, vl); + vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[ib].qs+16, vl); vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl); vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl); @@ -4893,11 +4780,9 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r int sumi = __riscv_vmv_x_s_i32m1_i32(vs2); - sumf += (GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d)) * sumi; + sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d)) * sumi; } - *s = sumf; - #elif defined(__POWER9_VECTOR__) const vector signed char lowMask = vec_splats((signed char)0xF); const vector unsigned char v4 = vec_splats((unsigned char)4); @@ -4905,27 +4790,27 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r vector float vsumf0 = vec_splats(0.0f); #pragma GCC unroll 4 - for (int i = 0; i < nb; ++i) { - __builtin_prefetch(x[i].qs, 0, 1); - __builtin_prefetch(y[i].qs, 0, 1); + for (; ib < nb; ++ib) { + __builtin_prefetch(x[ib].qs, 0, 1); + __builtin_prefetch(y[ib].qs, 0, 1); - vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d)); - vector float vyd = vec_splats(GGML_FP16_TO_FP32(y[i].d)); + vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[ib].d)); + vector float vyd = vec_splats(GGML_FP16_TO_FP32(y[ib].d)); vector float vd = vec_mul(vxd, vyd); - vector signed long long aux64x2_0 = {(uint64_t)(table_b2b_1[x[i].qh[0]]), (uint64_t)(table_b2b_1[x[i].qh[1]])}; - vector signed long long aux64x2_1 = {(uint64_t)(table_b2b_1[x[i].qh[2]]), (uint64_t)(table_b2b_1[x[i].qh[3]])}; + vector signed long long aux64x2_0 = {(uint64_t)(table_b2b_1[x[ib].qh[0]]), (uint64_t)(table_b2b_1[x[ib].qh[1]])}; + vector signed long long aux64x2_1 = {(uint64_t)(table_b2b_1[x[ib].qh[2]]), (uint64_t)(table_b2b_1[x[ib].qh[3]])}; vector signed char qh0 = (vector signed char)aux64x2_0; vector signed char qh1 = (vector signed char)aux64x2_1; - vector signed char qxs = (vector signed char)vec_xl( 0, x[i].qs); + vector signed char qxs = (vector signed char)vec_xl( 0, x[ib].qs); vector signed char q5x0 = vec_sub(vec_and (qxs, lowMask), qh0); vector signed char q5x1 = vec_sub(vec_sr(qxs, v4), qh1); - vector signed char q8y0 = vec_xl( 0, y[i].qs); - vector signed char q8y1 = vec_xl( 16, y[i].qs); + vector signed char q8y0 = vec_xl( 0, y[ib].qs); + vector signed char q8y1 = vec_xl( 16, y[ib].qs); vector signed short qv0 = vec_add(vec_mule(q5x0, q8y0), vec_mulo(q5x0, q8y0)); vector signed short qv1 = vec_add(vec_mule(q5x1, q8y1), vec_mulo(q5x1, q8y1)); @@ -4940,23 +4825,23 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); - *s = vec_extract(vsumf0, 0); + sumf = vec_extract(vsumf0, 0); #elif defined(__loongarch_asx) // Initialize accumulator with zeros __m256 acc = (__m256)__lasx_xvldi(0); // Main loop - for (int i = 0; i < nb; i++) { + for (; ib < nb; ++ib) { /* Compute combined scale for the block */ - const __m256 d = __lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d)); //FIXME + const __m256 d = __lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d)); //FIXME - __m256i qx = bytes_from_nibbles_32(x[i].qs); - __m256i bxhi = bytes_from_bits_32(x[i].qh); + __m256i qx = bytes_from_nibbles_32(x[ib].qs); + __m256i bxhi = bytes_from_bits_32(x[ib].qh); bxhi = __lasx_xvandn_v(bxhi, __lasx_xvreplgr2vr_b((char)0xF0)); qx = __lasx_xvor_v(qx, bxhi); - __m256i qy = __lasx_xvld((const __m256i *)y[i].qs, 0); + __m256i qy = __lasx_xvld((const __m256i *)y[ib].qs, 0); const __m256 q = mul_sum_i8_pairs_float(qx, qy); @@ -4964,39 +4849,40 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r acc = __lasx_xvfmadd_s(d, q, acc); } - *s = hsum_float_8(acc); - -#else - // scalar - float sumf = 0.0; - - for (int i = 0; i < nb; i++) { + sumf = hsum_float_8(acc); +#endif + for (; ib < nb; ++ib) { uint32_t qh; - memcpy(&qh, x[i].qh, sizeof(qh)); + memcpy(&qh, x[ib].qh, sizeof(qh)); - int sumi = 0; + int sumi0 = 0; + int sumi1 = 0; for (int j = 0; j < qk/2; ++j) { const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4; const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12)); - const int32_t x0 = ((x[i].qs[j] & 0x0F) | xh_0) - 16; - const int32_t x1 = ((x[i].qs[j] >> 4) | xh_1) - 16; + const int32_t x0 = (int8_t)(((x[ib].qs[j] & 0x0F) | xh_0) - 16); + const int32_t x1 = (int8_t)(((x[ib].qs[j] >> 4) | xh_1) - 16); - sumi += (x0 * y[i].qs[j]) + (x1 * y[i].qs[j + qk/2]); + sumi0 += (x0 * y[ib].qs[j]); + sumi1 += (x1 * y[ib].qs[j + qk/2]); } - sumf += (GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d)) * sumi; + int sumi = sumi0 + sumi1; + sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d)) * sumi; } *s = sumf; -#endif } void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { const int qk = QK8_1; const int nb = n / qk; + int ib = 0; + float sumf = 0; + assert(n % qk == 0); assert(qk == QK5_1); assert(nrc == 1); @@ -5021,13 +4907,11 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r uint64_t tmp0[4]; uint64_t tmp1[4]; - assert(nb % 2 == 0); // TODO: handle odd nb - - for (int i = 0; i < nb; i += 2) { - const block_q5_1 * restrict x0 = &x[i]; - const block_q5_1 * restrict x1 = &x[i + 1]; - const block_q8_1 * restrict y0 = &y[i]; - const block_q8_1 * restrict y1 = &y[i + 1]; + for (; ib + 1 < nb; ib += 2) { + const block_q5_1 * restrict x0 = &x[ib]; + const block_q5_1 * restrict x1 = &x[ib + 1]; + const block_q8_1 * restrict y0 = &y[ib]; + const block_q8_1 * restrict y1 = &y[ib + 1]; const uint8x16_t m4b = vdupq_n_u8(0x0F); @@ -5082,7 +4966,7 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r ggml_vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); } - *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs0 + summs1; + sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs0 + summs1; #elif defined(__wasm_simd128__) v128_t sumv = wasm_f32x4_splat(0.0f); @@ -5092,9 +4976,9 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r uint64_t tmp[4]; // TODO: check if unrolling this is better - for (int i = 0; i < nb; ++i) { - const block_q5_1 * restrict x0 = &x[i]; - const block_q8_1 * restrict y0 = &y[i]; + for (; ib < nb; ++ib) { + const block_q5_1 * restrict x0 = &x[ib]; + const block_q8_1 * restrict y0 = &y[ib]; summs += GGML_FP16_TO_FP32(x0->m) * GGML_FP16_TO_FP32(y0->s); @@ -5146,8 +5030,8 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r wasm_f32x4_splat(GGML_FP16_TO_FP32(x0->d) * GGML_FP16_TO_FP32(y0->d)))); } - *s = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) + - wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3) + summs; + sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) + + wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3) + summs; #elif defined(__AVX2__) // Initialize accumulator with zeros __m256 acc = _mm256_setzero_ps(); @@ -5155,25 +5039,25 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r float summs = 0.0f; // Main loop - for (int i = 0; i < nb; i++) { - const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d)); + for (; ib < nb; ++ib) { + const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[ib].d)); - summs += GGML_FP16_TO_FP32(x[i].m) * GGML_FP16_TO_FP32(y[i].s); + summs += GGML_FP16_TO_FP32(x[ib].m) * GGML_FP16_TO_FP32(y[ib].s); - __m256i qx = bytes_from_nibbles_32(x[i].qs); - __m256i bxhi = bytes_from_bits_32(x[i].qh); + __m256i qx = bytes_from_nibbles_32(x[ib].qs); + __m256i bxhi = bytes_from_bits_32(x[ib].qh); bxhi = _mm256_and_si256(bxhi, _mm256_set1_epi8(0x10)); qx = _mm256_or_si256(qx, bxhi); - const __m256 dy = _mm256_set1_ps(GGML_FP16_TO_FP32(y[i].d)); - const __m256i qy = _mm256_loadu_si256((const __m256i *)y[i].qs); + const __m256 dy = _mm256_set1_ps(GGML_FP16_TO_FP32(y[ib].d)); + const __m256i qy = _mm256_loadu_si256((const __m256i *)y[ib].qs); const __m256 q = mul_sum_us8_pairs_float(qx, qy); acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc); } - *s = hsum_float_8(acc) + summs; + sumf = hsum_float_8(acc) + summs; #elif defined(__AVX__) // Initialize accumulator with zeros __m256 acc = _mm256_setzero_ps(); @@ -5182,13 +5066,13 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r float summs = 0.0f; // Main loop - for (int i = 0; i < nb; i++) { - const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d)); + for (; ib < nb; ++ib) { + const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[ib].d)); - summs += GGML_FP16_TO_FP32(x[i].m) * GGML_FP16_TO_FP32(y[i].s); + summs += GGML_FP16_TO_FP32(x[ib].m) * GGML_FP16_TO_FP32(y[ib].s); - __m256i bx_0 = bytes_from_nibbles_32(x[i].qs); - const __m256i bxhi = bytes_from_bits_32(x[i].qh); + __m256i bx_0 = bytes_from_nibbles_32(x[ib].qs); + const __m256i bxhi = bytes_from_bits_32(x[ib].qh); __m128i bxhil = _mm256_castsi256_si128(bxhi); __m128i bxhih = _mm256_extractf128_si256(bxhi, 1); bxhil = _mm_and_si128(bxhil, mask); @@ -5199,18 +5083,16 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r bxh = _mm_or_si128(bxh, bxhih); bx_0 = MM256_SET_M128I(bxh, bxl); - const __m256 dy = _mm256_set1_ps(GGML_FP16_TO_FP32(y[i].d)); - const __m256i by_0 = _mm256_loadu_si256((const __m256i *)y[i].qs); + const __m256 dy = _mm256_set1_ps(GGML_FP16_TO_FP32(y[ib].d)); + const __m256i by_0 = _mm256_loadu_si256((const __m256i *)y[ib].qs); const __m256 q = mul_sum_us8_pairs_float(bx_0, by_0); acc = _mm256_add_ps(_mm256_mul_ps(q, _mm256_mul_ps(dx, dy)), acc); } - *s = hsum_float_8(acc) + summs; + sumf = hsum_float_8(acc) + summs; #elif defined(__riscv_v_intrinsic) - float sumf = 0.0; - uint32_t qh; size_t vl = __riscv_vsetvl_e8m1(qk/2); @@ -5219,8 +5101,8 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl); vuint32m2_t vt_2 = __riscv_vadd_vx_u32m2(vt_1, 12, vl); - for (int i = 0; i < nb; i++) { - memcpy(&qh, x[i].qh, sizeof(uint32_t)); + for (; ib < nb; ++ib) { + memcpy(&qh, x[ib].qh, sizeof(uint32_t)); // load qh vuint32m2_t vqh = __riscv_vmv_v_x_u32m2(qh, vl); @@ -5242,10 +5124,10 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl); // load - vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl); + vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[ib].qs, vl); - vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl); - vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl); + vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[ib].qs, vl); + vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[ib].qs+16, vl); vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl); vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl); @@ -5266,50 +5148,47 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r int sumi = __riscv_vmv_x_s_i32m1_i32(vs2); - sumf += (GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d))*sumi + GGML_FP16_TO_FP32(x[i].m)*GGML_FP16_TO_FP32(y[i].s); + sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d))*sumi + GGML_FP16_TO_FP32(x[ib].m)*GGML_FP16_TO_FP32(y[ib].s); } - *s = sumf; - #elif defined(__POWER9_VECTOR__) const vector signed char lowMask = vec_splats((signed char)0xF); + const vector signed int v0 = vec_splats((int32_t)0); const vector unsigned char v4 = vec_splats((unsigned char)0x4); vector float vsumf0 = vec_splats(0.0f); #pragma GCC unroll 4 - for (int i = 0; i < nb; ++i) { - __builtin_prefetch(x[i].qs, 0, 1); - __builtin_prefetch(y[i].qs, 0, 1); + for (; ib < nb; ++ib) { + __builtin_prefetch(x[ib].qs, 0, 1); + __builtin_prefetch(y[ib].qs, 0, 1); - vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d)); - vector float vyd = vec_splats(GGML_FP16_TO_FP32(y[i].d)); + vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[ib].d)); + vector float vyd = vec_splats(GGML_FP16_TO_FP32(y[ib].d)); vector float vd = vec_mul(vxd, vyd); - vector float vxmin = vec_splats(GGML_FP16_TO_FP32(x[i].m)); - vector float vys = {GGML_FP16_TO_FP32(y[i].s), 0.f, 0.f, 0.f}; + vector float vxmin = vec_splats(GGML_FP16_TO_FP32(x[ib].m)); + vector float vys = {GGML_FP16_TO_FP32(y[ib].s), 0.f, 0.f, 0.f}; vsumf0 = vec_madd(vxmin, vys, vsumf0); - vector unsigned long long aux64x2_0 = {(uint64_t)(table_b2b_0[x[i].qh[0]]), (uint64_t)(table_b2b_0[x[i].qh[1]])}; - vector unsigned long long aux64x2_1 = {(uint64_t)(table_b2b_0[x[i].qh[2]]), (uint64_t)(table_b2b_0[x[i].qh[3]])}; + vector unsigned long long aux64x2_0 = {(uint64_t)(table_b2b_0[x[ib].qh[0]]), (uint64_t)(table_b2b_0[x[ib].qh[1]])}; + vector unsigned long long aux64x2_1 = {(uint64_t)(table_b2b_0[x[ib].qh[2]]), (uint64_t)(table_b2b_0[x[ib].qh[3]])}; vector signed char qh0 = (vector signed char)aux64x2_0; vector signed char qh1 = (vector signed char)aux64x2_1; - vector signed char qxs = (vector signed char)vec_xl( 0, x[i].qs); + vector signed char qxs = (vector signed char)vec_xl( 0, x[ib].qs); - vector signed char q5x0 = vec_or(vec_and(qxs, lowMask), qh0); - vector signed char q5x1 = vec_or(vec_sr(qxs, v4), qh1); + vector unsigned char q5x0 = (vector unsigned char)vec_or(vec_and(qxs, lowMask), qh0); + vector unsigned char q5x1 = (vector unsigned char)vec_or(vec_sr(qxs, v4), qh1); - vector signed char q8y0 = vec_xl( 0, y[i].qs); - vector signed char q8y1 = vec_xl( 16, y[i].qs); + vector signed char q8y0 = vec_xl( 0, y[ib].qs); + vector signed char q8y1 = vec_xl( 16, y[ib].qs); - vector signed short qv0 = vec_add(vec_mule(q5x0, q8y0), vec_mulo(q5x0, q8y0)); - vector signed short qv1 = vec_add(vec_mule(q5x1, q8y1), vec_mulo(q5x1, q8y1)); + vector signed int vsumi0 = v0; - qv0 = vec_add(qv0, qv1); - - vector signed int vsumi0 = vec_add(vec_unpackh(qv0), vec_unpackl(qv0)); + vsumi0 = vec_msum(q8y0, q5x0, vsumi0); + vsumi0 = vec_msum(q8y1, q5x1, vsumi0); vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); } @@ -5317,7 +5196,7 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); - *s = vec_extract(vsumf0, 0); + sumf = vec_extract(vsumf0, 0); #elif defined(__loongarch_asx) // Initialize accumulator with zeros @@ -5326,51 +5205,49 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r float summs = 0.0f; // Main loop - for (int i = 0; i < nb; i++) { - const __m256 dx = __lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(x[i].d)); + for (; ib < nb; ++ib) { + const __m256 dx = __lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(x[ib].d)); - summs += GGML_FP16_TO_FP32(x[i].m) * GGML_FP16_TO_FP32(y[i].s); + summs += GGML_FP16_TO_FP32(x[ib].m) * GGML_FP16_TO_FP32(y[ib].s); - __m256i qx = bytes_from_nibbles_32(x[i].qs); - __m256i bxhi = bytes_from_bits_32(x[i].qh); + __m256i qx = bytes_from_nibbles_32(x[ib].qs); + __m256i bxhi = bytes_from_bits_32(x[ib].qh); bxhi = __lasx_xvand_v(bxhi, __lasx_xvreplgr2vr_b(0x10)); qx = __lasx_xvor_v(qx, bxhi); - const __m256 dy = __lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(y[i].d)); - const __m256i qy = __lasx_xvld((const __m256i *)y[i].qs, 0); + const __m256 dy = __lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(y[ib].d)); + const __m256i qy = __lasx_xvld((const __m256i *)y[ib].qs, 0); const __m256 q = mul_sum_us8_pairs_float(qx, qy); acc = __lasx_xvfmadd_s(q, __lasx_xvfmul_s(dx, dy), acc); } - *s = hsum_float_8(acc) + summs; - -#else - // scalar - float sumf = 0.0; - - for (int i = 0; i < nb; i++) { + sumf = hsum_float_8(acc) + summs; +#endif + for (; ib < nb; ++ib) { uint32_t qh; - memcpy(&qh, x[i].qh, sizeof(qh)); + memcpy(&qh, x[ib].qh, sizeof(qh)); - int sumi = 0; + int sumi0 = 0; + int sumi1 = 0; for (int j = 0; j < qk/2; ++j) { const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10; const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10; - const int32_t x0 = (x[i].qs[j] & 0xF) | xh_0; - const int32_t x1 = (x[i].qs[j] >> 4) | xh_1; + const int32_t x0 = (x[ib].qs[j] & 0xF) | xh_0; + const int32_t x1 = (x[ib].qs[j] >> 4) | xh_1; - sumi += (x0 * y[i].qs[j]) + (x1 * y[i].qs[j + qk/2]); + sumi0 += (x0 * y[ib].qs[j]); + sumi1 += (x1 * y[ib].qs[j + qk/2]); } - sumf += (GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d))*sumi + GGML_FP16_TO_FP32(x[i].m)*GGML_FP16_TO_FP32(y[i].s); + int sumi = sumi0 + sumi1; + sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d))*sumi + GGML_FP16_TO_FP32(x[ib].m)*GGML_FP16_TO_FP32(y[ib].s); } *s = sumf; -#endif } void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { @@ -5447,42 +5324,44 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r return; } #endif + + int ib = 0; + float sumf = 0; + #if defined(__ARM_FEATURE_SVE) - svfloat32_t sumv0 = svdup_n_f32(0.0f); - svfloat32_t sumv1 = svdup_n_f32(0.0f); + if (svcntb() == QK8_0) { + svfloat32_t sumv0 = svdup_n_f32(0.0f); + svfloat32_t sumv1 = svdup_n_f32(0.0f); - assert(nb % 2 == 0); // TODO: handle odd nb + for (; ib + 1 < nb; ib += 2) { + const block_q8_0 * restrict x0 = &x[ib + 0]; + const block_q8_0 * restrict x1 = &x[ib + 1]; + const block_q8_0 * restrict y0 = &y[ib + 0]; + const block_q8_0 * restrict y1 = &y[ib + 1]; - for (int i = 0; i < nb; i += 2) { - const block_q8_0 * restrict x0 = &x[i + 0]; - const block_q8_0 * restrict x1 = &x[i + 1]; - const block_q8_0 * restrict y0 = &y[i + 0]; - const block_q8_0 * restrict y1 = &y[i + 1]; + // load x + const svint8_t qx0 = svld1_s8(svptrue_b8(), x0->qs); + const svint8_t qx1 = svld1_s8(svptrue_b8(), x1->qs); - // load x - const svint8_t qx0 = svld1_s8(svptrue_b8(), x0->qs); - const svint8_t qx1 = svld1_s8(svptrue_b8(), x1->qs); + // load y + const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs); + const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs); - // load y - const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs); - const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs); + sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); + sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); + } - sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); - sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); + sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1)); } - - *s = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1)); #elif defined(__ARM_NEON) float32x4_t sumv0 = vdupq_n_f32(0.0f); float32x4_t sumv1 = vdupq_n_f32(0.0f); - assert(nb % 2 == 0); // TODO: handle odd nb - - for (int i = 0; i < nb; i += 2) { - const block_q8_0 * restrict x0 = &x[i + 0]; - const block_q8_0 * restrict x1 = &x[i + 1]; - const block_q8_0 * restrict y0 = &y[i + 0]; - const block_q8_0 * restrict y1 = &y[i + 1]; + for (; ib + 1 < nb; ib += 2) { + const block_q8_0 * restrict x0 = &x[ib + 0]; + const block_q8_0 * restrict x1 = &x[ib + 1]; + const block_q8_0 * restrict y0 = &y[ib + 0]; + const block_q8_0 * restrict y1 = &y[ib + 1]; const int8x16_t x0_0 = vld1q_s8(x0->qs); const int8x16_t x0_1 = vld1q_s8(x0->qs + 16); @@ -5504,17 +5383,17 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r ggml_vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); } - *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); + sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); #elif defined(__AVX2__) || defined(__AVX__) // Initialize accumulator with zeros __m256 acc = _mm256_setzero_ps(); // Main loop - for (int i = 0; i < nb; ++i) { + for (; ib < nb; ++ib) { // Compute combined scale for the block - const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d)); - __m256i qx = _mm256_loadu_si256((const __m256i *)x[i].qs); - __m256i qy = _mm256_loadu_si256((const __m256i *)y[i].qs); + const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d)); + __m256i qx = _mm256_loadu_si256((const __m256i *)x[ib].qs); + __m256i qy = _mm256_loadu_si256((const __m256i *)y[ib].qs); const __m256 q = mul_sum_i8_pairs_float(qx, qy); @@ -5526,15 +5405,14 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r #endif } - *s = hsum_float_8(acc); + sumf = hsum_float_8(acc); #elif defined(__riscv_v_intrinsic) - float sumf = 0.0; size_t vl = __riscv_vsetvl_e8m1(qk); - for (int i = 0; i < nb; i++) { + for (; ib < nb; ++ib) { // load elements - vint8m1_t bx_0 = __riscv_vle8_v_i8m1(x[i].qs, vl); - vint8m1_t by_0 = __riscv_vle8_v_i8m1(y[i].qs, vl); + vint8m1_t bx_0 = __riscv_vle8_v_i8m1(x[ib].qs, vl); + vint8m1_t by_0 = __riscv_vle8_v_i8m1(y[ib].qs, vl); vint16m2_t vw_mul = __riscv_vwmul_vv_i16m2(bx_0, by_0, vl); @@ -5543,40 +5421,38 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r int sumi = __riscv_vmv_x_s_i32m1_i32(v_sum); - sumf += sumi*(GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d)); + sumf += sumi*(GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d)); } - - *s = sumf; - #elif defined(__POWER9_VECTOR__) + const vector signed int v0 = vec_splats((int32_t)0); vector float vsumf0 = vec_splats(0.0f); -#pragma GCC unroll 4 - for (int i = 0; i < nb; i++) { - __builtin_prefetch(x[i].qs, 0, 1); - __builtin_prefetch(y[i].qs, 0, 1); +#pragma GCC unroll 8 + for (; ib < nb; ++ib) { + __builtin_prefetch(x[ib].qs, 0, 1); + __builtin_prefetch(y[ib].qs, 0, 1); - vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d)); - vector float vyd = vec_splats(GGML_FP16_TO_FP32(y[i].d)); + vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[ib].d)); + vector float vyd = vec_splats(GGML_FP16_TO_FP32(y[ib].d)); vector float vd = vec_mul(vxd, vyd); - vector signed char q8x0 = vec_xl( 0, x[i].qs); - vector signed char q8x1 = vec_xl(16, x[i].qs); - vector signed char q8y0 = vec_xl( 0, y[i].qs); - vector signed char q8y1 = vec_xl(16, y[i].qs); + vector signed char q8x0 = vec_xl( 0, x[ib].qs); + vector signed char q8x1 = vec_xl(16, x[ib].qs); + vector signed char q8y0 = vec_xl( 0, y[ib].qs); + vector signed char q8y1 = vec_xl(16, y[ib].qs); vector signed short qv0 = vec_mule(q8x0, q8y0); vector signed short qv1 = vec_mulo(q8x0, q8y0); vector signed short qv2 = vec_mule(q8x1, q8y1); vector signed short qv3 = vec_mulo(q8x1, q8y1); - vector signed int vsumi0 = vec_add(vec_unpackh(qv0), vec_unpackh(qv1)); - vector signed int vsumi1 = vec_add(vec_unpackl(qv0), vec_unpackl(qv1)); - vector signed int vsumi2 = vec_add(vec_unpackh(qv2), vec_unpackh(qv3)); - vector signed int vsumi3 = vec_add(vec_unpackl(qv2), vec_unpackl(qv3)); + vector signed int vsumi0 = v0; + vector signed int vsumi1 = v0; - vsumi0 = vec_add(vsumi0, vsumi2); - vsumi1 = vec_add(vsumi1, vsumi3); + vsumi0 = vec_sum4s(qv0, vsumi0); + vsumi1 = vec_sum4s(qv1, vsumi1); + vsumi0 = vec_sum4s(qv2, vsumi0); + vsumi1 = vec_sum4s(qv3, vsumi1); vsumi0 = vec_add(vsumi0, vsumi1); @@ -5586,18 +5462,18 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); - *s = vec_extract(vsumf0, 0); + sumf = vec_extract(vsumf0, 0); #elif defined(__loongarch_asx) // Initialize accumulator with zeros __m256 acc = (__m256)__lasx_xvldi(0); // Main loop - for (int i = 0; i < nb; ++i) { + for (; ib < nb; ++ib) { // Compute combined scale for the block - const __m256 d = __lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d)); - __m256i qx = __lasx_xvld((const __m256i *)x[i].qs, 0); - __m256i qy = __lasx_xvld((const __m256i *)y[i].qs, 0); + const __m256 d = __lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d)); + __m256i qx = __lasx_xvld((const __m256i *)x[ib].qs, 0); + __m256i qy = __lasx_xvld((const __m256i *)y[ib].qs, 0); const __m256 q = mul_sum_i8_pairs_float(qx, qy); @@ -5605,24 +5481,19 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r acc = __lasx_xvfmadd_s( d, q, acc ); } - *s = hsum_float_8(acc); - -#else - // scalar - float sumf = 0.0; - - for (int i = 0; i < nb; i++) { + sumf = hsum_float_8(acc); +#endif + for (; ib < nb; ++ib) { int sumi = 0; for (int j = 0; j < qk; j++) { - sumi += x[i].qs[j]*y[i].qs[j]; + sumi += x[ib].qs[j]*y[ib].qs[j]; } - sumf += sumi*(GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d)); + sumf += sumi*(GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d)); } *s = sumf; -#endif } void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { @@ -5964,6 +5835,7 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r #elif defined(__POWER9_VECTOR__) const vector signed char lowMask = vec_splats((signed char)0x3); const vector signed char lowScaleMask = vec_splats((signed char)0xF); + const vector int v0 = vec_splats((int32_t)0); const vector unsigned char v2 = vec_splats((unsigned char)0x2); const vector unsigned char v6 = vec_splats((unsigned char)0x6); const vector unsigned char v4 = vec_splats((unsigned char)0x4); @@ -6001,15 +5873,17 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r vsumf2 = vec_nmsub(vec_ctf(prod2, 0), vdmin, vsumf2); vsumf3 = vec_nmsub(vec_ctf(prod3, 0), vdmin, vsumf3); - vector signed int vsumi0 = vec_splats((int32_t)0); - vector signed int vsumi1 = vec_splats((int32_t)0); - vector signed int vsumi2 = vec_splats((int32_t)0); - vector signed int vsumi3 = vec_splats((int32_t)0); - vector signed int vsumi4 = vec_splats((int32_t)0); - vector signed int vsumi5 = vec_splats((int32_t)0); - vector signed int vsumi6 = vec_splats((int32_t)0); - vector signed int vsumi7 = vec_splats((int32_t)0); + vector signed int vsumi0 = v0; + vector signed int vsumi1 = v0; + vector signed int vsumi2 = v0; + vector signed int vsumi3 = v0; + vector signed int vsumi4 = v0; + vector signed int vsumi5 = v0; + vector signed int vsumi6 = v0; + vector signed int vsumi7 = v0; + const uint8_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; for (int j = 0; j < QK_K/128; ++j) { __builtin_prefetch(q2, 0, 1); @@ -6019,14 +5893,14 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r vector signed char qxs1 = (vector signed char)vec_xl(16, q2); q2 += 32; - vector signed char q2x00 = vec_and(qxs0, lowMask); - vector signed char q2x01 = vec_and(vec_sr(qxs0, v2), lowMask); - vector signed char q2x02 = vec_and(vec_sr(qxs0, v4), lowMask); - vector signed char q2x03 = vec_and(vec_sr(qxs0, v6), lowMask); - vector signed char q2x10 = vec_and(qxs1, lowMask); - vector signed char q2x11 = vec_and(vec_sr(qxs1, v2), lowMask); - vector signed char q2x12 = vec_and(vec_sr(qxs1, v4), lowMask); - vector signed char q2x13 = vec_and(vec_sr(qxs1, v6), lowMask); + vector unsigned char q2x00 = (vector unsigned char)vec_and(qxs0, lowMask); + vector unsigned char q2x01 = (vector unsigned char)vec_and(vec_sr(qxs0, v2), lowMask); + vector unsigned char q2x02 = (vector unsigned char)vec_and(vec_sr(qxs0, v4), lowMask); + vector unsigned char q2x03 = (vector unsigned char)vec_and(vec_sr(qxs0, v6), lowMask); + vector unsigned char q2x10 = (vector unsigned char)vec_and(qxs1, lowMask); + vector unsigned char q2x11 = (vector unsigned char)vec_and(vec_sr(qxs1, v2), lowMask); + vector unsigned char q2x12 = (vector unsigned char)vec_and(vec_sr(qxs1, v4), lowMask); + vector unsigned char q2x13 = (vector unsigned char)vec_and(vec_sr(qxs1, v6), lowMask); vector signed char q8y00 = vec_xl( 0, q8); vector signed char q8y10 = vec_xl( 16, q8); @@ -6038,45 +5912,36 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r vector signed char q8y13 = vec_xl(112, q8); q8 += 128; - vector signed short qv0 = vec_add(vec_mule(q2x00, q8y00), vec_mulo(q2x00, q8y00)); - vector signed short qv1 = vec_add(vec_mule(q2x01, q8y01), vec_mulo(q2x01, q8y01)); - vector signed short qv2 = vec_add(vec_mule(q2x02, q8y02), vec_mulo(q2x02, q8y02)); - vector signed short qv3 = vec_add(vec_mule(q2x03, q8y03), vec_mulo(q2x03, q8y03)); - vector signed short qv4 = vec_add(vec_mule(q2x10, q8y10), vec_mulo(q2x10, q8y10)); - vector signed short qv5 = vec_add(vec_mule(q2x11, q8y11), vec_mulo(q2x11, q8y11)); - vector signed short qv6 = vec_add(vec_mule(q2x12, q8y12), vec_mulo(q2x12, q8y12)); - vector signed short qv7 = vec_add(vec_mule(q2x13, q8y13), vec_mulo(q2x13, q8y13)); + vector signed int qv0 = vec_msum(q8y00, q2x00, v0); + vector signed int qv1 = vec_msum(q8y01, q2x01, v0); + vector signed int qv2 = vec_msum(q8y02, q2x02, v0); + vector signed int qv3 = vec_msum(q8y03, q2x03, v0); + vector signed int qv4 = vec_msum(q8y10, q2x10, v0); + vector signed int qv5 = vec_msum(q8y11, q2x11, v0); + vector signed int qv6 = vec_msum(q8y12, q2x12, v0); + vector signed int qv7 = vec_msum(q8y13, q2x13, v0); - vector signed short vscales_h = vec_unpackh(vscales); - vector signed short vs0 = vec_splat(vscales_h, 0); - vector signed short vs1 = vec_splat(vscales_h, 1); - vector signed short vs2 = vec_splat(vscales_h, 2); - vector signed short vs3 = vec_splat(vscales_h, 3); - vector signed short vs4 = vec_splat(vscales_h, 4); - vector signed short vs5 = vec_splat(vscales_h, 5); - vector signed short vs6 = vec_splat(vscales_h, 6); - vector signed short vs7 = vec_splat(vscales_h, 7); + vector signed short vscales_07 = vec_unpackh(vscales); + vector signed int vscales_03 = vec_unpackh(vscales_07); + vector signed int vscales_47 = vec_unpackl(vscales_07); + vector signed int vs0 = vec_splat(vscales_03, 0); + vector signed int vs1 = vec_splat(vscales_03, 1); + vector signed int vs2 = vec_splat(vscales_03, 2); + vector signed int vs3 = vec_splat(vscales_03, 3); + vector signed int vs4 = vec_splat(vscales_47, 0); + vector signed int vs5 = vec_splat(vscales_47, 1); + vector signed int vs6 = vec_splat(vscales_47, 2); + vector signed int vs7 = vec_splat(vscales_47, 3); vscales = vec_sld(vscales, vscales, 8); - qv0 = vec_mul(qv0, vs0); - qv1 = vec_mul(qv1, vs2); - qv2 = vec_mul(qv2, vs4); - qv3 = vec_mul(qv3, vs6); - - qv0 = vec_madd(qv4, vs1, qv0); - qv1 = vec_madd(qv5, vs3, qv1); - qv2 = vec_madd(qv6, vs5, qv2); - qv3 = vec_madd(qv7, vs7, qv3); - - vsumi0 = vec_add(vec_unpackh(qv0), vsumi0); - vsumi1 = vec_add(vec_unpackh(qv1), vsumi1); - vsumi2 = vec_add(vec_unpackh(qv2), vsumi2); - vsumi3 = vec_add(vec_unpackh(qv3), vsumi3); - - vsumi4 = vec_add(vec_unpackl(qv0), vsumi4); - vsumi5 = vec_add(vec_unpackl(qv1), vsumi5); - vsumi6 = vec_add(vec_unpackl(qv2), vsumi6); - vsumi7 = vec_add(vec_unpackl(qv3), vsumi7); + vsumi0 = vec_add(vec_mul(qv0, vs0), vsumi0); + vsumi1 = vec_add(vec_mul(qv1, vs2), vsumi1); + vsumi2 = vec_add(vec_mul(qv2, vs4), vsumi2); + vsumi3 = vec_add(vec_mul(qv3, vs6), vsumi3); + vsumi4 = vec_add(vec_mul(qv4, vs1), vsumi4); + vsumi5 = vec_add(vec_mul(qv5, vs3), vsumi5); + vsumi6 = vec_add(vec_mul(qv6, vs5), vsumi6); + vsumi7 = vec_add(vec_mul(qv7, vs7), vsumi7); } vsumi0 = vec_add(vsumi0, vsumi4); @@ -6667,6 +6532,9 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r #elif defined(__POWER9_VECTOR__) const vector signed char lowMask = vec_splats((signed char)0x3); + const vector signed char lowMask1 = vec_splats((int8_t)0xf); + const vector signed char lowMask2 = vec_splats((int8_t)0x30); + const vector int v0 = vec_splats((int32_t)0); const vector signed char v1 = vec_splats((signed char)0x1); const vector unsigned char v2 = vec_splats((unsigned char)0x2); const vector unsigned char v3 = vec_splats((unsigned char)0x3); @@ -6684,30 +6552,33 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r vector float vyd = vec_splats(y[i].d); vector float vd = vec_mul(vxd, vyd); - uint32_t aux[3]; - uint32_t utmp[4]; + UNUSED(kmask1); + UNUSED(kmask2); - memcpy(aux, x[i].scales, 12); - utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4); - utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4); - utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4); - utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4); + vector signed char u0 = (vector signed char)vec_xl_len(x[i].scales, 8); + vector signed char u1 = vec_and(u0, lowMask1); + vector signed char u2 = (vector signed char)vec_xl_len(x[i].scales + 8, 4); + vector signed char u3 = (vector signed char)vec_mergeh((vector signed int)u2, (vector signed int)vec_sr(u2, v2)); + vector signed char u30 = vec_sl(vec_and(u3, lowMask), v4); + vector signed char u31 = vec_and(u3, lowMask2); - vector signed char vscales = (vector signed char)vec_xl( 0, utmp); + u1 = vec_or(u1, u30); + u2 = vec_or(vec_sr(u0, v4), u31); + + vector signed char vscales = (vector signed char)vec_mergeh((vector signed long long)u1, (vector signed long long)u2); vector signed char qxhs0 = (vector signed char)vec_xl( 0, x[i].hmask); vector signed char qxhs1 = (vector signed char)vec_xl(16, x[i].hmask); vscales = vec_sub(vscales, off); - vector signed int vsumi0 = vec_splats((int32_t)0); - vector signed int vsumi1 = vec_splats((int32_t)0); - vector signed int vsumi2 = vec_splats((int32_t)0); - vector signed int vsumi3 = vec_splats((int32_t)0); - vector signed int vsumi4 = vec_splats((int32_t)0); - vector signed int vsumi5 = vec_splats((int32_t)0); - vector signed int vsumi6 = vec_splats((int32_t)0); - vector signed int vsumi7 = vec_splats((int32_t)0); - + vector signed int vsumi0 = v0; + vector signed int vsumi1 = v0; + vector signed int vsumi2 = v0; + vector signed int vsumi3 = v0; + vector signed int vsumi4 = v0; + vector signed int vsumi5 = v0; + vector signed int vsumi6 = v0; + vector signed int vsumi7 = v0; const uint8_t * restrict q3 = x[i].qs; const int8_t * restrict q8 = y[i].qs; @@ -6781,23 +6652,14 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r vector signed short qv12 = vec_add(vec_mule(q3x12, q8y12), vec_mulo(q3x12, q8y12)); vector signed short qv13 = vec_add(vec_mule(q3x13, q8y13), vec_mulo(q3x13, q8y13)); - vector signed int vsum0 = vec_add(vec_mule(qv00, vs0), vec_mulo(qv00, vs0)); - vector signed int vsum1 = vec_add(vec_mule(qv01, vs2), vec_mulo(qv01, vs2)); - vector signed int vsum2 = vec_add(vec_mule(qv02, vs4), vec_mulo(qv02, vs4)); - vector signed int vsum3 = vec_add(vec_mule(qv03, vs6), vec_mulo(qv03, vs6)); - vector signed int vsum4 = vec_add(vec_mule(qv10, vs1), vec_mulo(qv10, vs1)); - vector signed int vsum5 = vec_add(vec_mule(qv11, vs3), vec_mulo(qv11, vs3)); - vector signed int vsum6 = vec_add(vec_mule(qv12, vs5), vec_mulo(qv12, vs5)); - vector signed int vsum7 = vec_add(vec_mule(qv13, vs7), vec_mulo(qv13, vs7)); - - vsumi0 = vec_add(vsum0, vsumi0); - vsumi1 = vec_add(vsum1, vsumi1); - vsumi2 = vec_add(vsum2, vsumi2); - vsumi3 = vec_add(vsum3, vsumi3); - vsumi4 = vec_add(vsum4, vsumi4); - vsumi5 = vec_add(vsum5, vsumi5); - vsumi6 = vec_add(vsum6, vsumi6); - vsumi7 = vec_add(vsum7, vsumi7); + vsumi0 = vec_msum(qv00, vs0, vsumi0); + vsumi1 = vec_msum(qv01, vs2, vsumi1); + vsumi2 = vec_msum(qv02, vs4, vsumi2); + vsumi3 = vec_msum(qv03, vs6, vsumi3); + vsumi4 = vec_msum(qv10, vs1, vsumi4); + vsumi5 = vec_msum(qv11, vs3, vsumi5); + vsumi6 = vec_msum(qv12, vs5, vsumi6); + vsumi7 = vec_msum(qv13, vs7, vsumi7); } vsumi0 = vec_add(vsumi0, vsumi4); @@ -7296,6 +7158,10 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r #elif defined(__POWER9_VECTOR__) const vector signed char lowMask = vec_splats((signed char)0xF); + const vector signed char lowMask1 = vec_splats((int8_t)0x3f); + const vector signed char lowMask2 = vec_splats((int8_t)0x30); + const vector int v0 = vec_splats((int32_t)0); + const vector unsigned char v2 = vec_splats((uint8_t)2); const vector unsigned char v4 = vec_splats((unsigned char)0x4); vector float vsumf0 = vec_splats(0.0f); @@ -7314,15 +7180,24 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r vector signed short q8ysums0 = vec_xl( 0, y[i].bsums); vector signed short q8ysums1 = vec_xl(16, y[i].bsums); - memcpy(utmp, x[i].scales, 12); + UNUSED(kmask1); + UNUSED(kmask2); + UNUSED(kmask3); + UNUSED(utmp); - utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); - const uint32_t uaux = utmp[1] & kmask1; - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[2] = uaux; - utmp[0] &= kmask1; + vector signed char u0 = (vector signed char)vec_xl_len(x[i].scales, 8); + vector signed char u1 = vec_and(vec_sr(u0, v2), lowMask2); + vector signed char u2 = (vector signed char)vec_xl_len(x[i].scales + 8, 4); + vector signed char u3 = vec_sr(u2, v4); + + vector signed char u30 = u1; + vector signed char u31 = (vector signed char)vec_mergeh((vector signed int)vec_and(u2, lowMask), (vector signed int)u3); + + u1 = vec_and(u0, lowMask1); + u2 = vec_or(u30, u31); + + vector signed char utmps = (vector signed char)vec_mergeh((vector signed int)u1, (vector signed int)u2); - vector signed char utmps = (vector signed char)vec_xl( 0, utmp); vector signed short vscales = vec_unpackh(utmps); vector signed short q4xmins = vec_unpackl(utmps); vector signed short q4xmins0 = vec_mergeh(q4xmins, q4xmins); @@ -7338,14 +7213,10 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r vsumf2 = vec_nmsub(vec_ctf(prod2, 0), vdmin, vsumf2); vsumf3 = vec_nmsub(vec_ctf(prod3, 0), vdmin, vsumf3); - vector signed int vsumi0 = vec_splats((int32_t)0); - vector signed int vsumi1 = vec_splats((int32_t)0); - vector signed int vsumi2 = vec_splats((int32_t)0); - vector signed int vsumi3 = vec_splats((int32_t)0); - vector signed int vsumi4 = vec_splats((int32_t)0); - vector signed int vsumi5 = vec_splats((int32_t)0); - vector signed int vsumi6 = vec_splats((int32_t)0); - vector signed int vsumi7 = vec_splats((int32_t)0); + vector signed int vsumi0 = v0; + vector signed int vsumi1 = v0; + vector signed int vsumi2 = v0; + vector signed int vsumi3 = v0; const uint8_t * restrict q4 = x[i].qs; const int8_t * restrict q8 = y[i].qs; @@ -7360,14 +7231,14 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r vector signed char qxs3 = (vector signed char)vec_xl(48, q4); q4 += 64; - vector signed char q4x00 = vec_and(qxs0, lowMask); - vector signed char q4x01 = vec_sr(qxs0, v4); - vector signed char q4x10 = vec_and(qxs1, lowMask); - vector signed char q4x11 = vec_sr(qxs1, v4); - vector signed char q4x20 = vec_and(qxs2, lowMask); - vector signed char q4x21 = vec_sr(qxs2, v4); - vector signed char q4x30 = vec_and(qxs3, lowMask); - vector signed char q4x31 = vec_sr(qxs3, v4); + vector unsigned char q4x00 = (vector unsigned char)vec_and(qxs0, lowMask); + vector unsigned char q4x01 = (vector unsigned char)vec_sr(qxs0, v4); + vector unsigned char q4x10 = (vector unsigned char)vec_and(qxs1, lowMask); + vector unsigned char q4x11 = (vector unsigned char)vec_sr(qxs1, v4); + vector unsigned char q4x20 = (vector unsigned char)vec_and(qxs2, lowMask); + vector unsigned char q4x21 = (vector unsigned char)vec_sr(qxs2, v4); + vector unsigned char q4x30 = (vector unsigned char)vec_and(qxs3, lowMask); + vector unsigned char q4x31 = (vector unsigned char)vec_sr(qxs3, v4); vector signed char q8y00 = vec_xl( 0, q8); vector signed char q8y10 = vec_xl( 16, q8); @@ -7379,41 +7250,33 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r vector signed char q8y31 = vec_xl(112, q8); q8 += 128; - vector signed short qv00 = vec_add(vec_mule(q4x00, q8y00), vec_mulo(q4x00, q8y00)); - vector signed short qv01 = vec_add(vec_mule(q4x01, q8y01), vec_mulo(q4x01, q8y01)); - vector signed short qv10 = vec_add(vec_mule(q4x10, q8y10), vec_mulo(q4x10, q8y10)); - vector signed short qv11 = vec_add(vec_mule(q4x11, q8y11), vec_mulo(q4x11, q8y11)); - vector signed short qv20 = vec_add(vec_mule(q4x20, q8y20), vec_mulo(q4x20, q8y20)); - vector signed short qv21 = vec_add(vec_mule(q4x21, q8y21), vec_mulo(q4x21, q8y21)); - vector signed short qv30 = vec_add(vec_mule(q4x30, q8y30), vec_mulo(q4x30, q8y30)); - vector signed short qv31 = vec_add(vec_mule(q4x31, q8y31), vec_mulo(q4x31, q8y31)); + vector signed int qv00 = vec_msum(q8y00, q4x00, v0); + vector signed int qv01 = vec_msum(q8y01, q4x01, v0); + vector signed int qv10 = vec_msum(q8y10, q4x10, v0); + vector signed int qv11 = vec_msum(q8y11, q4x11, v0); + vector signed int qv20 = vec_msum(q8y20, q4x20, v0); + vector signed int qv21 = vec_msum(q8y21, q4x21, v0); + vector signed int qv30 = vec_msum(q8y30, q4x30, v0); + vector signed int qv31 = vec_msum(q8y31, q4x31, v0); - vector signed short vs0 = vec_splat(vscales, 0); - vector signed short vs1 = vec_splat(vscales, 1); - vector signed short vs2 = vec_splat(vscales, 2); - vector signed short vs3 = vec_splat(vscales, 3); + vector signed int vscales_h = vec_unpackh(vscales); + vector signed int vs0 = vec_splat(vscales_h, 0); + vector signed int vs1 = vec_splat(vscales_h, 1); + vector signed int vs2 = vec_splat(vscales_h, 2); + vector signed int vs3 = vec_splat(vscales_h, 3); vscales = vec_sld(vscales, vscales, 8); - qv00 = vec_add(qv00, qv10); - qv10 = vec_add(qv01, qv11); - qv20 = vec_add(qv20, qv30); - qv30 = vec_add(qv21, qv31); + vsumi0 = vec_add(vec_mul(qv00, vs0), vsumi0); + vsumi1 = vec_add(vec_mul(qv01, vs1), vsumi1); + vsumi2 = vec_add(vec_mul(qv20, vs2), vsumi2); + vsumi3 = vec_add(vec_mul(qv21, vs3), vsumi3); - vsumi0 = vec_add(vec_mule(qv00, vs0), vsumi0); - vsumi1 = vec_add(vec_mulo(qv00, vs0), vsumi1); - vsumi2 = vec_add(vec_mule(qv10, vs1), vsumi2); - vsumi3 = vec_add(vec_mulo(qv10, vs1), vsumi3); - vsumi4 = vec_add(vec_mule(qv20, vs2), vsumi4); - vsumi5 = vec_add(vec_mulo(qv20, vs2), vsumi5); - vsumi6 = vec_add(vec_mule(qv30, vs3), vsumi6); - vsumi7 = vec_add(vec_mulo(qv30, vs3), vsumi7); + vsumi0 = vec_add(vec_mul(qv10, vs0), vsumi0); + vsumi1 = vec_add(vec_mul(qv11, vs1), vsumi1); + vsumi2 = vec_add(vec_mul(qv30, vs2), vsumi2); + vsumi3 = vec_add(vec_mul(qv31, vs3), vsumi3); } - vsumi0 = vec_add(vsumi0, vsumi4); - vsumi1 = vec_add(vsumi1, vsumi5); - vsumi2 = vec_add(vsumi2, vsumi6); - vsumi3 = vec_add(vsumi3, vsumi7); - vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); @@ -7915,6 +7778,9 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r #elif defined(__POWER9_VECTOR__) const vector signed char lowMask = vec_splats((signed char)0xF); + const vector signed char lowMask1 = vec_splats((int8_t)0x3f); + const vector signed char lowMask2 = vec_splats((int8_t)0x30); + const vector int v0 = vec_splats((int32_t)0); const vector unsigned char v1 = vec_splats((unsigned char)0x1); const vector unsigned char v2 = vec_splats((unsigned char)0x2); const vector unsigned char v3 = vec_splats((unsigned char)0x3); @@ -7933,18 +7799,27 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r vector float vxmin = vec_splats(GGML_FP16_TO_FP32(x[i].dmin)); vector float vdmin = vec_mul(vxmin, vyd); - memcpy(utmp, x[i].scales, 12); + UNUSED(kmask1); + UNUSED(kmask2); + UNUSED(kmask3); + UNUSED(utmp); - utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); - const uint32_t uaux = utmp[1] & kmask1; - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[2] = uaux; - utmp[0] &= kmask1; + vector signed char u0 = (vector signed char)vec_xl_len(x[i].scales, 8); + vector signed char u1 = vec_and(vec_sr(u0, v2), lowMask2); + vector signed char u2 = (vector signed char)vec_xl_len(x[i].scales + 8, 4); + vector signed char u3 = vec_sr(u2, v4); + + vector signed char u30 = u1; + vector signed char u31 = (vector signed char)vec_mergeh((vector signed int)vec_and(u2, lowMask), (vector signed int)u3); + + u1 = vec_and(u0, lowMask1); + u2 = vec_or(u30, u31); + + vector signed char utmps = (vector signed char)vec_mergeh((vector signed int)u1, (vector signed int)u2); vector signed short q8ysums0 = vec_xl( 0, y[i].bsums); vector signed short q8ysums1 = vec_xl(16, y[i].bsums); - vector signed char utmps = (vector signed char)vec_xl( 0, utmp); vector signed short vscales = vec_unpackh(utmps); vector signed short q5xmins = vec_unpackl(utmps); @@ -7964,10 +7839,10 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r vector signed char qxhs0 = (vector signed char)vec_xl( 0, x[i].qh); vector signed char qxhs1 = (vector signed char)vec_xl(16, x[i].qh); - vector signed int vsumi0 = vec_splats((int32_t)0); - vector signed int vsumi1 = vec_splats((int32_t)0); - vector signed int vsumi2 = vec_splats((int32_t)0); - vector signed int vsumi3 = vec_splats((int32_t)0); + vector signed int vsumi0 = v0; + vector signed int vsumi1 = v0; + vector signed int vsumi2 = v0; + vector signed int vsumi3 = v0; const uint8_t * restrict q5 = x[i].qs; const int8_t * restrict q8 = y[i].qs; @@ -7992,10 +7867,10 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r qxhs0 = vec_sr(qxhs0, v2); qxhs1 = vec_sr(qxhs1, v2); - vector signed char q5x00 = vec_or(q5h00, qxs00); - vector signed char q5x01 = vec_or(q5h01, qxs01); - vector signed char q5x10 = vec_or(q5h10, qxs10); - vector signed char q5x11 = vec_or(q5h11, qxs11); + vector unsigned char q5x00 = (vector unsigned char)vec_or(q5h00, qxs00); + vector unsigned char q5x01 = (vector unsigned char)vec_or(q5h01, qxs01); + vector unsigned char q5x10 = (vector unsigned char)vec_or(q5h10, qxs10); + vector unsigned char q5x11 = (vector unsigned char)vec_or(q5h11, qxs11); vector signed char q8y00 = vec_xl( 0, q8); vector signed char q8y10 = vec_xl(16, q8); @@ -8003,22 +7878,20 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r vector signed char q8y11 = vec_xl(48, q8); q8 += 64; - vector signed short qv00 = vec_add(vec_mule(q5x00, q8y00), vec_mulo(q5x00, q8y00)); - vector signed short qv01 = vec_add(vec_mule(q5x01, q8y01), vec_mulo(q5x01, q8y01)); - vector signed short qv10 = vec_add(vec_mule(q5x10, q8y10), vec_mulo(q5x10, q8y10)); - vector signed short qv11 = vec_add(vec_mule(q5x11, q8y11), vec_mulo(q5x11, q8y11)); + vector signed int qv00 = vec_msum(q8y00, q5x00, v0); + vector signed int qv01 = vec_msum(q8y01, q5x01, v0); + vector signed int qv10 = vec_msum(q8y10, q5x10, v0); + vector signed int qv11 = vec_msum(q8y11, q5x11, v0); - vector signed short vs0 = vec_splat(vscales, 0); - vector signed short vs1 = vec_splat(vscales, 1); + vector signed int vscales_h = vec_unpackh(vscales); + vector signed int vs0 = vec_splat(vscales_h, 0); + vector signed int vs1 = vec_splat(vscales_h, 1); vscales = vec_sld(vscales, vscales, 12); - qv00 = vec_add(qv00, qv10); - qv01 = vec_add(qv01, qv11); - - vsumi0 = vec_add(vec_mule(qv00, vs0), vsumi0); - vsumi1 = vec_add(vec_mulo(qv00, vs0), vsumi1); - vsumi2 = vec_add(vec_mule(qv01, vs1), vsumi2); - vsumi3 = vec_add(vec_mulo(qv01, vs1), vsumi3); + vsumi0 = vec_add(vec_mul(qv00, vs0), vsumi0); + vsumi1 = vec_add(vec_mul(qv10, vs0), vsumi1); + vsumi2 = vec_add(vec_mul(qv01, vs1), vsumi2); + vsumi3 = vec_add(vec_mul(qv11, vs1), vsumi3); } vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); @@ -8579,6 +8452,7 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r #elif defined(__POWER9_VECTOR__) const vector signed char lowMask = vec_splats((signed char)0xF); + const vector int v0 = vec_splats((int32_t)0); const vector unsigned char v2 = vec_splats((unsigned char)0x2); const vector unsigned char v3 = vec_splats((unsigned char)0x3); const vector unsigned char v4 = vec_splats((unsigned char)0x4); @@ -8595,14 +8469,14 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r vector float vyd = vec_splats(y[i].d); vector float vd = vec_mul(vxd, vyd); - vector signed int vsumi0 = vec_splats((int32_t)0); - vector signed int vsumi1 = vec_splats((int32_t)0); - vector signed int vsumi2 = vec_splats((int32_t)0); - vector signed int vsumi3 = vec_splats((int32_t)0); - vector signed int vsumi4 = vec_splats((int32_t)0); - vector signed int vsumi5 = vec_splats((int32_t)0); - vector signed int vsumi6 = vec_splats((int32_t)0); - vector signed int vsumi7 = vec_splats((int32_t)0); + vector signed int vsumi0 = v0; + vector signed int vsumi1 = v0; + vector signed int vsumi2 = v0; + vector signed int vsumi3 = v0; + vector signed int vsumi4 = v0; + vector signed int vsumi5 = v0; + vector signed int vsumi6 = v0; + vector signed int vsumi7 = v0; const uint8_t * restrict q6 = x[i].ql; const uint8_t * restrict qh = x[i].qh; @@ -8682,23 +8556,14 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r vector signed short vs6 = vec_splat(vscales, 6); vector signed short vs7 = vec_splat(vscales, 7); - vsumi0 = vec_add(vec_mule(qv00, vs0), vsumi0); - vsumi1 = vec_add(vec_mulo(qv00, vs0), vsumi1); - vsumi2 = vec_add(vec_mule(qv01, vs4), vsumi2); - vsumi3 = vec_add(vec_mulo(qv01, vs4), vsumi3); - vsumi4 = vec_add(vec_mule(qv10, vs1), vsumi4); - vsumi5 = vec_add(vec_mulo(qv10, vs1), vsumi5); - vsumi6 = vec_add(vec_mule(qv11, vs5), vsumi6); - vsumi7 = vec_add(vec_mulo(qv11, vs5), vsumi7); - - vsumi0 = vec_add(vec_mule(qv20, vs2), vsumi0); - vsumi1 = vec_add(vec_mulo(qv20, vs2), vsumi1); - vsumi2 = vec_add(vec_mule(qv21, vs6), vsumi2); - vsumi3 = vec_add(vec_mulo(qv21, vs6), vsumi3); - vsumi4 = vec_add(vec_mule(qv30, vs3), vsumi4); - vsumi5 = vec_add(vec_mulo(qv30, vs3), vsumi5); - vsumi6 = vec_add(vec_mule(qv31, vs7), vsumi6); - vsumi7 = vec_add(vec_mulo(qv31, vs7), vsumi7); + vsumi0 = vec_msum(qv00, vs0, vsumi0); + vsumi1 = vec_msum(qv01, vs4, vsumi1); + vsumi2 = vec_msum(qv10, vs1, vsumi2); + vsumi3 = vec_msum(qv11, vs5, vsumi3); + vsumi4 = vec_msum(qv20, vs2, vsumi4); + vsumi5 = vec_msum(qv21, vs6, vsumi5); + vsumi6 = vec_msum(qv30, vs3, vsumi6); + vsumi7 = vec_msum(qv31, vs7, vsumi7); } vsumi0 = vec_add(vsumi0, vsumi4); @@ -8845,7 +8710,7 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r #endif } -#if defined (__AVX2__) || defined (__ARM_NEON) || defined (__POWER9_VECTOR__) || defined(__loongarch_asx) +#if defined (__AVX__) || defined (__AVX2__) || defined (__ARM_NEON) || defined (__POWER9_VECTOR__) || defined(__loongarch_asx) static const int8_t keven_signs_q2xs[1024] = { 1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, -1, @@ -8978,7 +8843,63 @@ void ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, size_t bs, const void *s = 0.125f * hsum_float_8(accumf); +#elif defined(__AVX__) + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + + uint32_t aux32[4]; + const uint8_t * aux8 = (const uint8_t *)aux32; + + __m256 accumf = _mm256_setzero_ps(); + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint16_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + __m128i sumi1_0 = _mm_setzero_si128(); + __m128i sumi1_1 = _mm_setzero_si128(); + __m128i sumi2_0 = _mm_setzero_si128(); + __m128i sumi2_1 = _mm_setzero_si128(); + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8; + const __m128i q2_1_0 = _mm_set_epi64x(iq2xxs_grid[aux8[1]], iq2xxs_grid[aux8[0]]); + const __m128i q2_1_1 = _mm_set_epi64x(iq2xxs_grid[aux8[3]], iq2xxs_grid[aux8[2]]); + const __m128i q2_2_0 = _mm_set_epi64x(iq2xxs_grid[aux8[9]], iq2xxs_grid[aux8[8]]); + const __m128i q2_2_1 = _mm_set_epi64x(iq2xxs_grid[aux8[11]], iq2xxs_grid[aux8[10]]); + const __m128i s2_1_0 = _mm_set_epi64x(signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]); + const __m128i s2_1_1 = _mm_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127]); + const __m128i s2_2_0 = _mm_set_epi64x(signs64[(aux32[3] >> 7) & 127], signs64[(aux32[3] >> 0) & 127]); + const __m128i s2_2_1 = _mm_set_epi64x(signs64[(aux32[3] >> 21) & 127], signs64[(aux32[3] >> 14) & 127]); + const __m128i q8s_1_0 = _mm_sign_epi8(q8_1_0, s2_1_0); + const __m128i q8s_1_1 = _mm_sign_epi8(q8_1_1, s2_1_1); + const __m128i q8s_2_0 = _mm_sign_epi8(q8_2_0, s2_2_0); + const __m128i q8s_2_1 = _mm_sign_epi8(q8_2_1, s2_2_1); + const __m128i dot1_0 = _mm_maddubs_epi16(q2_1_0, q8s_1_0); + const __m128i dot1_1 = _mm_maddubs_epi16(q2_1_1, q8s_1_1); + const __m128i dot2_0 = _mm_maddubs_epi16(q2_2_0, q8s_2_0); + const __m128i dot2_1 = _mm_maddubs_epi16(q2_2_1, q8s_2_1); + const uint16_t ls1 = aux32[1] >> 28; + const uint16_t ls2 = aux32[3] >> 28; + const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(2*ls1+1)); + const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(2*ls1+1)); + const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(2*ls2+1)); + const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(2*ls2+1)); + sumi1_0 = _mm_add_epi32(sumi1_0, p1_0); + sumi1_1 = _mm_add_epi32(sumi1_1, p1_1); + sumi2_0 = _mm_add_epi32(sumi2_0, p2_0); + sumi2_1 = _mm_add_epi32(sumi2_1, p2_1); + } + + accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf); + + } + + *s = 0.125f * hsum_float_8(accumf); + #elif defined(__POWER9_VECTOR__) + const vector int v0 = vec_splats((int32_t)0); vector float vsumf0 = vec_splats(0.0f); vector float vsumf1 = vec_splats(0.0f); vector float vsumf2 = vec_splats(0.0f); @@ -8991,14 +8912,10 @@ void ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, size_t bs, const void vector float vyd = vec_splats(y[i].d); vector float vd = vec_mul(vxd, vyd); - vector signed int vsumi0 = vec_splats((int32_t)0); - vector signed int vsumi1 = vec_splats((int32_t)0); - vector signed int vsumi2 = vec_splats((int32_t)0); - vector signed int vsumi3 = vec_splats((int32_t)0); - vector signed int vsumi4 = vec_splats((int32_t)0); - vector signed int vsumi5 = vec_splats((int32_t)0); - vector signed int vsumi6 = vec_splats((int32_t)0); - vector signed int vsumi7 = vec_splats((int32_t)0); + vector signed int vsumi0 = v0; + vector signed int vsumi1 = v0; + vector signed int vsumi2 = v0; + vector signed int vsumi3 = v0; const uint16_t * restrict q2 = x[i].qs; const int8_t * restrict q8 = y[i].qs; @@ -9045,21 +8962,12 @@ void ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, size_t bs, const void vector signed short vscales01 = vec_splats((int16_t)(2*ls0+1)); vector signed short vscales23 = vec_splats((int16_t)(2*ls1+1)); - vsumi0 = vec_add(vec_mule(qv0, vscales01), vsumi0); - vsumi1 = vec_add(vec_mule(qv1, vscales01), vsumi1); - vsumi2 = vec_add(vec_mule(qv2, vscales23), vsumi2); - vsumi3 = vec_add(vec_mule(qv3, vscales23), vsumi3); - vsumi4 = vec_add(vec_mulo(qv0, vscales01), vsumi4); - vsumi5 = vec_add(vec_mulo(qv1, vscales01), vsumi5); - vsumi6 = vec_add(vec_mulo(qv2, vscales23), vsumi6); - vsumi7 = vec_add(vec_mulo(qv3, vscales23), vsumi7); + vsumi0 = vec_msum(qv0, vscales01, vsumi0); + vsumi1 = vec_msum(qv1, vscales01, vsumi1); + vsumi2 = vec_msum(qv2, vscales23, vsumi2); + vsumi3 = vec_msum(qv3, vscales23, vsumi3); } - vsumi0 = vec_add(vsumi0, vsumi4); - vsumi1 = vec_add(vsumi1, vsumi5); - vsumi2 = vec_add(vsumi2, vsumi6); - vsumi3 = vec_add(vsumi3, vsumi7); - vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); @@ -9333,6 +9241,165 @@ void ggml_vec_dot_iq2_xs_q8_K(int n, float * restrict s, size_t bs, const void * } *s = 0.125f * hsum_float_8(accumf); + +#elif defined(__AVX__) + const __m128i mone = _mm_set1_epi8(1); + static const char block_sign_shuffle_mask_1[32] = { + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, + 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, + }; + static const char block_sign_shuffle_mask_2[32] = { + 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, + 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, + }; + static const uint8_t bit_selector_mask_bytes[32] = { + 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + }; + + const __m128i bit_selector_mask_0 = _mm_loadu_si128((const __m128i*)bit_selector_mask_bytes); + const __m128i bit_selector_mask_1 = _mm_loadu_si128((const __m128i*)bit_selector_mask_bytes + 1); + const __m128i block_sign_shuffle_1_0 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_1); + const __m128i block_sign_shuffle_1_1 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_1 + 1); + const __m128i block_sign_shuffle_2_0 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_2); + const __m128i block_sign_shuffle_2_1 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_2 + 1); + + static const uint8_t k_bit_helper[32] = { + 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00, + 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00, + }; + const __m128i bit_helper_0 = _mm_loadu_si128((const __m128i*)k_bit_helper); + const __m128i bit_helper_1 = _mm_loadu_si128((const __m128i*)k_bit_helper + 1); + const __m128i m511 = _mm_set1_epi16(511); + const __m128i m4 = _mm_set1_epi8(0xf); + const __m128i m1 = _mm_set1_epi8(1); + + uint64_t aux64; + + // somewhat hacky, but gives a significant boost in performance + __m256i aux_gindex; + const uint16_t * gindex = (const uint16_t *)&aux_gindex; + + __m256 accumf = _mm256_setzero_ps(); + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint16_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + memcpy(&aux64, x[i].scales, 8); + __m128i stmp = _mm_set1_epi64x(aux64); + stmp = _mm_unpacklo_epi8(_mm_and_si128(stmp, m4), _mm_and_si128(_mm_srli_epi16(stmp, 4), m4)); + const __m128i scales = _mm_add_epi8(_mm_slli_epi16(stmp, 1), m1); + + __m128i sumi1_0 = _mm_setzero_si128(); + __m128i sumi1_1 = _mm_setzero_si128(); + __m128i sumi2_0 = _mm_setzero_si128(); + __m128i sumi2_1 = _mm_setzero_si128(); + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 4) { + + const __m128i q2_data_0 = _mm_loadu_si128((const __m128i*)q2); + const __m128i q2_data_1 = _mm_loadu_si128((const __m128i*)q2 + 1); q2 += 16; + aux_gindex = MM256_SET_M128I(_mm_and_si128(q2_data_1, m511), _mm_and_si128(q2_data_0, m511)); + + const __m128i partial_sign_bits_0 = _mm_srli_epi16(q2_data_0, 9); + const __m128i partial_sign_bits_1 = _mm_srli_epi16(q2_data_1, 9); + const __m128i partial_sign_bits_upper_0 = _mm_srli_epi16(q2_data_0, 13); + const __m128i partial_sign_bits_upper_1 = _mm_srli_epi16(q2_data_1, 13); + const __m128i partial_sign_bits_for_counting_0 = _mm_xor_si128(partial_sign_bits_0, partial_sign_bits_upper_0); + const __m128i partial_sign_bits_for_counting_1 = _mm_xor_si128(partial_sign_bits_1, partial_sign_bits_upper_1); + + const __m128i odd_bits_0 = _mm_shuffle_epi8(bit_helper_0, partial_sign_bits_for_counting_0); + const __m128i odd_bits_1 = _mm_shuffle_epi8(bit_helper_1, partial_sign_bits_for_counting_1); + const __m128i full_sign_bits_0 = _mm_or_si128(partial_sign_bits_0, odd_bits_0); + const __m128i full_sign_bits_1 = _mm_or_si128(partial_sign_bits_1, odd_bits_1); + + const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_3_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_3_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_4_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_4_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + + const __m128i q2_1_0 = _mm_set_epi64x(iq2xs_grid[gindex[1]], iq2xs_grid[gindex[0]]); + const __m128i q2_1_1 = _mm_set_epi64x(iq2xs_grid[gindex[3]], iq2xs_grid[gindex[2]]); + const __m128i q2_2_0 = _mm_set_epi64x(iq2xs_grid[gindex[5]], iq2xs_grid[gindex[4]]); + const __m128i q2_2_1 = _mm_set_epi64x(iq2xs_grid[gindex[7]], iq2xs_grid[gindex[6]]); + const __m128i q2_3_0 = _mm_set_epi64x(iq2xs_grid[gindex[9]], iq2xs_grid[gindex[8]]); + const __m128i q2_3_1 = _mm_set_epi64x(iq2xs_grid[gindex[11]], iq2xs_grid[gindex[10]]); + const __m128i q2_4_0 = _mm_set_epi64x(iq2xs_grid[gindex[13]], iq2xs_grid[gindex[12]]); + const __m128i q2_4_1 = _mm_set_epi64x(iq2xs_grid[gindex[15]], iq2xs_grid[gindex[14]]); + + // AVX2 full_signs_1 is full_sign_bits_0 here + // AVX2 full_signs_2 is full_sign_bits_1 here + __m128i signs_0, signs_1; + signs_0 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_1_0); + signs_1 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_1_1); + signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0); + signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1); + const __m128i q8s_1_0 = _mm_sign_epi8(q8_1_0, _mm_or_si128(signs_0, mone)); + const __m128i q8s_1_1 = _mm_sign_epi8(q8_1_1, _mm_or_si128(signs_1, mone)); + + signs_0 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_2_0); + signs_1 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_2_1); + signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0); + signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1); + const __m128i q8s_2_0 = _mm_sign_epi8(q8_2_0, _mm_or_si128(signs_0, mone)); + const __m128i q8s_2_1 = _mm_sign_epi8(q8_2_1, _mm_or_si128(signs_1, mone)); + + signs_0 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_1_0); + signs_1 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_1_1); + signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0); + signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1); + const __m128i q8s_3_0 = _mm_sign_epi8(q8_3_0, _mm_or_si128(signs_0, mone)); + const __m128i q8s_3_1 = _mm_sign_epi8(q8_3_1, _mm_or_si128(signs_1, mone)); + + signs_0 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_2_0); + signs_1 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_2_1); + signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0); + signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1); + const __m128i q8s_4_0 = _mm_sign_epi8(q8_4_0, _mm_or_si128(signs_0, mone)); + const __m128i q8s_4_1 = _mm_sign_epi8(q8_4_1, _mm_or_si128(signs_1, mone)); + + const __m128i dot1_0 = _mm_maddubs_epi16(q2_1_0, q8s_1_0); + const __m128i dot1_1 = _mm_maddubs_epi16(q2_1_1, q8s_1_1); + const __m128i dot2_0 = _mm_maddubs_epi16(q2_2_0, q8s_2_0); + const __m128i dot2_1 = _mm_maddubs_epi16(q2_2_1, q8s_2_1); + const __m128i dot3_0 = _mm_maddubs_epi16(q2_3_0, q8s_3_0); + const __m128i dot3_1 = _mm_maddubs_epi16(q2_3_1, q8s_3_1); + const __m128i dot4_0 = _mm_maddubs_epi16(q2_4_0, q8s_4_0); + const __m128i dot4_1 = _mm_maddubs_epi16(q2_4_1, q8s_4_1); + + __m128i sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+0)); + const __m128i sc1_0 = _mm_cvtepi8_epi16(sc_tmp); + const __m128i sc1_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8)); + sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+1)); + const __m128i sc2_0 = _mm_cvtepi8_epi16(sc_tmp); + const __m128i sc2_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8)); + sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+2)); + const __m128i sc3_0 = _mm_cvtepi8_epi16(sc_tmp); + const __m128i sc3_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8)); + sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+3)); + const __m128i sc4_0 = _mm_cvtepi8_epi16(sc_tmp); + const __m128i sc4_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8)); + + sumi1_0 = _mm_add_epi32(sumi1_0, _mm_madd_epi16(dot1_0, sc1_0)); + sumi1_1 = _mm_add_epi32(sumi1_1, _mm_madd_epi16(dot1_1, sc1_1)); + sumi2_0 = _mm_add_epi32(sumi2_0, _mm_madd_epi16(dot2_0, sc2_0)); + sumi2_1 = _mm_add_epi32(sumi2_1, _mm_madd_epi16(dot2_1, sc2_1)); + sumi1_0 = _mm_add_epi32(sumi1_0, _mm_madd_epi16(dot3_0, sc3_0)); + sumi1_1 = _mm_add_epi32(sumi1_1, _mm_madd_epi16(dot3_1, sc3_1)); + sumi2_0 = _mm_add_epi32(sumi2_0, _mm_madd_epi16(dot4_0, sc4_0)); + sumi2_1 = _mm_add_epi32(sumi2_1, _mm_madd_epi16(dot4_1, sc4_1)); + } + + accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf); + + } + + *s = 0.125f * hsum_float_8(accumf); + #elif defined(__loongarch_asx) const __m256i mone = __lasx_xvreplgr2vr_b(1); @@ -9451,6 +9518,7 @@ void ggml_vec_dot_iq2_xs_q8_K(int n, float * restrict s, size_t bs, const void * *s = 0.125f * hsum_float_8(accumf); #elif defined(__POWER9_VECTOR__) + const vector int v0 = vec_splats((int32_t)0); vector float vsumf0 = vec_splats(0.0f); vector float vsumf1 = vec_splats(0.0f); vector float vsumf2 = vec_splats(0.0f); @@ -9463,14 +9531,10 @@ void ggml_vec_dot_iq2_xs_q8_K(int n, float * restrict s, size_t bs, const void * vector float vyd = vec_splats(y[i].d); vector float vd = vec_mul(vxd, vyd); - vector signed int vsumi0 = vec_splats((int32_t)0); - vector signed int vsumi1 = vec_splats((int32_t)0); - vector signed int vsumi2 = vec_splats((int32_t)0); - vector signed int vsumi3 = vec_splats((int32_t)0); - vector signed int vsumi4 = vec_splats((int32_t)0); - vector signed int vsumi5 = vec_splats((int32_t)0); - vector signed int vsumi6 = vec_splats((int32_t)0); - vector signed int vsumi7 = vec_splats((int32_t)0); + vector signed int vsumi0 = v0; + vector signed int vsumi1 = v0; + vector signed int vsumi2 = v0; + vector signed int vsumi3 = v0; const uint16_t * restrict q2 = x[i].qs; const uint8_t * restrict sc = x[i].scales; @@ -9518,21 +9582,12 @@ void ggml_vec_dot_iq2_xs_q8_K(int n, float * restrict s, size_t bs, const void * vector signed short vscales2 = vec_splats((int16_t)(2*ls2+1)); vector signed short vscales3 = vec_splats((int16_t)(2*ls3+1)); - vsumi0 = vec_add(vec_mule(qv0, vscales0), vsumi0); - vsumi1 = vec_add(vec_mule(qv1, vscales1), vsumi1); - vsumi2 = vec_add(vec_mule(qv2, vscales2), vsumi2); - vsumi3 = vec_add(vec_mule(qv3, vscales3), vsumi3); - vsumi4 = vec_add(vec_mulo(qv0, vscales0), vsumi4); - vsumi5 = vec_add(vec_mulo(qv1, vscales1), vsumi5); - vsumi6 = vec_add(vec_mulo(qv2, vscales2), vsumi6); - vsumi7 = vec_add(vec_mulo(qv3, vscales3), vsumi7); + vsumi0 = vec_msum(qv0, vscales0, vsumi0); + vsumi1 = vec_msum(qv1, vscales1, vsumi1); + vsumi2 = vec_msum(qv2, vscales2, vsumi2); + vsumi3 = vec_msum(qv3, vscales3, vsumi3); } - vsumi0 = vec_add(vsumi0, vsumi4); - vsumi1 = vec_add(vsumi1, vsumi5); - vsumi2 = vec_add(vsumi2, vsumi6); - vsumi3 = vec_add(vsumi3, vsumi7); - vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); @@ -9748,6 +9803,98 @@ void ggml_vec_dot_iq2_s_q8_K(int n, float * restrict s, size_t bs, const void * *s = 0.125f * hsum_float_8(accumf); +#elif defined(__AVX__) + static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, + 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 + }; + + static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + }; + + const __m128i m4 = _mm_set1_epi8(0xf); + const __m128i m1 = _mm_set1_epi8(1); + + const __m128i mask1_0 = _mm_loadu_si128((const __m128i*)k_mask1); + const __m128i mask1_1 = _mm_loadu_si128((const __m128i*)k_mask1 + 1); + const __m128i mask2_0 = _mm_loadu_si128((const __m128i*)k_mask2); + const __m128i mask2_1 = _mm_loadu_si128((const __m128i*)k_mask2 + 1); + + uint64_t aux64; + + __m256 accumf = _mm256_setzero_ps(); + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint8_t * restrict qs = x[i].qs; + const uint8_t * restrict qh = x[i].qh; + const uint16_t * restrict signs = (const uint16_t *)(x[i].qs + QK_K/8); + const int8_t * restrict q8 = y[i].qs; + + memcpy(&aux64, x[i].scales, 8); + const __m128i scales8 = _mm_add_epi8(_mm_slli_epi16(_mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), m4), 1), m1); + const __m128i scales16_0 = _mm_cvtepi8_epi16(scales8); + const __m128i scales16_1 = _mm_cvtepi8_epi16(_mm_srli_si128(scales8, 8)); + + __m128i sumi1_0 = _mm_setzero_si128(); + __m128i sumi1_1 = _mm_setzero_si128(); + __m128i sumi2_0 = _mm_setzero_si128(); + __m128i sumi2_1 = _mm_setzero_si128(); + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q2_1_0 = _mm_set_epi64x(iq2s_grid[qs[1] | ((qh[ib32+0] << 6) & 0x300)], + iq2s_grid[qs[0] | ((qh[ib32+0] << 8) & 0x300)]); + const __m128i q2_1_1 = _mm_set_epi64x(iq2s_grid[qs[3] | ((qh[ib32+0] << 2) & 0x300)], + iq2s_grid[qs[2] | ((qh[ib32+0] << 4) & 0x300)]); + const __m128i q2_2_0 = _mm_set_epi64x(iq2s_grid[qs[5] | ((qh[ib32+1] << 6) & 0x300)], + iq2s_grid[qs[4] | ((qh[ib32+1] << 8) & 0x300)]); + const __m128i q2_2_1 = _mm_set_epi64x(iq2s_grid[qs[7] | ((qh[ib32+1] << 2) & 0x300)], + iq2s_grid[qs[6] | ((qh[ib32+1] << 4) & 0x300)]); + qs += 8; + + __m128i aux128_0 = _mm_set1_epi32(signs[0] | ((uint32_t) signs[1] << 16)); + __m128i aux128_1 = aux128_0; + aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0); + aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1); + const __m128i s2_1_0 = _mm_cmpeq_epi8(aux128_0, mask2_0); + const __m128i s2_1_1 = _mm_cmpeq_epi8(aux128_1, mask2_1); + const __m128i q8s_1_0 = _mm_sub_epi8(_mm_xor_si128(s2_1_0, q8_1_0), s2_1_0); + const __m128i q8s_1_1 = _mm_sub_epi8(_mm_xor_si128(s2_1_1, q8_1_1), s2_1_1); + + aux128_0 = _mm_set1_epi32(signs[2] | ((uint32_t) signs[3] << 16)); + aux128_1 = aux128_0; + aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0); + aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1); + const __m128i s2_2_0 = _mm_cmpeq_epi8(aux128_0, mask2_0); + const __m128i s2_2_1 = _mm_cmpeq_epi8(aux128_1, mask2_1); + const __m128i q8s_2_0 = _mm_sub_epi8(_mm_xor_si128(s2_2_0, q8_2_0), s2_2_0); + const __m128i q8s_2_1 = _mm_sub_epi8(_mm_xor_si128(s2_2_1, q8_2_1), s2_2_1); + + signs += 4; + + const __m128i dot1_0 = _mm_maddubs_epi16(q2_1_0, q8s_1_0); + const __m128i dot1_1 = _mm_maddubs_epi16(q2_1_1, q8s_1_1); + const __m128i dot2_0 = _mm_maddubs_epi16(q2_2_0, q8s_2_0); + const __m128i dot2_1 = _mm_maddubs_epi16(q2_2_1, q8s_2_1); + + const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_shuffle_epi8(scales16_0, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+0), 0))); + const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_shuffle_epi8(scales16_1, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+0), 1))); + const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_shuffle_epi8(scales16_0, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+1), 0))); + const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_shuffle_epi8(scales16_1, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+1), 1))); + sumi1_0 = _mm_add_epi32(sumi1_0, p1_0); + sumi1_1 = _mm_add_epi32(sumi1_1, p1_1); + sumi2_0 = _mm_add_epi32(sumi2_0, p2_0); + sumi2_1 = _mm_add_epi32(sumi2_1, p2_1); + } + + accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf); + + } + + *s = 0.125f * hsum_float_8(accumf); + #elif defined(__POWER9_VECTOR__) static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 @@ -9755,6 +9902,8 @@ void ggml_vec_dot_iq2_s_q8_K(int n, float * restrict s, size_t bs, const void * static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,}; + const vector int v0 = vec_splats((int32_t)0); + vector float vsumf0 = vec_splats(0.0f); vector float vsumf1 = vec_splats(0.0f); vector float vsumf2 = vec_splats(0.0f); @@ -9769,14 +9918,10 @@ void ggml_vec_dot_iq2_s_q8_K(int n, float * restrict s, size_t bs, const void * vector float vyd = vec_splats(y[i].d); vector float vd = vec_mul(vxd, vyd); - vector signed int vsumi0 = vec_splats((int32_t)0); - vector signed int vsumi1 = vec_splats((int32_t)0); - vector signed int vsumi2 = vec_splats((int32_t)0); - vector signed int vsumi3 = vec_splats((int32_t)0); - vector signed int vsumi4 = vec_splats((int32_t)0); - vector signed int vsumi5 = vec_splats((int32_t)0); - vector signed int vsumi6 = vec_splats((int32_t)0); - vector signed int vsumi7 = vec_splats((int32_t)0); + vector signed int vsumi0 = v0; + vector signed int vsumi1 = v0; + vector signed int vsumi2 = v0; + vector signed int vsumi3 = v0; const uint8_t * restrict q2 = x[i].qs; const uint8_t * restrict qh = x[i].qh; @@ -9836,21 +9981,12 @@ void ggml_vec_dot_iq2_s_q8_K(int n, float * restrict s, size_t bs, const void * vector signed short vscales2 = vec_splats((int16_t)(2*ls2+1)); vector signed short vscales3 = vec_splats((int16_t)(2*ls3+1)); - vsumi0 = vec_add(vec_mule(qv0, vscales0), vsumi0); - vsumi1 = vec_add(vec_mule(qv1, vscales1), vsumi1); - vsumi2 = vec_add(vec_mule(qv2, vscales2), vsumi2); - vsumi3 = vec_add(vec_mule(qv3, vscales3), vsumi3); - vsumi4 = vec_add(vec_mulo(qv0, vscales0), vsumi4); - vsumi5 = vec_add(vec_mulo(qv1, vscales1), vsumi5); - vsumi6 = vec_add(vec_mulo(qv2, vscales2), vsumi6); - vsumi7 = vec_add(vec_mulo(qv3, vscales3), vsumi7); + vsumi0 = vec_msum(qv0, vscales0, vsumi0); + vsumi1 = vec_msum(qv1, vscales1, vsumi1); + vsumi2 = vec_msum(qv2, vscales2, vsumi2); + vsumi3 = vec_msum(qv3, vscales3, vsumi3); } - vsumi0 = vec_add(vsumi0, vsumi4); - vsumi1 = vec_add(vsumi1, vsumi5); - vsumi2 = vec_add(vsumi2, vsumi6); - vsumi3 = vec_add(vsumi3, vsumi7); - vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); @@ -10085,9 +10221,68 @@ void ggml_vec_dot_iq3_xxs_q8_K(int n, float * restrict s, size_t bs, const void *s = 0.25f * hsum_float_8(accumf); +#elif defined(__AVX__) + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + + uint32_t aux32[2]; + + __m256 accumf = _mm256_setzero_ps(); + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint8_t * restrict q3 = x[i].qs; + const uint8_t * restrict gas = x[i].qs + QK_K/4; + const int8_t * restrict q8 = y[i].qs; + __m128i sumi1_0 = _mm_setzero_si128(); + __m128i sumi1_1 = _mm_setzero_si128(); + __m128i sumi2_0 = _mm_setzero_si128(); + __m128i sumi2_1 = _mm_setzero_si128(); + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q2_1_0 = _mm_set_epi32(iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]); + const __m128i q2_1_1 = _mm_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]]); + q3 += 8; + const __m128i q2_2_0 = _mm_set_epi32(iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]); + const __m128i q2_2_1 = _mm_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]]); + q3 += 8; + memcpy(aux32, gas, 8); gas += 8; + const __m128i s2_1_0 = _mm_set_epi64x(signs64[(aux32[0] >> 7) & 127], signs64[(aux32[0] >> 0) & 127]); + const __m128i s2_1_1 = _mm_set_epi64x(signs64[(aux32[0] >> 21) & 127], signs64[(aux32[0] >> 14) & 127]); + const __m128i s2_2_0 = _mm_set_epi64x(signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]); + const __m128i s2_2_1 = _mm_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127]); + const __m128i q8s_1_0 = _mm_sign_epi8(q8_1_0, s2_1_0); + const __m128i q8s_1_1 = _mm_sign_epi8(q8_1_1, s2_1_1); + const __m128i q8s_2_0 = _mm_sign_epi8(q8_2_0, s2_2_0); + const __m128i q8s_2_1 = _mm_sign_epi8(q8_2_1, s2_2_1); + const __m128i dot1_0 = _mm_maddubs_epi16(q2_1_0, q8s_1_0); + const __m128i dot1_1 = _mm_maddubs_epi16(q2_1_1, q8s_1_1); + const __m128i dot2_0 = _mm_maddubs_epi16(q2_2_0, q8s_2_0); + const __m128i dot2_1 = _mm_maddubs_epi16(q2_2_1, q8s_2_1); + const uint16_t ls1 = aux32[0] >> 28; + const uint16_t ls2 = aux32[1] >> 28; + const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(2*ls1+1)); + const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(2*ls1+1)); + const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(2*ls2+1)); + const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(2*ls2+1)); + sumi1_0 = _mm_add_epi32(sumi1_0, p1_0); + sumi1_1 = _mm_add_epi32(sumi1_1, p1_1); + sumi2_0 = _mm_add_epi32(sumi2_0, p2_0); + sumi2_1 = _mm_add_epi32(sumi2_1, p2_1); + } + + accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf); + + } + + *s = 0.25f * hsum_float_8(accumf); + #elif defined(__POWER9_VECTOR__) const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + const vector int v0 = vec_splats((int32_t)0); + vector float vsumf0 = vec_splats(0.0f); vector float vsumf1 = vec_splats(0.0f); vector float vsumf2 = vec_splats(0.0f); @@ -10098,14 +10293,10 @@ void ggml_vec_dot_iq3_xxs_q8_K(int n, float * restrict s, size_t bs, const void vector float vyd = vec_splats(y[i].d); vector float vd = vec_mul(vxd, vyd); - vector signed int vsumi0 = vec_splats((int32_t)0); - vector signed int vsumi1 = vec_splats((int32_t)0); - vector signed int vsumi2 = vec_splats((int32_t)0); - vector signed int vsumi3 = vec_splats((int32_t)0); - vector signed int vsumi4 = vec_splats((int32_t)0); - vector signed int vsumi5 = vec_splats((int32_t)0); - vector signed int vsumi6 = vec_splats((int32_t)0); - vector signed int vsumi7 = vec_splats((int32_t)0); + vector signed int vsumi0 = v0; + vector signed int vsumi1 = v0; + vector signed int vsumi2 = v0; + vector signed int vsumi3 = v0; const uint8_t * restrict q3 = x[i].qs; const uint32_t * restrict signs = (const uint32_t *)(x[i].qs + QK_K/4); @@ -10150,21 +10341,12 @@ void ggml_vec_dot_iq3_xxs_q8_K(int n, float * restrict s, size_t bs, const void vector signed short vscales01 = (vector signed short)vec_splats((uint16_t)(2*ls0+1)); vector signed short vscales23 = (vector signed short)vec_splats((uint16_t)(2*ls1+1)); - vsumi0 = vec_add(vec_mule(qv0, vscales01), vsumi0); - vsumi1 = vec_add(vec_mule(qv1, vscales01), vsumi1); - vsumi2 = vec_add(vec_mule(qv2, vscales23), vsumi2); - vsumi3 = vec_add(vec_mule(qv3, vscales23), vsumi3); - vsumi4 = vec_add(vec_mulo(qv0, vscales01), vsumi4); - vsumi5 = vec_add(vec_mulo(qv1, vscales01), vsumi5); - vsumi6 = vec_add(vec_mulo(qv2, vscales23), vsumi6); - vsumi7 = vec_add(vec_mulo(qv3, vscales23), vsumi7); + vsumi0 = vec_msum(qv0, vscales01, vsumi0); + vsumi1 = vec_msum(qv1, vscales01, vsumi1); + vsumi2 = vec_msum(qv2, vscales23, vsumi2); + vsumi3 = vec_msum(qv3, vscales23, vsumi3); } - vsumi0 = vec_add(vsumi0, vsumi4); - vsumi1 = vec_add(vsumi1, vsumi5); - vsumi2 = vec_add(vsumi2, vsumi6); - vsumi3 = vec_add(vsumi3, vsumi7); - vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); @@ -10447,6 +10629,112 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * restrict s, size_t bs, const void * *s = hsum_float_8(accumf); +#elif defined(__AVX__) + static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, + 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 + }; + + static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + }; + + const __m128i mask1_0 = _mm_loadu_si128((const __m128i*)k_mask1); + const __m128i mask1_1 = _mm_loadu_si128((const __m128i*)k_mask1 + 1); + const __m128i mask2_0 = _mm_loadu_si128((const __m128i*)k_mask2); + const __m128i mask2_1 = _mm_loadu_si128((const __m128i*)k_mask2 + 1); + + const __m128i idx_mul_0 = _mm_set_epi32(32, 64, 128, 256); + const __m128i idx_mul_1 = _mm_set_epi32(2, 4, 8, 16); + const __m128i idx_mask = _mm_set1_epi32(256); + + typedef union { + __m128i vec[4]; + uint32_t index[16]; + } index_t; + + index_t idx; + + __m256 accumf = _mm256_setzero_ps(); + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint8_t * restrict qs = x[i].qs; + const uint8_t * restrict qh = x[i].qh; + const uint16_t * restrict signs = (const uint16_t *)x[i].signs; + const int8_t * restrict q8 = y[i].qs; + __m128i sumi1_0 = _mm_setzero_si128(); + __m128i sumi1_1 = _mm_setzero_si128(); + __m128i sumi2_0 = _mm_setzero_si128(); + __m128i sumi2_1 = _mm_setzero_si128(); + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i qs_tmp = _mm_loadu_si128((const __m128i *)qs); + const __m128i idx_l_0 = _mm_cvtepu8_epi16(qs_tmp); + const __m128i idx_l_1 = _mm_cvtepu8_epi16(_mm_srli_si128(qs_tmp, 8)); qs += 16; + idx.vec[0] = _mm_set1_epi32(qh[ib32+0]); + idx.vec[1] = idx.vec[0]; + idx.vec[2] = _mm_set1_epi32(qh[ib32+1]); + idx.vec[3] = idx.vec[2]; + + idx.vec[0] = _mm_and_si128(_mm_mullo_epi32(idx.vec[0], idx_mul_0), idx_mask); + idx.vec[1] = _mm_and_si128(_mm_mullo_epi32(idx.vec[1], idx_mul_1), idx_mask); + idx.vec[2] = _mm_and_si128(_mm_mullo_epi32(idx.vec[2], idx_mul_0), idx_mask); + idx.vec[3] = _mm_and_si128(_mm_mullo_epi32(idx.vec[3], idx_mul_1), idx_mask); + + idx.vec[0] = _mm_or_si128(idx.vec[0], _mm_cvtepi16_epi32(idx_l_0)); + idx.vec[1] = _mm_or_si128(idx.vec[1], _mm_cvtepi16_epi32(_mm_srli_si128(idx_l_0, 8))); + idx.vec[2] = _mm_or_si128(idx.vec[2], _mm_cvtepi16_epi32(idx_l_1)); + idx.vec[3] = _mm_or_si128(idx.vec[3], _mm_cvtepi16_epi32(_mm_srli_si128(idx_l_1, 8))); + + const __m128i q2_1_0 = _mm_set_epi32(iq3s_grid[idx.index[3]], iq3s_grid[idx.index[2]], iq3s_grid[idx.index[1]], iq3s_grid[idx.index[0]]); + const __m128i q2_1_1 = _mm_set_epi32(iq3s_grid[idx.index[7]], iq3s_grid[idx.index[6]], iq3s_grid[idx.index[5]], iq3s_grid[idx.index[4]]); + const __m128i q2_2_0 = _mm_set_epi32(iq3s_grid[idx.index[11]], iq3s_grid[idx.index[10]], iq3s_grid[idx.index[9]], iq3s_grid[idx.index[8]]); + const __m128i q2_2_1 = _mm_set_epi32(iq3s_grid[idx.index[15]], iq3s_grid[idx.index[14]], iq3s_grid[idx.index[13]], iq3s_grid[idx.index[12]]); + + __m128i aux128_0 = _mm_set1_epi32(signs[0] | (signs[1] << 16)); + __m128i aux128_1 = aux128_0; + aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0); + aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1); + const __m128i s2_1_0 = _mm_cmpeq_epi8(aux128_0, mask2_0); + const __m128i s2_1_1 = _mm_cmpeq_epi8(aux128_1, mask2_1); + const __m128i q8s_1_0 = _mm_sub_epi8(_mm_xor_si128(s2_1_0, q8_1_0), s2_1_0); + const __m128i q8s_1_1 = _mm_sub_epi8(_mm_xor_si128(s2_1_1, q8_1_1), s2_1_1); + + aux128_0 = _mm_set1_epi32(signs[2] | (signs[3] << 16)); + aux128_1 = aux128_0; + aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0); + aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1); + const __m128i s2_2_0 = _mm_cmpeq_epi8(aux128_0, mask2_0); + const __m128i s2_2_1 = _mm_cmpeq_epi8(aux128_1, mask2_1); + const __m128i q8s_2_0 = _mm_sub_epi8(_mm_xor_si128(s2_2_0, q8_2_0), s2_2_0); + const __m128i q8s_2_1 = _mm_sub_epi8(_mm_xor_si128(s2_2_1, q8_2_1), s2_2_1); + + signs += 4; + + const __m128i dot1_0 = _mm_maddubs_epi16(q2_1_0, q8s_1_0); + const __m128i dot1_1 = _mm_maddubs_epi16(q2_1_1, q8s_1_1); + const __m128i dot2_0 = _mm_maddubs_epi16(q2_2_0, q8s_2_0); + const __m128i dot2_1 = _mm_maddubs_epi16(q2_2_1, q8s_2_1); + const uint16_t ls1 = x[i].scales[ib32/2] & 0xf; + const uint16_t ls2 = x[i].scales[ib32/2] >> 4; + const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(2*ls1+1)); + const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(2*ls1+1)); + const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(2*ls2+1)); + const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(2*ls2+1)); + sumi1_0 = _mm_add_epi32(sumi1_0, p1_0); + sumi1_1 = _mm_add_epi32(sumi1_1, p1_1); + sumi2_0 = _mm_add_epi32(sumi2_0, p2_0); + sumi2_1 = _mm_add_epi32(sumi2_1, p2_1); + } + + accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf); + + } + + *s = hsum_float_8(accumf); + #elif defined(__POWER9_VECTOR__) static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 @@ -10454,6 +10742,8 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * restrict s, size_t bs, const void * static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,}; + const vector int v0 = vec_splats((int32_t)0); + vector float vsumf0 = vec_splats(0.0f); vector float vsumf1 = vec_splats(0.0f); vector float vsumf2 = vec_splats(0.0f); @@ -10474,14 +10764,10 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * restrict s, size_t bs, const void * const uint8_t * restrict sc = x[i].scales; const int8_t * restrict q8 = y[i].qs; - vector signed int vsumi0 = vec_splats((int32_t)0); - vector signed int vsumi1 = vec_splats((int32_t)0); - vector signed int vsumi2 = vec_splats((int32_t)0); - vector signed int vsumi3 = vec_splats((int32_t)0); - vector signed int vsumi4 = vec_splats((int32_t)0); - vector signed int vsumi5 = vec_splats((int32_t)0); - vector signed int vsumi6 = vec_splats((int32_t)0); - vector signed int vsumi7 = vec_splats((int32_t)0); + vector signed int vsumi0 = v0; + vector signed int vsumi1 = v0; + vector signed int vsumi2 = v0; + vector signed int vsumi3 = v0; for (int j = 0; j < QK_K/32; j += 2) { __builtin_prefetch(q3, 0, 1); @@ -10535,21 +10821,12 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * restrict s, size_t bs, const void * vector signed short vscales01 = (vector signed short)vec_splats((uint16_t)(2*ls0+1)); vector signed short vscales23 = (vector signed short)vec_splats((uint16_t)(2*ls1+1)); - vsumi0 = vec_add(vec_mule(qv0, vscales01), vsumi0); - vsumi1 = vec_add(vec_mule(qv1, vscales01), vsumi1); - vsumi2 = vec_add(vec_mule(qv2, vscales23), vsumi2); - vsumi3 = vec_add(vec_mule(qv3, vscales23), vsumi3); - vsumi4 = vec_add(vec_mulo(qv0, vscales01), vsumi4); - vsumi5 = vec_add(vec_mulo(qv1, vscales01), vsumi5); - vsumi6 = vec_add(vec_mulo(qv2, vscales23), vsumi6); - vsumi7 = vec_add(vec_mulo(qv3, vscales23), vsumi7); + vsumi0 = vec_msum(qv0, vscales01, vsumi0); + vsumi1 = vec_msum(qv1, vscales01, vsumi1); + vsumi2 = vec_msum(qv2, vscales23, vsumi2); + vsumi3 = vec_msum(qv3, vscales23, vsumi3); } - vsumi0 = vec_add(vsumi0, vsumi4); - vsumi1 = vec_add(vsumi1, vsumi5); - vsumi2 = vec_add(vsumi2, vsumi6); - vsumi3 = vec_add(vsumi3, vsumi7); - vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); @@ -10695,6 +10972,14 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * restrict s, size_t bs, const void * } +#if defined(__AVX__) +static inline __m128i mul_add_epi8_sse(const __m128i x, const __m128i y) { + const __m128i ax = _mm_sign_epi8(x, x); + const __m128i sy = _mm_sign_epi8(y, x); + return _mm_maddubs_epi16(ax, sy); +} +#endif + #if defined(__AVX2__) static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) { const __m256i ax = _mm256_sign_epi8(x, x); @@ -10812,6 +11097,54 @@ void ggml_vec_dot_iq1_s_q8_K (int n, float * restrict s, size_t bs, const void *s = hsum_float_8(accum) + IQ1S_DELTA * accum1; +#elif defined __AVX__ + __m256 accum = _mm256_setzero_ps(); + float accum1 = 0; + for (int i = 0; i < nb; ++i) { + + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint16_t * qh = x[i].qh; + + __m128i sumi1_0 = _mm_setzero_si128(); + __m128i sumi1_1 = _mm_setzero_si128(); + int sumi1 = 0; + for (int ib = 0; ib < QK_K/32; ib += 2) { + const __m128i q1b_1_0 = _mm_set_epi64x(iq1s_grid[qs[1] | ((qh[ib+0] << 5) & 0x700)], iq1s_grid[qs[0] | ((qh[ib+0] << 8) & 0x700)]); + const __m128i q1b_1_1 = _mm_set_epi64x(iq1s_grid[qs[3] | ((qh[ib+0] >> 1) & 0x700)], iq1s_grid[qs[2] | ((qh[ib+0] << 2) & 0x700)]); + const __m128i q1b_2_0 = _mm_set_epi64x(iq1s_grid[qs[5] | ((qh[ib+1] << 5) & 0x700)], iq1s_grid[qs[4] | ((qh[ib+1] << 8) & 0x700)]); + const __m128i q1b_2_1 = _mm_set_epi64x(iq1s_grid[qs[7] | ((qh[ib+1] >> 1) & 0x700)], iq1s_grid[qs[6] | ((qh[ib+1] << 2) & 0x700)]); + qs += 8; + const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + + const __m128i dot1_0 = mul_add_epi8_sse(q1b_1_0, q8b_1_0); + const __m128i dot1_1 = mul_add_epi8_sse(q1b_1_1, q8b_1_1); + const __m128i dot2_0 = mul_add_epi8_sse(q1b_2_0, q8b_2_0); + const __m128i dot2_1 = mul_add_epi8_sse(q1b_2_1, q8b_2_1); + const int16_t ls1 = 2*((qh[ib+0] >> 12) & 7) + 1; + const int16_t ls2 = 2*((qh[ib+1] >> 12) & 7) + 1; + const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(ls1)); + const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(ls1)); + const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(ls2)); + const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(ls2)); + + sumi1_0 = _mm_add_epi32(sumi1_0, _mm_add_epi32(p1_0, p2_0)); + sumi1_1 = _mm_add_epi32(sumi1_1, _mm_add_epi32(p1_1, p2_1)); + sumi1 += (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]) * (qh[ib+0] & 0x8000 ? -1 : 1) * ls1 + + (y[i].bsums[2*ib+2] + y[i].bsums[2*ib+3]) * (qh[ib+1] & 0x8000 ? -1 : 1) * ls2; + } + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + accum = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(sumi1_1, sumi1_0))), accum); + accum1 += d * sumi1; + + } + + *s = hsum_float_8(accum) + IQ1S_DELTA * accum1; + #elif defined(__POWER9_VECTOR__) const vector unsigned char v0 = vec_splats((unsigned char)0x0); const vector unsigned short vsign = vec_splats((unsigned short)0x8000); @@ -10830,10 +11163,6 @@ void ggml_vec_dot_iq1_s_q8_K (int n, float * restrict s, size_t bs, const void vector signed int vsumi1 = vec_splats((int32_t)0); vector signed int vsumi2 = vec_splats((int32_t)0); vector signed int vsumi3 = vec_splats((int32_t)0); - vector signed int vsumi4 = vec_splats((int32_t)0); - vector signed int vsumi5 = vec_splats((int32_t)0); - vector signed int vsumi6 = vec_splats((int32_t)0); - vector signed int vsumi7 = vec_splats((int32_t)0); vector signed int vsumi8 = vec_splats((int32_t)0); const uint8_t * restrict q1 = x[i].qs; @@ -10875,14 +11204,10 @@ void ggml_vec_dot_iq1_s_q8_K (int n, float * restrict s, size_t bs, const void vector signed short vscales23 = (vector signed short)vec_splats((uint16_t)(2*ls1+1)); vector signed short vscales = vec_sld(vscales23, vscales01, 8); - vsumi0 = vec_add(vec_mule(qv0, vscales01), vsumi0); - vsumi1 = vec_add(vec_mule(qv1, vscales01), vsumi1); - vsumi2 = vec_add(vec_mule(qv2, vscales23), vsumi2); - vsumi3 = vec_add(vec_mule(qv3, vscales23), vsumi3); - vsumi4 = vec_add(vec_mulo(qv0, vscales01), vsumi4); - vsumi5 = vec_add(vec_mulo(qv1, vscales01), vsumi5); - vsumi6 = vec_add(vec_mulo(qv2, vscales23), vsumi6); - vsumi7 = vec_add(vec_mulo(qv3, vscales23), vsumi7); + vsumi0 = vec_msum(qv0, vscales01, vsumi0); + vsumi1 = vec_msum(qv1, vscales01, vsumi1); + vsumi2 = vec_msum(qv2, vscales23, vsumi2); + vsumi3 = vec_msum(qv3, vscales23, vsumi3); vector signed short q8ysums = vec_xl_len(qs, 8); qs += 4; @@ -10897,11 +11222,6 @@ void ggml_vec_dot_iq1_s_q8_K (int n, float * restrict s, size_t bs, const void vsumi8 = vec_add(vec_mule(q8ysum, vscales), vsumi8); } - vsumi0 = vec_add(vsumi0, vsumi4); - vsumi1 = vec_add(vsumi1, vsumi5); - vsumi2 = vec_add(vsumi2, vsumi6); - vsumi3 = vec_add(vsumi3, vsumi7); - vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); @@ -11163,6 +11483,92 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void *s = hsum_float_8(accum1) + IQ1M_DELTA * hsum_float_8(accum2); +#elif defined __AVX__ + const __m128i mask = _mm_set1_epi16(0x7); + const __m128i mone = _mm_set1_epi16(1); + + __m256 accum1 = _mm256_setzero_ps(); + __m256 accum2 = _mm256_setzero_ps(); + for (int i = 0; i < nb; ++i) { + + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint8_t * qh = x[i].qh; + const uint16_t * sc = (const uint16_t *)x[i].scales; + + scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); + + __m128i sumi1_0 = _mm_setzero_si128(); + __m128i sumi1_1 = _mm_setzero_si128(); + __m128i sumi2_0 = _mm_setzero_si128(); + __m128i sumi2_1 = _mm_setzero_si128(); + for (int ib = 0; ib < QK_K/32; ib += 2) { + const __m128i q1b_1_0 = _mm_set_epi64x( + iq1s_grid[qs[1] | (((uint16_t)qh[0] << 4) & 0x700)], iq1s_grid[qs[0] | (((uint16_t)qh[0] << 8) & 0x700)]); + const __m128i q1b_1_1 = _mm_set_epi64x( + iq1s_grid[qs[3] | (((uint16_t)qh[1] << 4) & 0x700)], iq1s_grid[qs[2] | (((uint16_t)qh[1] << 8) & 0x700)]); + const __m128i q1b_2_0 = _mm_set_epi64x( + iq1s_grid[qs[5] | (((uint16_t)qh[2] << 4) & 0x700)], iq1s_grid[qs[4] | (((uint16_t)qh[2] << 8) & 0x700)]); + const __m128i q1b_2_1 = _mm_set_epi64x( + iq1s_grid[qs[7] | (((uint16_t)qh[3] << 4) & 0x700)], iq1s_grid[qs[6] | (((uint16_t)qh[3] << 8) & 0x700)]); + const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + + const __m128i dot1_0 = mul_add_epi8_sse(q1b_1_0, q8b_1_0); + const __m128i dot1_1 = mul_add_epi8_sse(q1b_1_1, q8b_1_1); + const __m128i dot2_0 = mul_add_epi8_sse(q1b_2_0, q8b_2_0); + const __m128i dot2_1 = mul_add_epi8_sse(q1b_2_1, q8b_2_1); + + const __m128i delta1_0 = _mm_set_epi64x(qh[0] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101, + qh[0] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101); + const __m128i delta1_1 = _mm_set_epi64x(qh[1] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101, + qh[1] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101); + const __m128i delta2_0 = _mm_set_epi64x(qh[2] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101, + qh[2] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101); + const __m128i delta2_1 = _mm_set_epi64x(qh[3] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101, + qh[3] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101); + + const __m128i dot3_0 = mul_add_epi8_sse(delta1_0, q8b_1_0); + const __m128i dot3_1 = mul_add_epi8_sse(delta1_1, q8b_1_1); + const __m128i dot4_0 = mul_add_epi8_sse(delta2_0, q8b_2_0); + const __m128i dot4_1 = mul_add_epi8_sse(delta2_1, q8b_2_1); + + __m128i scale1_0 = _mm_set1_epi16(sc[ib/2] >> 0); + __m128i scale1_1 = _mm_set1_epi16(sc[ib/2] >> 3); + __m128i scale2_0 = _mm_set1_epi16(sc[ib/2] >> 6); + __m128i scale2_1 = _mm_set1_epi16(sc[ib/2] >> 9); + + scale1_0 = _mm_add_epi16(_mm_slli_epi16(_mm_and_si128(scale1_0, mask), 1), mone); + scale1_1 = _mm_add_epi16(_mm_slli_epi16(_mm_and_si128(scale1_1, mask), 1), mone); + scale2_0 = _mm_add_epi16(_mm_slli_epi16(_mm_and_si128(scale2_0, mask), 1), mone); + scale2_1 = _mm_add_epi16(_mm_slli_epi16(_mm_and_si128(scale2_1, mask), 1), mone); + const __m128i p1_0 = _mm_madd_epi16(dot1_0, scale1_0); + const __m128i p1_1 = _mm_madd_epi16(dot1_1, scale1_1); + const __m128i p2_0 = _mm_madd_epi16(dot2_0, scale2_0); + const __m128i p2_1 = _mm_madd_epi16(dot2_1, scale2_1); + const __m128i p3_0 = _mm_madd_epi16(dot3_0, scale1_0); + const __m128i p3_1 = _mm_madd_epi16(dot3_1, scale1_1); + const __m128i p4_0 = _mm_madd_epi16(dot4_0, scale2_0); + const __m128i p4_1 = _mm_madd_epi16(dot4_1, scale2_1); + + sumi1_0 = _mm_add_epi32(sumi1_0, _mm_add_epi32(p1_0, p2_0)); + sumi1_1 = _mm_add_epi32(sumi1_1, _mm_add_epi32(p1_1, p2_1)); + sumi2_0 = _mm_add_epi32(sumi2_0, _mm_add_epi32(p3_0, p4_0)); + sumi2_1 = _mm_add_epi32(sumi2_1, _mm_add_epi32(p3_1, p4_1)); + + qs += 8; qh += 4; + } + + const __m256 d = _mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(scale.f16)); + + accum1 = _mm256_add_ps(_mm256_mul_ps(d, _mm256_cvtepi32_ps(MM256_SET_M128I(sumi1_1, sumi1_0))), accum1); + accum2 = _mm256_add_ps(_mm256_mul_ps(d, _mm256_cvtepi32_ps(MM256_SET_M128I(sumi2_1, sumi2_0))), accum2); + } + + *s = hsum_float_8(accum1) + IQ1M_DELTA * hsum_float_8(accum2); + #else int sum1[2], sum2[2], delta[4]; @@ -11227,6 +11633,9 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void * const int nb = n / QK4_NL; + int ib = 0; + float sumf = 0; + #if defined __ARM_NEON const int8x16_t values = vld1q_s8(kvalues_iq4nl); const uint8x16_t m4b = vdupq_n_u8(0x0f); @@ -11235,16 +11644,14 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void * int8x16x4_t q8b; int32x4_t prod_1, prod_2; - float sumf = 0; + for (; ib + 1 < nb; ib += 2) { - for (int ib = 0; ib < nb; ib += 2) { - - q4bits.val[0] = vld1q_u8(x[ib+0].qs); - q4bits.val[1] = vld1q_u8(x[ib+1].qs); - q8b.val[0] = vld1q_s8(y[ib+0].qs); - q8b.val[1] = vld1q_s8(y[ib+0].qs + 16); - q8b.val[2] = vld1q_s8(y[ib+1].qs); - q8b.val[3] = vld1q_s8(y[ib+1].qs + 16); + q4bits.val[0] = vld1q_u8(x[ib + 0].qs); + q4bits.val[1] = vld1q_u8(x[ib + 1].qs); + q8b.val[0] = vld1q_s8(y[ib + 0].qs); + q8b.val[1] = vld1q_s8(y[ib + 0].qs + 16); + q8b.val[2] = vld1q_s8(y[ib + 1].qs); + q8b.val[3] = vld1q_s8(y[ib + 1].qs + 16); q4b.val[0] = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[0], m4b)); q4b.val[1] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[0], 4)); @@ -11255,12 +11662,10 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void * prod_2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[2], q8b.val[2]), q4b.val[3], q8b.val[3]); sumf += - GGML_FP16_TO_FP32(x[ib+0].d) * GGML_FP16_TO_FP32(y[ib+0].d) * vaddvq_s32(prod_1) + - GGML_FP16_TO_FP32(x[ib+1].d) * GGML_FP16_TO_FP32(y[ib+1].d) * vaddvq_s32(prod_2); + GGML_FP16_TO_FP32(x[ib+0].d) * GGML_FP16_TO_FP32(y[ib + 0].d) * vaddvq_s32(prod_1) + + GGML_FP16_TO_FP32(x[ib+1].d) * GGML_FP16_TO_FP32(y[ib + 1].d) * vaddvq_s32(prod_2); } - *s = sumf; - #elif defined __AVX2__ const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl); @@ -11269,11 +11674,11 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void * __m256 accum1 = _mm256_setzero_ps(); __m256 accum2 = _mm256_setzero_ps(); - for (int ib = 0; ib < nb; ib += 2) { - const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)x[0].qs); - const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)x[1].qs); - const __m256i q8b_1 = _mm256_loadu_si256((const __m256i *)y[0].qs); - const __m256i q8b_2 = _mm256_loadu_si256((const __m256i *)y[1].qs); + for (; ib + 1 < nb; ib += 2) { + const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)x[ib + 0].qs); + const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)x[ib + 1].qs); + const __m256i q8b_1 = _mm256_loadu_si256((const __m256i *)y[ib + 0].qs); + const __m256i q8b_2 = _mm256_loadu_si256((const __m256i *)y[ib + 1].qs); const __m256i q4b_1 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)), _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b))); const __m256i q4b_2 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)), @@ -11282,19 +11687,52 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void * const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2); const __m256i p_1 = _mm256_madd_epi16(p16_1, mone); const __m256i p_2 = _mm256_madd_epi16(p16_2, mone); - accum1 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[0].d)*GGML_FP16_TO_FP32(x[0].d)), + accum1 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[ib + 0].d)*GGML_FP16_TO_FP32(x[ib + 0].d)), _mm256_cvtepi32_ps(p_1), accum1); - accum2 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[1].d)*GGML_FP16_TO_FP32(x[1].d)), + accum2 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[ib + 1].d)*GGML_FP16_TO_FP32(x[ib + 1].d)), _mm256_cvtepi32_ps(p_2), accum2); - - y += 2; - x += 2; } - *s = hsum_float_8(_mm256_add_ps(accum1, accum2)); + sumf = hsum_float_8(_mm256_add_ps(accum1, accum2)); + +#elif defined __AVX__ + const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl); + const __m128i m4b = _mm_set1_epi8(0x0f); + const __m128i mone = _mm_set1_epi16(1); + + __m256 accum1 = _mm256_setzero_ps(); + __m256 accum2 = _mm256_setzero_ps(); + for (; ib + 1 < nb; ib += 2) { + const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)x[ib + 0].qs); + const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs); + const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs); + const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs + 1); + const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs); + const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs + 1); + + const __m128i q4b_1_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b)); + const __m128i q4b_1_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)); + const __m128i q4b_2_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b)); + const __m128i q4b_2_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)); + const __m128i p16_1_0 = mul_add_epi8_sse(q4b_1_0, q8b_1_0); + const __m128i p16_1_1 = mul_add_epi8_sse(q4b_1_1, q8b_1_1); + const __m128i p16_2_0 = mul_add_epi8_sse(q4b_2_0, q8b_2_0); + const __m128i p16_2_1 = mul_add_epi8_sse(q4b_2_1, q8b_2_1); + const __m128i p_1_0 = _mm_madd_epi16(p16_1_0, mone); + const __m128i p_1_1 = _mm_madd_epi16(p16_1_1, mone); + const __m128i p_2_0 = _mm_madd_epi16(p16_2_0, mone); + const __m128i p_2_1 = _mm_madd_epi16(p16_2_1, mone); + accum1 = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[ib + 0].d)*GGML_FP16_TO_FP32(x[ib + 0].d)), + _mm256_cvtepi32_ps(MM256_SET_M128I(p_1_1, p_1_0))), accum1); + accum2 = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[ib + 1].d)*GGML_FP16_TO_FP32(x[ib + 1].d)), + _mm256_cvtepi32_ps(MM256_SET_M128I(p_2_1, p_2_0))), accum2); + } + + sumf = hsum_float_8(_mm256_add_ps(accum1, accum2)); #elif defined(__POWER9_VECTOR__) const vector signed char lowMask = vec_splats((signed char)0xF); + const vector signed int v0 = vec_splats((int32_t)0); const vector unsigned char v4 = vec_splats((unsigned char)0x4); vector float vsumf0 = vec_splats(0.0f); @@ -11303,7 +11741,7 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void * const vector signed char values = vec_xl( 0, kvalues_iq4nl); #pragma GCC unroll 4 - for (int ib = 0; ib < nb; ++ib) { + for (; ib < nb; ++ib) { __builtin_prefetch(x[ib].qs, 0, 1); __builtin_prefetch(y[ib].qs, 0, 1); @@ -11325,8 +11763,11 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void * vector signed short qv0 = vec_add(vec_mule(q4x0, q8y0), vec_mulo(q4x0, q8y0)); vector signed short qv1 = vec_add(vec_mule(q4x1, q8y1), vec_mulo(q4x1, q8y1)); - vector signed int vsumi0 = vec_add(vec_unpackh(qv0), vec_unpackl(qv0)); - vector signed int vsumi1 = vec_add(vec_unpackh(qv1), vec_unpackl(qv1)); + vector signed int vsumi0 = v0; + vector signed int vsumi1 = v0; + + vsumi0 = vec_sum4s(qv0, vsumi0); + vsumi1 = vec_sum4s(qv1, vsumi1); vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); @@ -11337,7 +11778,7 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void * vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); - *s = vec_extract(vsumf0, 0); + sumf = vec_extract(vsumf0, 0); #elif defined (__loongarch_asx) @@ -11347,11 +11788,11 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void * __m256 accum1 = (__m256)__lasx_xvldi(0); __m256 accum2 = (__m256)__lasx_xvldi(0); - for (int ib = 0; ib < nb; ib += 2) { - const __m128i q4bits_1 = __lsx_vld((const __m128i*)x[0].qs, 0); - const __m128i q4bits_2 = __lsx_vld((const __m128i*)x[1].qs, 0); - const __m256i q8b_1 = __lasx_xvld((const __m256i *)y[0].qs, 0); - const __m256i q8b_2 = __lasx_xvld((const __m256i *)y[1].qs, 0); + for (; ib + 1 < nb; ib += 2) { + const __m128i q4bits_1 = __lsx_vld((const __m128i*)x[ib + 0].qs, 0); + const __m128i q4bits_2 = __lsx_vld((const __m128i*)x[ib + 1].qs, 0); + const __m256i q8b_1 = __lasx_xvld((const __m256i *)y[ib + 0].qs, 0); + const __m256i q8b_2 = __lasx_xvld((const __m256i *)y[ib + 1].qs, 0); const __m256i q4b_1 = lasx_insertf128(lsx_shuffle_b(values128, __lsx_vand_v(__lsx_vsrli_h(q4bits_1, 4), m4b)), lsx_shuffle_b(values128, __lsx_vand_v(q4bits_1, m4b))); const __m256i q4b_2 = lasx_insertf128(lsx_shuffle_b(values128, __lsx_vand_v(__lsx_vsrli_h(q4bits_2, 4), m4b)), @@ -11360,20 +11801,16 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void * const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2); const __m256i p_1 = lasx_madd_h(p16_1, mone); const __m256i p_2 = lasx_madd_h(p16_2, mone); - accum1 = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(y[0].d)*GGML_FP16_TO_FP32(x[0].d)), + accum1 = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(y[ib + 0].d)*GGML_FP16_TO_FP32(x[ib + 0].d)), __lasx_xvffint_s_w(p_1), accum1); - accum2 = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(y[1].d)*GGML_FP16_TO_FP32(x[1].d)), + accum2 = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(y[ib + 1].d)*GGML_FP16_TO_FP32(x[ib + 1].d)), __lasx_xvffint_s_w(p_2), accum2); - - y += 2; - x += 2; } - *s = hsum_float_8(__lasx_xvfadd_s(accum1, accum2)); + sumf = hsum_float_8(__lasx_xvfadd_s(accum1, accum2)); -#else - float sumf = 0; - for (int ib = 0; ib < nb; ++ib) { +#endif + for (; ib < nb; ++ib) { const float d = GGML_FP16_TO_FP32(y[ib].d)*GGML_FP16_TO_FP32(x[ib].d); int sumi1 = 0, sumi2 = 0; for (int j = 0; j < QK4_NL/2; ++j) { @@ -11383,7 +11820,6 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void * sumf += d * (sumi1 + sumi2); } *s = sumf; -#endif } void ggml_vec_dot_iq4_xs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { @@ -11479,8 +11915,57 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * restrict s, size_t bs, const void * *s = hsum_float_8(accum); +#elif defined __AVX__ + const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl); + const __m128i m4b = _mm_set1_epi8(0x0f); + + __m256 accum = _mm256_setzero_ps(); + for (int ibl = 0; ibl < nb; ++ibl) { + const uint8_t * qs = x[ibl].qs; + const int8_t * q8 = y[ibl].qs; + uint16_t sh = x[ibl].scales_h; + __m128i sumi1_0 = _mm_setzero_si128(); + __m128i sumi1_1 = _mm_setzero_si128(); + __m128i sumi2_0 = _mm_setzero_si128(); + __m128i sumi2_1 = _mm_setzero_si128(); + for (int ib = 0; ib < QK_K/32; ib += 2) { + const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)qs); qs += 16; + const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)qs); qs += 16; + const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q4b_1_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b)); + const __m128i q4b_1_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)); + const __m128i q4b_2_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b)); + const __m128i q4b_2_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)); + const __m128i p16_1_0 = mul_add_epi8_sse(q4b_1_0, q8b_1_0); + const __m128i p16_1_1 = mul_add_epi8_sse(q4b_1_1, q8b_1_1); + const __m128i p16_2_0 = mul_add_epi8_sse(q4b_2_0, q8b_2_0); + const __m128i p16_2_1 = mul_add_epi8_sse(q4b_2_1, q8b_2_1); + const int16_t ls1 = ((x[ibl].scales_l[ib/2] & 0xf) | ((sh << 4) & 0x30)) - 32; + const int16_t ls2 = ((x[ibl].scales_l[ib/2] >> 4) | ((sh << 2) & 0x30)) - 32; + sh >>= 4; + const __m128i p_1_0 = _mm_madd_epi16(p16_1_0, _mm_set1_epi16(ls1)); + const __m128i p_1_1 = _mm_madd_epi16(p16_1_1, _mm_set1_epi16(ls1)); + const __m128i p_2_0 = _mm_madd_epi16(p16_2_0, _mm_set1_epi16(ls2)); + const __m128i p_2_1 = _mm_madd_epi16(p16_2_1, _mm_set1_epi16(ls2)); + sumi1_0 = _mm_add_epi32(p_1_0, sumi1_0); + sumi1_1 = _mm_add_epi32(p_1_1, sumi1_1); + sumi2_0 = _mm_add_epi32(p_2_0, sumi2_0); + sumi2_1 = _mm_add_epi32(p_2_1, sumi2_1); + } + __m128i sumi12_0 = _mm_add_epi32(sumi1_0, sumi2_0); + __m128i sumi12_1 = _mm_add_epi32(sumi1_1, sumi2_1); + accum = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(x[ibl].d)*y[ibl].d), + _mm256_cvtepi32_ps(MM256_SET_M128I(sumi12_1, sumi12_0))), accum); + } + + *s = hsum_float_8(accum); + #elif defined(__POWER9_VECTOR__) const vector signed char lowMask = vec_splats((signed char)0xF); + const vector int v0 = vec_splats((int32_t)0); const vector unsigned char v4 = vec_splats((unsigned char)0x4); vector float vsumf0 = vec_splats(0.0f); @@ -11496,14 +11981,10 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * restrict s, size_t bs, const void * vector float vyd = vec_splats(y[ibl].d); vector float vd = vec_mul(vxd, vyd); - vector signed int vsumi0 = vec_splats((int32_t)0); - vector signed int vsumi1 = vec_splats((int32_t)0); - vector signed int vsumi2 = vec_splats((int32_t)0); - vector signed int vsumi3 = vec_splats((int32_t)0); - vector signed int vsumi4 = vec_splats((int32_t)0); - vector signed int vsumi5 = vec_splats((int32_t)0); - vector signed int vsumi6 = vec_splats((int32_t)0); - vector signed int vsumi7 = vec_splats((int32_t)0); + vector signed int vsumi0 = v0; + vector signed int vsumi1 = v0; + vector signed int vsumi2 = v0; + vector signed int vsumi3 = v0; uint16_t h = x[ibl].scales_h; @@ -11548,21 +12029,12 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * restrict s, size_t bs, const void * vector signed short vscales01 = vec_splats((int16_t)ls0); vector signed short vscales23 = vec_splats((int16_t)ls1); - vsumi0 = vec_add(vec_mule(qv0, vscales01), vsumi0); - vsumi1 = vec_add(vec_mule(qv1, vscales01), vsumi1); - vsumi2 = vec_add(vec_mule(qv2, vscales23), vsumi2); - vsumi3 = vec_add(vec_mule(qv3, vscales23), vsumi3); - vsumi4 = vec_add(vec_mulo(qv0, vscales01), vsumi4); - vsumi5 = vec_add(vec_mulo(qv1, vscales01), vsumi5); - vsumi6 = vec_add(vec_mulo(qv2, vscales23), vsumi6); - vsumi7 = vec_add(vec_mulo(qv3, vscales23), vsumi7); + vsumi0 = vec_msum(qv0, vscales01, vsumi0); + vsumi1 = vec_msum(qv1, vscales01, vsumi1); + vsumi2 = vec_msum(qv2, vscales23, vsumi2); + vsumi3 = vec_msum(qv3, vscales23, vsumi3); } - vsumi0 = vec_add(vsumi0, vsumi4); - vsumi1 = vec_add(vsumi1, vsumi5); - vsumi2 = vec_add(vsumi2, vsumi6); - vsumi3 = vec_add(vsumi3, vsumi7); - vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); @@ -12258,7 +12730,7 @@ static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict printf("Oops: found point %u not on grid:", u); for (int i = 0; i < 8; ++i) printf(" %d", L[8*k+i]); printf("\n"); - GGML_ASSERT(false); + GGML_ABORT("fatal error"); } q2[2*ib+0] |= ((uint32_t) grid_index << 8*k); q2[2*ib+1] |= (block_signs[k] << 7*k); @@ -12437,7 +12909,7 @@ static void quantize_row_iq2_xs_impl(const float * restrict x, void * restrict v printf("Oops: found point %u not on grid:", u); for (int i = 0; i < 8; ++i) printf(" %d", L[8*k+i]); printf("\n"); - GGML_ASSERT(false); + GGML_ABORT("fatal error"); } q2[2*ib+k] = grid_index | (block_signs[k] << 9); } @@ -12880,7 +13352,7 @@ static void quantize_row_iq3_xxs_impl(int grid_size, const float * restrict x, v printf("Oops: found point %u not on grid:", u); for (int i = 0; i < 4; ++i) printf(" %d", L[4*k+i]); printf("\n"); - GGML_ASSERT(false); + GGML_ABORT("fatal error"); } if (grid_size == 256) { q3[8*ib+k] = grid_index; @@ -12934,10 +13406,10 @@ size_t quantize_iq3_xxs(const float * restrict src, void * restrict dst, int64_t void quantize_row_iq3_xxs(const float * restrict x, void * restrict vy, int64_t k) { assert(k % QK_K == 0); block_iq3_xxs * restrict y = vy; - quantize_row_iq3_xxs_reference(x, y, k); + quantize_row_iq3_xxs_ref(x, y, k); } -void quantize_row_iq3_xxs_reference(const float * restrict x, block_iq3_xxs * restrict y, int64_t k) { +void quantize_row_iq3_xxs_ref(const float * restrict x, block_iq3_xxs * restrict y, int64_t k) { assert(k % QK_K == 0); quantize_row_iq3_xxs_impl(256, x, y, k, NULL); } @@ -13093,7 +13565,7 @@ static void quantize_row_iq3_s_impl(int block_size, const float * restrict x, vo printf("Oops: found point %u not on grid:", u); for (int i = 0; i < 4; ++i) printf(" %d", L[4*k+i]); printf("\n"); - GGML_ASSERT(false); + GGML_ABORT("fatal error"); } qs[k] = grid_index & 255; qh[(ib*bs4+k)/8] |= ((grid_index >> 8) << ((ib*bs4+k)%8)); @@ -13150,10 +13622,10 @@ size_t quantize_iq3_s(const float * restrict src, void * restrict dst, int64_t n void quantize_row_iq3_s(const float * restrict x, void * restrict vy, int64_t k) { assert(k % QK_K == 0); block_iq3_s * restrict y = vy; - quantize_row_iq3_s_reference(x, y, k); + quantize_row_iq3_s_ref(x, y, k); } -void quantize_row_iq3_s_reference(const float * restrict x, block_iq3_s * restrict y, int64_t k) { +void quantize_row_iq3_s_ref(const float * restrict x, block_iq3_s * restrict y, int64_t k) { assert(k % QK_K == 0); quantize_iq3_s(x, y, 1, k, NULL); } @@ -13165,7 +13637,7 @@ static int iq1_find_best_neighbour(const uint16_t * restrict neighbours, const u const float * restrict xval, const float * restrict weight, float * scale, int8_t * restrict L, int ngrid) { int num_neighbors = neighbours[0]; GGML_ASSERT(num_neighbors > 0); - float best_score = 0; + float best_score = -FLT_MAX; int grid_index = -1; for (int j = 1; j <= num_neighbors; ++j) { const int8_t * pg = (const int8_t *)(grid + neighbours[j]); @@ -13363,7 +13835,7 @@ static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy sumw[j+1] = sumw[j] + weight[i]; } } - float best_score = 0, scale = max; + float best_score = -FLT_MIN, scale = max; int besti1 = -1, besti2 = -1, best_shift = 0; for (int i1 = 0; i1 <= block_size; ++i1) { for (int i2 = i1; i2 <= block_size; ++i2) { @@ -13539,7 +14011,7 @@ static void quantize_row_iq1_m_impl(const float * restrict x, void * restrict vy idx[2*j] = j; } qsort(pairs, block_size, 2*sizeof(float), iq1_sort_helper); - float best_score = 0, scale = max; + float best_score = -FLT_MIN, scale = max; int besti1 = -1, besti2 = -1, best_k = -1; // 0: +, + // 1: +, - @@ -13891,7 +14363,7 @@ void quantize_row_iq4_nl(const float * restrict x, void * restrict vy, int64_t k } } -void quantize_row_iq4_nl_reference(const float * restrict x, block_iq4_nl * restrict y, int64_t k) { +void quantize_row_iq4_nl_ref(const float * restrict x, block_iq4_nl * restrict y, int64_t k) { assert(k % QK4_NL == 0); quantize_row_iq4_nl(x, y, k); } @@ -13919,10 +14391,10 @@ size_t quantize_iq4_xs(const float * restrict src, void * restrict dst, int64_t void quantize_row_iq4_xs(const float * restrict x, void * restrict vy, int64_t k) { assert(k % QK_K == 0); block_iq4_xs * restrict y = vy; - quantize_row_iq4_xs_reference(x, y, k); + quantize_row_iq4_xs_ref(x, y, k); } -void quantize_row_iq4_xs_reference(const float * restrict x, block_iq4_xs * restrict y, int64_t k) { +void quantize_row_iq4_xs_ref(const float * restrict x, block_iq4_xs * restrict y, int64_t k) { assert(k % QK_K == 0); quantize_iq4_xs(x, y, 1, k, NULL); } @@ -14069,7 +14541,7 @@ static void quantize_row_iq2_s_impl(const float * restrict x, void * restrict vy printf("Oops: found point %u not on grid:", u); for (int i = 0; i < 8; ++i) printf(" %d", L[8*k+i]); printf("\n"); - GGML_ASSERT(false); + GGML_ABORT("fatal error"); } const int i8 = 2*ib + k; y[ibl].qs[i8] = grid_index & 255; @@ -14109,7 +14581,7 @@ size_t quantize_iq2_s(const float * restrict src, void * restrict dst, int64_t n return nrow * nblock * sizeof(block_iq2_s); } -void quantize_row_iq2_s_reference(const float * restrict x, block_iq2_s * restrict y, int64_t k) { +void quantize_row_iq2_s_ref(const float * restrict x, block_iq2_s * restrict y, int64_t k) { assert(k % QK_K == 0); quantize_iq2_s(x, y, 1, k, NULL); } @@ -14117,7 +14589,7 @@ void quantize_row_iq2_s_reference(const float * restrict x, block_iq2_s * restri void quantize_row_iq2_s(const float * restrict x, void * restrict vy, int64_t k) { assert(k % QK_K == 0); block_iq2_s * restrict y = vy; - quantize_row_iq2_s_reference(x, y, k); + quantize_row_iq2_s_ref(x, y, k); } static bool validate_float(float f, size_t i) { @@ -14172,6 +14644,16 @@ static bool validate_fp16(ggml_fp16_t f, size_t i) { } \ } +#define VALIDATE_ROW_DATA_DVEC_F16_IMPL(type, data, nb, nr) \ + const type * q = (const type *) (data); \ + for (size_t i = 0; i < (nb); ++i) { \ + for (size_t j = 0; j < (nr); ++j) { \ + if (!validate_fp16(q[i].d[j], i)) { \ + return false; \ + } \ + } \ + } + bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbytes) { if (type < 0 || type >= GGML_TYPE_COUNT) { fprintf(stderr, "%s: invalid type %d\n", __func__, type); @@ -14179,7 +14661,7 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte } if (nbytes % ggml_type_size(type) != 0) { - fprintf(stderr, "%s: invalid size %zu for type %d\n", __func__, nbytes, type); + fprintf(stderr, "%s: invalid size %zu for type %s (type size = %zu)\n", __func__, nbytes, ggml_type_name(type), ggml_type_size(type)); return false; } @@ -14389,6 +14871,16 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte { VALIDATE_ROW_DATA_D_F16_IMPL(block_iq4_nl, data, nb); } break; + case GGML_TYPE_Q4_0_4_4: + case GGML_TYPE_Q4_0_4_8: + { + VALIDATE_ROW_DATA_DVEC_F16_IMPL(block_q4_0x4, data, nbytes / sizeof(block_q4_0x4), 4); + } break; + case GGML_TYPE_Q4_0_8_8: + { + VALIDATE_ROW_DATA_DVEC_F16_IMPL(block_q4_0x8, data, nbytes / sizeof(block_q4_0x8), 8); + } break; + case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: diff --git a/llama/ggml-quants.h b/llama/ggml-quants.h index c9a9b732..39ece43c 100644 --- a/llama/ggml-quants.h +++ b/llama/ggml-quants.h @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * @@ -38,25 +38,25 @@ extern "C" { #endif // Quantization -void quantize_row_q4_0_reference(const float * GGML_RESTRICT x, block_q4_0 * GGML_RESTRICT y, int64_t k); -void quantize_row_q4_1_reference(const float * GGML_RESTRICT x, block_q4_1 * GGML_RESTRICT y, int64_t k); -void quantize_row_q5_0_reference(const float * GGML_RESTRICT x, block_q5_0 * GGML_RESTRICT y, int64_t k); -void quantize_row_q5_1_reference(const float * GGML_RESTRICT x, block_q5_1 * GGML_RESTRICT y, int64_t k); -void quantize_row_q8_0_reference(const float * GGML_RESTRICT x, block_q8_0 * GGML_RESTRICT y, int64_t k); -void quantize_row_q8_1_reference(const float * GGML_RESTRICT x, block_q8_1 * GGML_RESTRICT y, int64_t k); +void quantize_row_q4_0_ref(const float * GGML_RESTRICT x, block_q4_0 * GGML_RESTRICT y, int64_t k); +void quantize_row_q4_1_ref(const float * GGML_RESTRICT x, block_q4_1 * GGML_RESTRICT y, int64_t k); +void quantize_row_q5_0_ref(const float * GGML_RESTRICT x, block_q5_0 * GGML_RESTRICT y, int64_t k); +void quantize_row_q5_1_ref(const float * GGML_RESTRICT x, block_q5_1 * GGML_RESTRICT y, int64_t k); +void quantize_row_q8_0_ref(const float * GGML_RESTRICT x, block_q8_0 * GGML_RESTRICT y, int64_t k); +void quantize_row_q8_1_ref(const float * GGML_RESTRICT x, block_q8_1 * GGML_RESTRICT y, int64_t k); -void quantize_row_q2_K_reference(const float * GGML_RESTRICT x, block_q2_K * GGML_RESTRICT y, int64_t k); -void quantize_row_q3_K_reference(const float * GGML_RESTRICT x, block_q3_K * GGML_RESTRICT y, int64_t k); -void quantize_row_q4_K_reference(const float * GGML_RESTRICT x, block_q4_K * GGML_RESTRICT y, int64_t k); -void quantize_row_q5_K_reference(const float * GGML_RESTRICT x, block_q5_K * GGML_RESTRICT y, int64_t k); -void quantize_row_q6_K_reference(const float * GGML_RESTRICT x, block_q6_K * GGML_RESTRICT y, int64_t k); -void quantize_row_q8_K_reference(const float * GGML_RESTRICT x, block_q8_K * GGML_RESTRICT y, int64_t k); +void quantize_row_q2_K_ref(const float * GGML_RESTRICT x, block_q2_K * GGML_RESTRICT y, int64_t k); +void quantize_row_q3_K_ref(const float * GGML_RESTRICT x, block_q3_K * GGML_RESTRICT y, int64_t k); +void quantize_row_q4_K_ref(const float * GGML_RESTRICT x, block_q4_K * GGML_RESTRICT y, int64_t k); +void quantize_row_q5_K_ref(const float * GGML_RESTRICT x, block_q5_K * GGML_RESTRICT y, int64_t k); +void quantize_row_q6_K_ref(const float * GGML_RESTRICT x, block_q6_K * GGML_RESTRICT y, int64_t k); +void quantize_row_q8_K_ref(const float * GGML_RESTRICT x, block_q8_K * GGML_RESTRICT y, int64_t k); -void quantize_row_iq3_xxs_reference(const float * GGML_RESTRICT x, block_iq3_xxs * GGML_RESTRICT y, int64_t k); -void quantize_row_iq4_nl_reference (const float * GGML_RESTRICT x, block_iq4_nl * GGML_RESTRICT y, int64_t k); -void quantize_row_iq4_xs_reference (const float * GGML_RESTRICT x, block_iq4_xs * GGML_RESTRICT y, int64_t k); -void quantize_row_iq3_s_reference (const float * GGML_RESTRICT x, block_iq3_s * GGML_RESTRICT y, int64_t k); -void quantize_row_iq2_s_reference (const float * GGML_RESTRICT x, block_iq2_s * GGML_RESTRICT y, int64_t k); +void quantize_row_iq3_xxs_ref(const float * GGML_RESTRICT x, block_iq3_xxs * GGML_RESTRICT y, int64_t k); +void quantize_row_iq4_nl_ref (const float * GGML_RESTRICT x, block_iq4_nl * GGML_RESTRICT y, int64_t k); +void quantize_row_iq4_xs_ref (const float * GGML_RESTRICT x, block_iq4_xs * GGML_RESTRICT y, int64_t k); +void quantize_row_iq3_s_ref (const float * GGML_RESTRICT x, block_iq3_s * GGML_RESTRICT y, int64_t k); +void quantize_row_iq2_s_ref (const float * GGML_RESTRICT x, block_iq2_s * GGML_RESTRICT y, int64_t k); void quantize_row_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q4_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); @@ -156,4 +156,3 @@ void iq3xs_free_impl(int grid_size); #ifdef __cplusplus } #endif - diff --git a/llama/ggml.c b/llama/ggml.c index e1d2f674..e7822f91 100644 --- a/llama/ggml.c +++ b/llama/ggml.c @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * @@ -30,7 +30,7 @@ #include "ggml-impl.h" #include "ggml-quants.h" #include "ggml.h" - +#include "ggml-aarch64.h" #if defined(_MSC_VER) || defined(__MINGW32__) #include // using malloc.h with MSC/MINGW @@ -63,12 +63,12 @@ #include #endif -#ifdef __ARM_FEATURE_MATMUL_INT8 +#if defined(__ARM_FEATURE_SVE) || defined(__ARM_FEATURE_MATMUL_INT8) #undef GGML_USE_LLAMAFILE #endif #ifdef GGML_USE_LLAMAFILE -#include "sgemm.h" +#include #endif #if defined(_MSC_VER) @@ -167,23 +167,25 @@ typedef pthread_t ggml_thread_t; #include -void ggml_print_backtrace(void) { - /* - #include - #include - +#if defined(__linux__) +#include +static void ggml_print_backtrace_symbols(void) { void * trace[100]; - int nptrs = backtrace(trace, sizeof(trace)/sizeof(trace[0])); - backtrace_symbols_fd(trace, nptrs, STDERR_FILENO); - */ +} +#else +static void ggml_print_backtrace_symbols(void) { + // platform not supported +} +#endif - // backtrack_symbols does not show line numbers, use gdb instead +static void ggml_print_backtrace(void) { char attach[32]; snprintf(attach, sizeof(attach), "attach %d", getpid()); int pid = fork(); if (pid == 0) { + // try gdb execlp("gdb", "gdb", "--batch", "-ex", "set style enabled on", "-ex", attach, @@ -191,17 +193,46 @@ void ggml_print_backtrace(void) { "-ex", "detach", "-ex", "quit", (char *) NULL); + // try lldb + execlp("lldb", "lldb", "--batch", + "-o", "bt", + "-o", "quit", + "-p", attach, + (char *) NULL); + exit(EXIT_FAILURE); } else { - waitpid(pid, NULL, 0); + int wstatus; + waitpid(pid, &wstatus, 0); + if (WIFEXITED(wstatus)) { + if (WEXITSTATUS(wstatus) == EXIT_FAILURE) { + // gdb failed, fallback to backtrace_symbols + ggml_print_backtrace_symbols(); + } + } } } #else -void ggml_print_backtrace(void) { +static void ggml_print_backtrace(void) { // platform not supported } #endif -/*#define GGML_PERF*/ +void ggml_abort(const char * file, int line, const char * fmt, ...) { + fflush(stdout); + + fprintf(stderr, "%s:%d: ", file, line); + + va_list args; + va_start(args, fmt); + vfprintf(stderr, fmt, args); + va_end(args); + + fprintf(stderr, "\n"); + + ggml_print_backtrace(); + abort(); +} + #define GGML_DEBUG 0 #define GGML_GELU_FP16 #define GGML_GELU_QUICK_FP16 @@ -273,7 +304,7 @@ inline static void * ggml_aligned_malloc(size_t size) { break; } GGML_PRINT("%s: %s (attempted to allocate %6.2f MB)\n", __func__, error_desc, size/(1024.0*1024.0)); - GGML_ASSERT(false); + GGML_ABORT("fatal error"); return NULL; } return aligned_memory; @@ -294,7 +325,7 @@ inline static void * ggml_malloc(size_t size) { void * result = malloc(size); if (result == NULL) { GGML_PRINT("%s: failed to allocate %6.2f MB\n", __func__, size/(1024.0*1024.0)); - GGML_ASSERT(false); + GGML_ABORT("fatal error"); } return result; } @@ -308,7 +339,7 @@ inline static void * ggml_calloc(size_t num, size_t size) { void * result = calloc(num, size); if (result == NULL) { GGML_PRINT("%s: failed to allocate %6.2f MB\n", __func__, size/(1024.0*1024.0)); - GGML_ASSERT(false); + GGML_ABORT("fatal error"); } return result; } @@ -319,16 +350,10 @@ inline static void * ggml_calloc(size_t num, size_t size) { #define GGML_FREE(ptr) free(ptr) #define UNUSED GGML_UNUSED -#define SWAP(x, y, T) do { T SWAP = x; x = y; y = SWAP; } while (0) +#define SWAP(x, y, T) do { T SWAP = x; (x) = y; (y) = SWAP; } while (0) #if defined(GGML_USE_ACCELERATE) #include -#elif defined(GGML_USE_OPENBLAS) -#if defined(GGML_BLAS_USE_MKL) -#include -#else -#include -#endif #endif // floating point type used to accumulate sums @@ -506,18 +531,6 @@ int64_t ggml_cycles_per_ms(void) { return CLOCKS_PER_SEC/1000; } -#ifdef GGML_PERF -#define ggml_perf_time_ms() ggml_time_ms() -#define ggml_perf_time_us() ggml_time_us() -#define ggml_perf_cycles() ggml_cycles() -#define ggml_perf_cycles_per_ms() ggml_cycles_per_ms() -#else -#define ggml_perf_time_ms() 0 -#define ggml_perf_time_us() 0 -#define ggml_perf_cycles() 0 -#define ggml_perf_cycles_per_ms() 0 -#endif - // // cross-platform UTF-8 file paths // @@ -637,7 +650,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .is_quantized = false, .to_float = (ggml_to_float_t) ggml_fp16_to_fp32_row, .from_float = (ggml_from_float_t) ggml_fp32_to_fp16_row, - .from_float_reference = (ggml_from_float_t) ggml_fp32_to_fp16_row, + .from_float_ref = (ggml_from_float_t) ggml_fp32_to_fp16_row, .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f16, .vec_dot_type = GGML_TYPE_F16, .nrows = 1, @@ -649,7 +662,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_q4_0, .from_float = quantize_row_q4_0, - .from_float_reference = (ggml_from_float_t) quantize_row_q4_0_reference, + .from_float_ref = (ggml_from_float_t) quantize_row_q4_0_ref, .vec_dot = ggml_vec_dot_q4_0_q8_0, .vec_dot_type = GGML_TYPE_Q8_0, #if defined (__ARM_FEATURE_MATMUL_INT8) @@ -665,7 +678,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_q4_1, .from_float = quantize_row_q4_1, - .from_float_reference = (ggml_from_float_t) quantize_row_q4_1_reference, + .from_float_ref = (ggml_from_float_t) quantize_row_q4_1_ref, .vec_dot = ggml_vec_dot_q4_1_q8_1, .vec_dot_type = GGML_TYPE_Q8_1, #if defined (__ARM_FEATURE_MATMUL_INT8) @@ -681,7 +694,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .is_quantized = false, .to_float = NULL, .from_float = NULL, - .from_float_reference = NULL, + .from_float_ref = NULL, .vec_dot = NULL, .vec_dot_type = GGML_TYPE_COUNT, .nrows = 1, @@ -693,7 +706,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .is_quantized = false, .to_float = NULL, .from_float = NULL, - .from_float_reference = NULL, + .from_float_ref = NULL, .vec_dot = NULL, .vec_dot_type = GGML_TYPE_COUNT, .nrows = 1, @@ -705,7 +718,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_q5_0, .from_float = quantize_row_q5_0, - .from_float_reference = (ggml_from_float_t) quantize_row_q5_0_reference, + .from_float_ref = (ggml_from_float_t) quantize_row_q5_0_ref, .vec_dot = ggml_vec_dot_q5_0_q8_0, .vec_dot_type = GGML_TYPE_Q8_0, .nrows = 1, @@ -717,7 +730,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_q5_1, .from_float = quantize_row_q5_1, - .from_float_reference = (ggml_from_float_t) quantize_row_q5_1_reference, + .from_float_ref = (ggml_from_float_t) quantize_row_q5_1_ref, .vec_dot = ggml_vec_dot_q5_1_q8_1, .vec_dot_type = GGML_TYPE_Q8_1, .nrows = 1, @@ -729,7 +742,8 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_q8_0, .from_float = quantize_row_q8_0, - .from_float_reference = (ggml_from_float_t) quantize_row_q8_0_reference, + .from_float_ref = (ggml_from_float_t) quantize_row_q8_0_ref, + .from_float_to_mat = quantize_mat_q8_0, .vec_dot = ggml_vec_dot_q8_0_q8_0, .vec_dot_type = GGML_TYPE_Q8_0, #if defined (__ARM_FEATURE_MATMUL_INT8) @@ -744,7 +758,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .type_size = sizeof(block_q8_1), .is_quantized = true, .from_float = quantize_row_q8_1, - .from_float_reference = (ggml_from_float_t) quantize_row_q8_1_reference, + .from_float_ref = (ggml_from_float_t) quantize_row_q8_1_ref, .vec_dot_type = GGML_TYPE_Q8_1, .nrows = 1, }, @@ -755,7 +769,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_q2_K, .from_float = quantize_row_q2_K, - .from_float_reference = (ggml_from_float_t) quantize_row_q2_K_reference, + .from_float_ref = (ggml_from_float_t) quantize_row_q2_K_ref, .vec_dot = ggml_vec_dot_q2_K_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, @@ -767,7 +781,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_q3_K, .from_float = quantize_row_q3_K, - .from_float_reference = (ggml_from_float_t) quantize_row_q3_K_reference, + .from_float_ref = (ggml_from_float_t) quantize_row_q3_K_ref, .vec_dot = ggml_vec_dot_q3_K_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, @@ -779,7 +793,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_q4_K, .from_float = quantize_row_q4_K, - .from_float_reference = (ggml_from_float_t) quantize_row_q4_K_reference, + .from_float_ref = (ggml_from_float_t) quantize_row_q4_K_ref, .vec_dot = ggml_vec_dot_q4_K_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, @@ -791,7 +805,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_q5_K, .from_float = quantize_row_q5_K, - .from_float_reference = (ggml_from_float_t) quantize_row_q5_K_reference, + .from_float_ref = (ggml_from_float_t) quantize_row_q5_K_ref, .vec_dot = ggml_vec_dot_q5_K_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, @@ -803,7 +817,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_q6_K, .from_float = quantize_row_q6_K, - .from_float_reference = (ggml_from_float_t) quantize_row_q6_K_reference, + .from_float_ref = (ggml_from_float_t) quantize_row_q6_K_ref, .vec_dot = ggml_vec_dot_q6_K_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, @@ -815,7 +829,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_iq2_xxs, .from_float = NULL, - .from_float_reference = NULL, + .from_float_ref = NULL, .vec_dot = ggml_vec_dot_iq2_xxs_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, @@ -827,7 +841,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_iq2_xs, .from_float = NULL, - .from_float_reference = NULL, + .from_float_ref = NULL, .vec_dot = ggml_vec_dot_iq2_xs_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, @@ -839,7 +853,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_iq3_xxs, .from_float = quantize_row_iq3_xxs, - .from_float_reference = (ggml_from_float_t)quantize_row_iq3_xxs_reference, + .from_float_ref = (ggml_from_float_t)quantize_row_iq3_xxs_ref, .vec_dot = ggml_vec_dot_iq3_xxs_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, @@ -851,7 +865,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_iq3_s, .from_float = quantize_row_iq3_s, - .from_float_reference = (ggml_from_float_t)quantize_row_iq3_s_reference, + .from_float_ref = (ggml_from_float_t)quantize_row_iq3_s_ref, .vec_dot = ggml_vec_dot_iq3_s_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, @@ -863,7 +877,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_iq2_s, .from_float = quantize_row_iq2_s, - .from_float_reference = (ggml_from_float_t)quantize_row_iq2_s_reference, + .from_float_ref = (ggml_from_float_t)quantize_row_iq2_s_ref, .vec_dot = ggml_vec_dot_iq2_s_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, @@ -875,7 +889,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_iq1_s, .from_float = NULL, - .from_float_reference = NULL, + .from_float_ref = NULL, .vec_dot = ggml_vec_dot_iq1_s_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, @@ -887,7 +901,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_iq1_m, .from_float = NULL, - .from_float_reference = NULL, + .from_float_ref = NULL, .vec_dot = ggml_vec_dot_iq1_m_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, @@ -899,7 +913,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_iq4_nl, .from_float = quantize_row_iq4_nl, - .from_float_reference = (ggml_from_float_t)quantize_row_iq4_nl_reference, + .from_float_ref = (ggml_from_float_t)quantize_row_iq4_nl_ref, .vec_dot = ggml_vec_dot_iq4_nl_q8_0, .vec_dot_type = GGML_TYPE_Q8_0, .nrows = 1, @@ -911,7 +925,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_iq4_xs, .from_float = quantize_row_iq4_xs, - .from_float_reference = (ggml_from_float_t)quantize_row_iq4_xs_reference, + .from_float_ref = (ggml_from_float_t)quantize_row_iq4_xs_ref, .vec_dot = ggml_vec_dot_iq4_xs_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, @@ -930,10 +944,58 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .is_quantized = false, .to_float = (ggml_to_float_t) ggml_bf16_to_fp32_row, .from_float = (ggml_from_float_t) ggml_fp32_to_bf16_row, - .from_float_reference = (ggml_from_float_t) ggml_fp32_to_bf16_row, + .from_float_ref = (ggml_from_float_t) ggml_fp32_to_bf16_row, .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_bf16, .vec_dot_type = GGML_TYPE_BF16, .nrows = 1, + }, + [GGML_TYPE_Q4_0_4_4] = { + .type_name = "q4_0_4x4", + .blck_size = QK4_0, + .blck_size_interleave = 4, + .type_size = sizeof(block_q4_0), + .is_quantized = true, + .to_float = NULL, + .from_float = NULL, + .from_float_ref = NULL, + .vec_dot = NULL, + .vec_dot_type = GGML_TYPE_Q8_0, + .nrows = 1, + .ncols = 4, + .gemv = ggml_gemv_q4_0_4x4_q8_0, + .gemm = ggml_gemm_q4_0_4x4_q8_0, + }, + [GGML_TYPE_Q4_0_4_8] = { + .type_name = "q4_0_4x8", + .blck_size = QK4_0, + .blck_size_interleave = 8, + .type_size = sizeof(block_q4_0), + .is_quantized = true, + .to_float = NULL, + .from_float = NULL, + .from_float_ref = NULL, + .vec_dot = NULL, + .vec_dot_type = GGML_TYPE_Q8_0, + .nrows = 1, + .ncols = 4, + .gemv = ggml_gemv_q4_0_4x8_q8_0, + .gemm = ggml_gemm_q4_0_4x8_q8_0, + }, + [GGML_TYPE_Q4_0_8_8] = { + .type_name = "q4_0_8x8", + .blck_size = QK4_0, + .blck_size_interleave = 8, + .type_size = sizeof(block_q4_0), + .is_quantized = true, + .to_float = NULL, + .from_float = NULL, + .from_float_ref = NULL, + .vec_dot = NULL, + .vec_dot_type = GGML_TYPE_Q8_0, + .nrows = 1, + .ncols = 8, + .gemv = ggml_gemv_q4_0_8x8_q8_0, + .gemm = ggml_gemm_q4_0_8x8_q8_0, } }; @@ -1762,8 +1824,8 @@ struct ggml_context { int n_objects; - struct ggml_object* objects_begin; - struct ggml_object* objects_end; + struct ggml_object * objects_begin; + struct ggml_object * objects_end; struct ggml_scratch scratch; struct ggml_scratch scratch_save; @@ -1776,30 +1838,38 @@ struct ggml_context_container { }; struct ggml_compute_state_shared { - const struct ggml_cgraph* cgraph; - const struct ggml_cplan* cplan; - - int64_t perf_node_start_cycles; - int64_t perf_node_start_time_us; + const struct ggml_cgraph * cgraph; + const struct ggml_cplan * cplan; int n_threads; // synchronization primitives - atomic_int n_active; // num active threads - atomic_int node_n; // active graph node - atomic_int node_task; // active graph node task phase + atomic_int n_barrier; + atomic_int n_barrier_passed; ggml_abort_callback abort_callback; // abort ggml_graph_compute when true - void* abort_callback_data; + void * abort_callback_data; - atomic_int current_chunk; // currently processing chunk during Mat_Mul, shared between all the threads. + atomic_int current_chunk; // currently processing chunk during mul_mat, shared between all the threads + + enum ggml_status ec; }; struct ggml_compute_state { ggml_thread_t thrd; int ith; - struct ggml_compute_state_shared* shared; - enum ggml_status ec; + struct ggml_compute_state_shared * shared; +}; + +struct ggml_compute_params { + // ith = thread index, nth = number of threads + int ith, nth; + + // work buffer for all threads + size_t wsize; + void * wdata; + + struct ggml_compute_state_shared * shared; }; // @@ -2847,42 +2917,6 @@ static_assert(GGML_UNARY_OP_COUNT == 13, "GGML_UNARY_OP_COUNT != 13"); static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN"); static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN"); -// WARN: -// Mis-configuration can lead to problem that's hard to reason about: -// * At best it crash or talks nosense. -// * At worst it talks slightly difference but hard to perceive. -// -// An op has to enable INIT or FINALIZE when any of it's branch needs that pass. -// Take care about compile options (e.g., GGML_USE_xxx). -static bool GGML_OP_HAS_INIT [GGML_OP_COUNT] = { 0 }; -static bool GGML_OP_HAS_FINALIZE[GGML_OP_COUNT] = { 0 }; - -static void ggml_setup_op_has_task_pass(void) { - { // INIT - bool * p = GGML_OP_HAS_INIT; - - p[GGML_OP_ACC ] = true; - p[GGML_OP_MUL_MAT ] = true; - p[GGML_OP_MUL_MAT_ID ] = true; - p[GGML_OP_OUT_PROD ] = true; - p[GGML_OP_SET ] = true; - p[GGML_OP_GET_ROWS_BACK ] = true; - p[GGML_OP_DIAG_MASK_INF ] = true; - p[GGML_OP_DIAG_MASK_ZERO ] = true; - p[GGML_OP_CONV_TRANSPOSE_1D ] = true; - p[GGML_OP_CONV_TRANSPOSE_2D ] = true; - p[GGML_OP_FLASH_ATTN_BACK ] = true; - p[GGML_OP_CROSS_ENTROPY_LOSS ] = true; - p[GGML_OP_ADD_REL_POS ] = true; - } - - { // FINALIZE - bool * p = GGML_OP_HAS_FINALIZE; - - p[GGML_OP_CROSS_ENTROPY_LOSS ] = true; - } -} - // // NUMA support // @@ -2921,7 +2955,7 @@ struct ggml_state { static struct ggml_state g_state; static atomic_flag g_state_critical = ATOMIC_FLAG_INIT; -// barrier via spin lock +// critical section via spin lock inline static void ggml_critical_section_start(void) { while (atomic_flag_test_and_set(&g_state_critical)) { // spin @@ -2929,6 +2963,48 @@ inline static void ggml_critical_section_start(void) { } } +#ifdef GGML_USE_OPENMP +static void ggml_barrier(struct ggml_compute_state_shared * shared) { + if (shared->n_threads == 1) { + return; + } + + #pragma omp barrier +} +#else +static void ggml_barrier(struct ggml_compute_state_shared * shared) { + if (shared->n_threads == 1) { + return; + } + + atomic_int * n_barrier = &shared->n_barrier; + atomic_int * n_barrier_passed = &shared->n_barrier_passed; + + int n_threads = shared->n_threads; + int passed_old = atomic_load(n_barrier_passed); + + if (atomic_fetch_add(n_barrier, 1) == n_threads - 1) { + // last thread + atomic_store(n_barrier, 0); + atomic_fetch_add(n_barrier_passed, 1); + } else { + // wait for other threads + const int n_spin_before_sleep = 100000; + while (true) { + for (int i = 0; i < n_spin_before_sleep; i++) { + if (atomic_load(n_barrier_passed) != passed_old) { + return; + } + #if defined(__SSE3__) + _mm_pause(); + #endif + } + sched_yield(); + } + } +} +#endif + // TODO: make this somehow automatically executed // some sort of "sentry" mechanism inline static void ggml_critical_section_end(void) { @@ -3033,7 +3109,7 @@ void ggml_numa_init(enum ggml_numa_strategy numa_flag) { } } #else - GGML_UNUSED(numa_flag); + UNUSED(numa_flag); // TODO #endif } @@ -3097,7 +3173,7 @@ size_t ggml_nbytes_pad(const struct ggml_tensor * tensor) { return GGML_PAD(ggml_nbytes(tensor), GGML_MEM_ALIGN); } -GGML_CALL int ggml_blck_size(enum ggml_type type) { +GGML_CALL int64_t ggml_blck_size(enum ggml_type type) { return type_traits[type].blck_size; } @@ -3139,9 +3215,7 @@ GGML_CALL const char * ggml_op_desc(const struct ggml_tensor * t) { enum ggml_unary_op uop = ggml_get_unary_op(t); return ggml_unary_op_name(uop); } - else { - return ggml_op_name(t->op); - } + return ggml_op_name(t->op); } GGML_CALL size_t ggml_element_size(const struct ggml_tensor * tensor) { @@ -3221,6 +3295,9 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) { case GGML_FTYPE_MOSTLY_IQ4_XS: wtype = GGML_TYPE_IQ4_XS; break; case GGML_FTYPE_MOSTLY_IQ3_S: wtype = GGML_TYPE_IQ3_S; break; case GGML_FTYPE_MOSTLY_IQ2_S: wtype = GGML_TYPE_IQ2_S; break; + case GGML_FTYPE_MOSTLY_Q4_0_4_4: wtype = GGML_TYPE_Q4_0_4_4; break; + case GGML_FTYPE_MOSTLY_Q4_0_4_8: wtype = GGML_TYPE_Q4_0_4_8; break; + case GGML_FTYPE_MOSTLY_Q4_0_8_8: wtype = GGML_TYPE_Q4_0_8_8; break; case GGML_FTYPE_UNKNOWN: wtype = GGML_TYPE_COUNT; break; case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break; } @@ -3238,35 +3315,42 @@ GGML_CALL bool ggml_is_transposed(const struct ggml_tensor * tensor) { return tensor->nb[0] > tensor->nb[1]; } -GGML_CALL bool ggml_is_contiguous(const struct ggml_tensor * tensor) { - static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); +static bool ggml_is_contiguous_n(const struct ggml_tensor * tensor, int n) { + size_t next_nb = ggml_type_size(tensor->type); + if (tensor->ne[0] != ggml_blck_size(tensor->type) && tensor->nb[0] != next_nb) { + return false; + } + next_nb *= tensor->ne[0]/ggml_blck_size(tensor->type); + for (int i = 1; i < GGML_MAX_DIMS; i++) { + if (tensor->ne[i] != 1) { + if (i > n) { + if (tensor->nb[i] != next_nb) { + return false; + } + next_nb *= tensor->ne[i]; + } else { + // this dimension does not need to be contiguous + next_nb = tensor->ne[i]*tensor->nb[i]; + } + } + } + return true; +} - return - tensor->nb[0] == ggml_type_size(tensor->type) && - tensor->nb[1] == (tensor->nb[0]*tensor->ne[0])/ggml_blck_size(tensor->type) && - tensor->nb[2] == tensor->nb[1]*tensor->ne[1] && - tensor->nb[3] == tensor->nb[2]*tensor->ne[2]; +GGML_CALL bool ggml_is_contiguous(const struct ggml_tensor * tensor) { + return ggml_is_contiguous_0(tensor); } GGML_CALL bool ggml_is_contiguous_0(const struct ggml_tensor * tensor) { - return ggml_is_contiguous(tensor); + return ggml_is_contiguous_n(tensor, 0); } GGML_CALL bool ggml_is_contiguous_1(const struct ggml_tensor * tensor) { - static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); - - return - tensor->nb[0] == ggml_type_size(tensor->type) && - tensor->nb[2] == tensor->nb[1]*tensor->ne[1] && - tensor->nb[3] == tensor->nb[2]*tensor->ne[2]; + return ggml_is_contiguous_n(tensor, 1); } GGML_CALL bool ggml_is_contiguous_2(const struct ggml_tensor * tensor) { - static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); - - return - tensor->nb[0] == ggml_type_size(tensor->type) && - tensor->nb[3] == tensor->nb[2]*tensor->ne[2]; + return ggml_is_contiguous_n(tensor, 2); } GGML_CALL bool ggml_is_permuted(const struct ggml_tensor * tensor) { @@ -3298,24 +3382,24 @@ bool ggml_are_same_shape(const struct ggml_tensor * t0, const struct ggml_tensor static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); return - (t0->ne[0] == t1->ne[0] ) && - (t0->ne[1] == t1->ne[1] ) && - (t0->ne[2] == t1->ne[2] ) && - (t0->ne[3] == t1->ne[3] ); + (t0->ne[0] == t1->ne[0]) && + (t0->ne[1] == t1->ne[1]) && + (t0->ne[2] == t1->ne[2]) && + (t0->ne[3] == t1->ne[3]); } bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tensor * t1) { static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); return - (t0->nb[0] == t1->nb[0] ) && - (t0->nb[1] == t1->nb[1] ) && - (t0->nb[2] == t1->nb[2] ) && - (t0->nb[3] == t1->nb[3] ); + (t0->nb[0] == t1->nb[0]) && + (t0->nb[1] == t1->nb[1]) && + (t0->nb[2] == t1->nb[2]) && + (t0->nb[3] == t1->nb[3]); } // check if t1 can be represented as a repeatition of t0 -static inline bool ggml_can_repeat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) { +bool ggml_can_repeat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) { static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); return ggml_is_empty(t0) ? ggml_is_empty(t1) : @@ -3346,7 +3430,7 @@ static inline int ggml_up(int n, int m) { } // assert that pointer is aligned to GGML_MEM_ALIGN -#define ggml_assert_aligned(ptr) \ +#define GGML_ASSERT_ALIGNED(ptr) \ GGML_ASSERT(((uintptr_t) (ptr))%GGML_MEM_ALIGN == 0) //////////////////////////////////////////////////////////////////////////////// @@ -3401,8 +3485,6 @@ struct ggml_context * ggml_init(struct ggml_init_params params) { GGML_PRINT_DEBUG("%s: g_state initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f); } - ggml_setup_op_has_task_pass(); - is_first_call = false; } @@ -3449,7 +3531,7 @@ struct ggml_context * ggml_init(struct ggml_init_params params) { GGML_ASSERT(ctx->mem_buffer != NULL); - ggml_assert_aligned(ctx->mem_buffer); + GGML_ASSERT_ALIGNED(ctx->mem_buffer); GGML_PRINT_DEBUG("%s: context initialized\n", __func__); @@ -3581,7 +3663,7 @@ static struct ggml_object * ggml_new_object(struct ggml_context * ctx, enum ggml .type = type, }; - ggml_assert_aligned(mem_buffer + obj_new->offs); + GGML_ASSERT_ALIGNED(mem_buffer + obj_new->offs); if (obj_cur != NULL) { obj_cur->next = obj_new; @@ -3669,15 +3751,12 @@ static struct ggml_tensor * ggml_new_tensor_impl( /*.flags =*/ 0, /*.grad =*/ NULL, /*.src =*/ { NULL }, - /*.perf_runs =*/ 0, - /*.perf_cycles =*/ 0, - /*.perf_time_us =*/ 0, /*.view_src =*/ view_src, /*.view_offs =*/ view_offs, /*.data =*/ obj_alloc_size > 0 ? (void *)(result + 1) : data, /*.name =*/ { 0 }, /*.extra =*/ NULL, - /*.padding =*/ { 0 }, + ///*.padding =*/ { 0 }, }; #ifdef __clang__ @@ -3685,7 +3764,7 @@ static struct ggml_tensor * ggml_new_tensor_impl( #endif // TODO: this should not be needed as long as we don't rely on aligned SIMD loads - //ggml_assert_aligned(result->data); + //GGML_ASSERT_ALIGNED(result->data); for (int i = 0; i < n_dims; i++) { result->ne[i] = ne[i]; @@ -3858,8 +3937,8 @@ struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) { } break; default: { - GGML_ASSERT(false); - } break; + GGML_ABORT("fatal error"); + } } return tensor; @@ -3917,8 +3996,8 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) { } break; default: { - GGML_ASSERT(false); - } break; + GGML_ABORT("fatal error"); + } } return tensor; @@ -3987,11 +4066,9 @@ int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) { } default: { - GGML_ASSERT(false); + GGML_ABORT("fatal error"); } } - - return 0.0f; } void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) { @@ -4034,8 +4111,8 @@ void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) { } break; default: { - GGML_ASSERT(false); - } break; + GGML_ABORT("fatal error"); + } } } @@ -4055,10 +4132,8 @@ int32_t ggml_get_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i case GGML_TYPE_F32: return ((float *) data)[0]; default: - GGML_ASSERT(false); + GGML_ABORT("fatal error"); } - - return 0.0f; } void ggml_set_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, int32_t value) { @@ -4090,8 +4165,8 @@ void ggml_set_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, } break; default: { - GGML_ASSERT(false); - } break; + GGML_ABORT("fatal error"); + } } } @@ -4104,41 +4179,33 @@ float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) { switch (tensor->type) { case GGML_TYPE_I8: { - GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); return ((int8_t *)(tensor->data))[i]; } case GGML_TYPE_I16: { - GGML_ASSERT(tensor->nb[0] == sizeof(int16_t)); return ((int16_t *)(tensor->data))[i]; } case GGML_TYPE_I32: { - GGML_ASSERT(tensor->nb[0] == sizeof(int32_t)); return ((int32_t *)(tensor->data))[i]; } case GGML_TYPE_F16: { - GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t)); return GGML_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]); } case GGML_TYPE_BF16: { - GGML_ASSERT(tensor->nb[0] == sizeof(ggml_bf16_t)); return GGML_BF16_TO_FP32(((ggml_bf16_t *)(tensor->data))[i]); } case GGML_TYPE_F32: { - GGML_ASSERT(tensor->nb[0] == sizeof(float)); return ((float *)(tensor->data))[i]; } default: { - GGML_ASSERT(false); + GGML_ABORT("fatal error"); } } - - return 0.0f; } void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) { @@ -4151,38 +4218,32 @@ void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) { switch (tensor->type) { case GGML_TYPE_I8: { - GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); ((int8_t *)(tensor->data))[i] = value; } break; case GGML_TYPE_I16: { - GGML_ASSERT(tensor->nb[0] == sizeof(int16_t)); ((int16_t *)(tensor->data))[i] = value; } break; case GGML_TYPE_I32: { - GGML_ASSERT(tensor->nb[0] == sizeof(int32_t)); ((int32_t *)(tensor->data))[i] = value; } break; case GGML_TYPE_F16: { - GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t)); ((ggml_fp16_t *)(tensor->data))[i] = GGML_FP32_TO_FP16(value); } break; case GGML_TYPE_BF16: { - GGML_ASSERT(tensor->nb[0] == sizeof(ggml_bf16_t)); ((ggml_bf16_t *)(tensor->data))[i] = GGML_FP32_TO_BF16(value); } break; case GGML_TYPE_F32: { - GGML_ASSERT(tensor->nb[0] == sizeof(float)); ((float *)(tensor->data))[i] = value; } break; default: { - GGML_ASSERT(false); - } break; + GGML_ABORT("fatal error"); + } } } @@ -4202,10 +4263,8 @@ float ggml_get_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, case GGML_TYPE_F32: return ((float *) data)[0]; default: - GGML_ASSERT(false); + GGML_ABORT("fatal error"); } - - return 0.0f; } void ggml_set_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, float value) { @@ -4237,8 +4296,8 @@ void ggml_set_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, } break; default: { - GGML_ASSERT(false); - } break; + GGML_ABORT("fatal error"); + } } } @@ -4261,8 +4320,11 @@ const char * ggml_get_name(const struct ggml_tensor * tensor) { } struct ggml_tensor * ggml_set_name(struct ggml_tensor * tensor, const char * name) { - strncpy(tensor->name, name, sizeof(tensor->name) - 1); - tensor->name[sizeof(tensor->name) - 1] = '\0'; + size_t i; + for (i = 0; i < sizeof(tensor->name) - 1 && name[i] != '\0'; i++) { + tensor->name[i] = name[i]; + } + tensor->name[i] = '\0'; return tensor; } @@ -4833,7 +4895,7 @@ struct ggml_tensor * ggml_mean( bool is_node = false; if (a->grad) { - GGML_ASSERT(false); // TODO: implement + GGML_ABORT("fatal error"); // TODO: implement is_node = true; } @@ -4856,7 +4918,7 @@ struct ggml_tensor * ggml_argmax( bool is_node = false; if (a->grad) { - GGML_ASSERT(false); + GGML_ABORT("fatal error"); is_node = true; } @@ -5179,7 +5241,7 @@ static struct ggml_tensor * ggml_norm_impl( bool is_node = false; if (!inplace && (a->grad)) { - GGML_ASSERT(false); // TODO: implement backward + GGML_ABORT("fatal error"); // TODO: implement backward is_node = true; } @@ -5282,7 +5344,7 @@ static struct ggml_tensor * ggml_group_norm_impl( bool is_node = false; if (!inplace && (a->grad)) { - GGML_ASSERT(false); // TODO: implement backward + GGML_ABORT("fatal error"); // TODO: implement backward is_node = true; } @@ -5355,7 +5417,7 @@ void ggml_mul_mat_set_prec( as -> [cols, rows, n_expert] ids -> [n_experts_used, n_tokens] (i32) b -> [cols, n_expert_used, n_tokens] - c -> [cols, n_expert_used, n_tokens] + c -> [rows, n_expert_used, n_tokens] in b, n_experts_used can be broadcasted to match the n_expert_used of ids @@ -5696,7 +5758,7 @@ struct ggml_tensor * ggml_reshape( if (b->grad) { // gradient propagation is not supported - //GGML_ASSERT(false); + //GGML_ABORT("fatal error"); } struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, GGML_MAX_DIMS, b->ne, a, 0); @@ -6479,7 +6541,7 @@ struct ggml_tensor * ggml_clamp( bool is_node = false; if (a->grad) { - GGML_ASSERT(false); // TODO: implement backward + GGML_ABORT("fatal error"); // TODO: implement backward is_node = true; } @@ -6555,7 +6617,7 @@ GGML_API struct ggml_tensor * ggml_conv_transpose_1d( bool is_node = false; if (a->grad || b->grad) { - GGML_ASSERT(false); // TODO: implement backward + GGML_ABORT("fatal error"); // TODO: implement backward is_node = true; } @@ -6627,7 +6689,7 @@ struct ggml_tensor * ggml_im2col( bool is_node = false; if (a->grad || b->grad) { - GGML_ASSERT(false); // TODO: implement backward + GGML_ABORT("fatal error"); // TODO: implement backward is_node = true; } @@ -6713,7 +6775,7 @@ struct ggml_tensor * ggml_conv_transpose_2d_p0( bool is_node = false; if (a->grad || b->grad) { - GGML_ASSERT(false); // TODO: implement backward + GGML_ABORT("fatal error"); // TODO: implement backward is_node = true; } @@ -6754,7 +6816,7 @@ struct ggml_tensor * ggml_pool_1d( bool is_node = false; if (a->grad) { - GGML_ASSERT(false); // TODO: implement backward + GGML_ABORT("fatal error"); // TODO: implement backward is_node = true; } @@ -6792,7 +6854,7 @@ struct ggml_tensor * ggml_pool_2d( bool is_node = false; if (a->grad) { - GGML_ASSERT(false); // TODO: implement backward + GGML_ABORT("fatal error"); // TODO: implement backward is_node = true; } @@ -6825,7 +6887,7 @@ static struct ggml_tensor * ggml_upscale_impl( bool is_node = false; if (a->grad) { - GGML_ASSERT(false); // TODO: implement backward + GGML_ABORT("fatal error"); // TODO: implement backward is_node = true; } @@ -6875,7 +6937,7 @@ struct ggml_tensor * ggml_pad( bool is_node = false; if (a->grad) { - GGML_ASSERT(false); // TODO: implement backward + GGML_ABORT("fatal error"); // TODO: implement backward is_node = true; } @@ -6924,7 +6986,7 @@ struct ggml_tensor * ggml_timestep_embedding( bool is_node = false; if (timesteps->grad) { - GGML_ASSERT(false); // TODO: implement backward + GGML_ABORT("fatal error"); // TODO: implement backward is_node = true; } @@ -7050,7 +7112,7 @@ struct ggml_tensor * ggml_flash_attn_back( struct ggml_tensor * v, struct ggml_tensor * d, bool masked) { - GGML_ASSERT(false && "TODO: adapt to ggml_flash_attn_ext() changes"); + GGML_ABORT("TODO: adapt to ggml_flash_attn_ext() changes"); GGML_ASSERT(ggml_can_mul_mat(k, q)); // TODO: check if vT can be multiplied by (k*qT) @@ -7149,7 +7211,7 @@ struct ggml_tensor * ggml_ssm_conv( bool is_node = false; if (s->grad || x->grad || c->grad || sq->grad) { - GGML_ASSERT(false); // TODO: implement + GGML_ABORT("fatal error"); // TODO: implement is_node = true; } @@ -7203,7 +7265,7 @@ struct ggml_tensor * ggml_ssm_scan( bool is_node = false; if (s->grad || x->grad || dt->grad || A->grad || B->grad || C->grad || sq->grad) { - GGML_ASSERT(false); // TODO: implement + GGML_ABORT("fatal error"); // TODO: implement is_node = true; } @@ -7235,7 +7297,7 @@ struct ggml_tensor * ggml_win_part( bool is_node = false; if (a->grad) { - GGML_ASSERT(false); // TODO: implement backward + GGML_ABORT("fatal error"); // TODO: implement backward is_node = true; } @@ -7273,7 +7335,7 @@ struct ggml_tensor * ggml_win_unpart( bool is_node = false; if (a->grad) { - GGML_ASSERT(false); // TODO: implement backward + GGML_ABORT("fatal error"); // TODO: implement backward is_node = true; } @@ -7303,7 +7365,7 @@ struct ggml_tensor * ggml_get_rel_pos( bool is_node = false; if (a->grad) { - GGML_ASSERT(false); // TODO: implement backward + GGML_ABORT("fatal error"); // TODO: implement backward is_node = true; } @@ -7369,13 +7431,15 @@ struct ggml_tensor * ggml_add_rel_pos_inplace( return ggml_add_rel_pos_impl(ctx, a, pw, ph, true); } -// gmml_unary +// ggml_unary static struct ggml_tensor * ggml_unary_impl( struct ggml_context * ctx, struct ggml_tensor * a, enum ggml_unary_op op, bool inplace) { + GGML_ASSERT(ggml_is_contiguous_1(a)); + bool is_node = false; if (!inplace && (a->grad)) { @@ -7865,10 +7929,6 @@ static void ggml_compute_forward_dup_same_cont( GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0)); GGML_ASSERT(src0->type == dst->type); - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { - return; - } - const size_t nb00 = src0->nb[0]; const size_t nb0 = dst->nb[0]; @@ -7897,10 +7957,6 @@ static void ggml_compute_forward_dup_f16( GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0)); - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { - return; - } - GGML_TENSOR_UNARY_OP_LOCALS const int ith = params->ith; // thread index @@ -7999,7 +8055,7 @@ static void ggml_compute_forward_dup_f16( } } } else { - GGML_ASSERT(false); // TODO: implement + GGML_ABORT("fatal error"); // TODO: implement } } else { //printf("%s: this is not optimal - fix me\n", __func__); @@ -8041,7 +8097,7 @@ static void ggml_compute_forward_dup_f16( } } } else { - GGML_ASSERT(false); // TODO: implement + GGML_ABORT("fatal error"); // TODO: implement } } return; @@ -8158,7 +8214,7 @@ static void ggml_compute_forward_dup_f16( } } } else { - GGML_ASSERT(false); // TODO: implement + GGML_ABORT("fatal error"); // TODO: implement } } @@ -8170,10 +8226,6 @@ static void ggml_compute_forward_dup_bf16( GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0)); - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { - return; - } - GGML_TENSOR_UNARY_OP_LOCALS const int ith = params->ith; // thread index @@ -8289,7 +8341,7 @@ static void ggml_compute_forward_dup_bf16( } } } else { - GGML_ASSERT(false); // TODO: implement + GGML_ABORT("fatal error"); // TODO: implement } } else { //printf("%s: this is not optimal - fix me\n", __func__); @@ -8349,7 +8401,7 @@ static void ggml_compute_forward_dup_bf16( } } } else { - GGML_ASSERT(false); // TODO: implement + GGML_ABORT("fatal error"); // TODO: implement } } return; @@ -8518,7 +8570,7 @@ static void ggml_compute_forward_dup_bf16( } } } else { - GGML_ASSERT(false); // TODO: implement + GGML_ABORT("fatal error"); // TODO: implement } } @@ -8530,10 +8582,6 @@ static void ggml_compute_forward_dup_f32( GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0)); - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { - return; - } - GGML_TENSOR_UNARY_OP_LOCALS const int ith = params->ith; // thread index @@ -8608,7 +8656,7 @@ static void ggml_compute_forward_dup_f32( } } } else { - GGML_ASSERT(false); // TODO: implement + GGML_ABORT("fatal error"); // TODO: implement } } else { //printf("%s: this is not optimal - fix me\n", __func__); @@ -8668,7 +8716,7 @@ static void ggml_compute_forward_dup_f32( } } } else { - GGML_ASSERT(false); // TODO: implement + GGML_ABORT("fatal error"); // TODO: implement } } @@ -8839,7 +8887,7 @@ static void ggml_compute_forward_dup_f32( } } } else { - GGML_ASSERT(false); // TODO: implement + GGML_ABORT("fatal error"); // TODO: implement } } @@ -8853,10 +8901,6 @@ static void ggml_compute_forward_dup_bytes( GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0)); GGML_ASSERT(src0->type == dst->type); - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { - return; - } - if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst)) { ggml_compute_forward_dup_same_cont(params, dst); return; @@ -9021,8 +9065,8 @@ static void ggml_compute_forward_dup( } break; default: { - GGML_ASSERT(false); - } break; + GGML_ABORT("fatal error"); + } } } @@ -9037,10 +9081,6 @@ static void ggml_compute_forward_add_f32( GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst)); - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { - return; - } - const int ith = params->ith; const int nth = params->nth; @@ -9116,10 +9156,6 @@ static void ggml_compute_forward_add_f16_f32( GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { - return; - } - const int ith = params->ith; const int nth = params->nth; @@ -9182,7 +9218,7 @@ static void ggml_compute_forward_add_f16_f32( } else { // src1 is not contiguous - GGML_ASSERT(false); + GGML_ABORT("fatal error"); } } @@ -9195,10 +9231,6 @@ static void ggml_compute_forward_add_bf16_f32( GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { - return; - } - const int ith = params->ith; const int nth = params->nth; @@ -9261,7 +9293,7 @@ static void ggml_compute_forward_add_bf16_f32( } else { // src1 is not contiguous - GGML_ASSERT(false); + GGML_ABORT("fatal error"); } } @@ -9274,10 +9306,6 @@ static void ggml_compute_forward_add_f16_f16( GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { - return; - } - const int ith = params->ith; const int nth = params->nth; @@ -9317,7 +9345,7 @@ static void ggml_compute_forward_add_f16_f16( } else { // src1 is not contiguous - GGML_ASSERT(false); + GGML_ABORT("fatal error"); } } @@ -9330,10 +9358,6 @@ static void ggml_compute_forward_add_bf16_bf16( GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { - return; - } - const int ith = params->ith; const int nth = params->nth; @@ -9373,7 +9397,7 @@ static void ggml_compute_forward_add_bf16_bf16( } else { // src1 is not contiguous - GGML_ASSERT(false); + GGML_ABORT("fatal error"); } } @@ -9386,10 +9410,6 @@ static void ggml_compute_forward_add_q_f32( GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { - return; - } - const int nr = ggml_nrows(src0); GGML_TENSOR_BINARY_OP_LOCALS @@ -9471,7 +9491,7 @@ static void ggml_compute_forward_add( ggml_compute_forward_add_f32(params, dst); } else { - GGML_ASSERT(false); + GGML_ABORT("fatal error"); } } break; case GGML_TYPE_F16: @@ -9483,7 +9503,7 @@ static void ggml_compute_forward_add( ggml_compute_forward_add_f16_f32(params, dst); } else { - GGML_ASSERT(false); + GGML_ABORT("fatal error"); } } break; case GGML_TYPE_BF16: @@ -9495,7 +9515,7 @@ static void ggml_compute_forward_add( ggml_compute_forward_add_bf16_f32(params, dst); } else { - GGML_ASSERT(false); + GGML_ABORT("fatal error"); } } break; case GGML_TYPE_Q4_0: @@ -9517,13 +9537,16 @@ static void ggml_compute_forward_add( case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ2_S: + case GGML_TYPE_Q4_0_4_4: + case GGML_TYPE_Q4_0_4_8: + case GGML_TYPE_Q4_0_8_8: { ggml_compute_forward_add_q_f32(params, dst); } break; default: { - GGML_ASSERT(false); - } break; + GGML_ABORT("fatal error"); + } } } @@ -9539,10 +9562,6 @@ static void ggml_compute_forward_add1_f32( GGML_ASSERT(ggml_are_same_shape(src0, dst)); GGML_ASSERT(ggml_is_scalar(src1)); - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { - return; - } - const int ith = params->ith; const int nth = params->nth; @@ -9593,10 +9612,6 @@ static void ggml_compute_forward_add1_f16_f32( GGML_ASSERT(ggml_are_same_shape(src0, dst)); GGML_ASSERT(ggml_is_scalar(src1)); - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { - return; - } - // scalar to add const float v = *(float *) src1->data; @@ -9645,10 +9660,6 @@ static void ggml_compute_forward_add1_f16_f16( GGML_ASSERT(ggml_are_same_shape(src0, dst)); GGML_ASSERT(ggml_is_scalar(src1)); - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { - return; - } - // scalar to add const float v = GGML_FP16_TO_FP32(*(ggml_fp16_t *) src1->data); @@ -9697,10 +9708,6 @@ static void ggml_compute_forward_add1_q_f32( GGML_ASSERT(ggml_are_same_shape(src0, dst)); GGML_ASSERT(ggml_is_scalar(src1)); - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { - return; - } - // scalar to add const float v = *(float *) src1->data; @@ -9766,10 +9773,6 @@ static void ggml_compute_forward_add1_bf16_f32( GGML_ASSERT(ggml_are_same_shape(src0, dst)); GGML_ASSERT(ggml_is_scalar(src1)); - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { - return; - } - // scalar to add const float v = *(float *) src1->data; @@ -9818,10 +9821,6 @@ static void ggml_compute_forward_add1_bf16_bf16( GGML_ASSERT(ggml_are_same_shape(src0, dst)); GGML_ASSERT(ggml_is_scalar(src1)); - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { - return; - } - // scalar to add const float v = GGML_BF16_TO_FP32(*(ggml_bf16_t *) src1->data); @@ -9881,7 +9880,7 @@ static void ggml_compute_forward_add1( ggml_compute_forward_add1_f16_f32(params, dst); } else { - GGML_ASSERT(false); + GGML_ABORT("fatal error"); } } break; case GGML_TYPE_BF16: @@ -9893,7 +9892,7 @@ static void ggml_compute_forward_add1( ggml_compute_forward_add1_bf16_f32(params, dst); } else { - GGML_ASSERT(false); + GGML_ABORT("fatal error"); } } break; case GGML_TYPE_Q4_0: @@ -9916,13 +9915,16 @@ static void ggml_compute_forward_add1( case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ2_S: + case GGML_TYPE_Q4_0_4_4: + case GGML_TYPE_Q4_0_4_8: + case GGML_TYPE_Q4_0_8_8: { ggml_compute_forward_add1_q_f32(params, dst); } break; default: { - GGML_ASSERT(false); - } break; + GGML_ABORT("fatal error"); + } } } @@ -9946,20 +9948,16 @@ static void ggml_compute_forward_acc_f32( size_t offset = ((int32_t *) dst->op_params)[3]; bool inplace = (bool) ((int32_t *) dst->op_params)[4]; - if (!inplace && (params->type == GGML_TASK_TYPE_INIT)) { - if (params->ith != 0) { - return; + if (!inplace) { + if (params->ith == 0) { + // memcpy needs to be synchronized across threads to avoid race conditions. + // => do it in INIT phase + memcpy( + ((char *) dst->data), + ((char *) src0->data), + ggml_nbytes(dst)); } - // memcpy needs to be synchronized across threads to avoid race conditions. - // => do it in INIT phase - memcpy( - ((char *) dst->data), - ((char *) src0->data), - ggml_nbytes(dst)); - } - - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { - return; + ggml_barrier(params->shared); } const int ith = params->ith; @@ -10045,10 +10043,13 @@ static void ggml_compute_forward_acc( case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ2_S: + case GGML_TYPE_Q4_0_4_4: + case GGML_TYPE_Q4_0_4_8: + case GGML_TYPE_Q4_0_8_8: default: { - GGML_ASSERT(false); - } break; + GGML_ABORT("fatal error"); + } } } @@ -10061,13 +10062,12 @@ static void ggml_compute_forward_sub_f32( const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * src1 = dst->src[1]; - assert(params->ith == 0); - assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); - - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { + if (params->ith != 0) { return; } + assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); + const int nr = ggml_nrows(src0); GGML_TENSOR_BINARY_OP_LOCALS @@ -10129,8 +10129,8 @@ static void ggml_compute_forward_sub( } break; default: { - GGML_ASSERT(false); - } break; + GGML_ABORT("fatal error"); + } } } @@ -10145,9 +10145,6 @@ static void ggml_compute_forward_mul_f32( GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst)); - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { - return; - } const int ith = params->ith; const int nth = params->nth; @@ -10226,8 +10223,8 @@ static void ggml_compute_forward_mul( } break; default: { - GGML_ASSERT(false); - } break; + GGML_ABORT("fatal error"); + } } } @@ -10242,10 +10239,6 @@ static void ggml_compute_forward_div_f32( GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst)); - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { - return; - } - const int ith = params->ith; const int nth = params->nth; @@ -10321,8 +10314,8 @@ static void ggml_compute_forward_div( } break; default: { - GGML_ASSERT(false); - } break; + GGML_ABORT("fatal error"); + } } } @@ -10334,13 +10327,12 @@ static void ggml_compute_forward_sqr_f32( const struct ggml_tensor * src0 = dst->src[0]; - assert(params->ith == 0); - assert(ggml_are_same_shape(src0, dst)); - - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { + if (params->ith != 0) { return; } + assert(ggml_are_same_shape(src0, dst)); + const int n = ggml_nrows(src0); const int nc = src0->ne[0]; @@ -10367,8 +10359,8 @@ static void ggml_compute_forward_sqr( } break; default: { - GGML_ASSERT(false); - } break; + GGML_ABORT("fatal error"); + } } } @@ -10380,13 +10372,12 @@ static void ggml_compute_forward_sqrt_f32( const struct ggml_tensor * src0 = dst->src[0]; - assert(params->ith == 0); - assert(ggml_are_same_shape(src0, dst)); - - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { + if (params->ith != 0) { return; } + assert(ggml_are_same_shape(src0, dst)); + const int n = ggml_nrows(src0); const int nc = src0->ne[0]; @@ -10413,8 +10404,8 @@ static void ggml_compute_forward_sqrt( } break; default: { - GGML_ASSERT(false); - } break; + GGML_ABORT("fatal error"); + } } } @@ -10426,13 +10417,12 @@ static void ggml_compute_forward_log_f32( const struct ggml_tensor * src0 = dst->src[0]; - GGML_ASSERT(params->ith == 0); - GGML_ASSERT(ggml_are_same_shape(src0, dst)); - - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { + if (params->ith != 0) { return; } + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + const int n = ggml_nrows(src0); const int nc = src0->ne[0]; @@ -10459,8 +10449,8 @@ static void ggml_compute_forward_log( } break; default: { - GGML_ASSERT(false); - } break; + GGML_ABORT("fatal error"); + } } } @@ -10472,13 +10462,13 @@ static void ggml_compute_forward_sum_f32( const struct ggml_tensor * src0 = dst->src[0]; - assert(params->ith == 0); - assert(ggml_is_scalar(dst)); - - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { + if (params->ith != 0) { return; } + assert(ggml_is_scalar(dst)); + + assert(ggml_is_scalar(dst)); assert(src0->nb[0] == sizeof(float)); @@ -10507,13 +10497,12 @@ static void ggml_compute_forward_sum_f16( const struct ggml_tensor * src0 = dst->src[0]; - assert(params->ith == 0); - assert(ggml_is_scalar(dst)); - - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { + if (params->ith != 0) { return; } + assert(ggml_is_scalar(dst)); + assert(src0->nb[0] == sizeof(ggml_fp16_t)); GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) @@ -10541,13 +10530,12 @@ static void ggml_compute_forward_sum_bf16( const struct ggml_tensor * src0 = dst->src[0]; - assert(params->ith == 0); - assert(ggml_is_scalar(dst)); - - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { + if (params->ith != 0) { return; } + assert(ggml_is_scalar(dst)); + assert(src0->nb[0] == sizeof(ggml_bf16_t)); GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) @@ -10590,8 +10578,8 @@ static void ggml_compute_forward_sum( } break; default: { - GGML_ASSERT(false); - } break; + GGML_ABORT("fatal error"); + } } } @@ -10603,9 +10591,7 @@ static void ggml_compute_forward_sum_rows_f32( const struct ggml_tensor * src0 = dst->src[0]; - GGML_ASSERT(params->ith == 0); - - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { + if (params->ith != 0) { return; } @@ -10645,8 +10631,8 @@ static void ggml_compute_forward_sum_rows( } break; default: { - GGML_ASSERT(false); - } break; + GGML_ABORT("fatal error"); + } } } @@ -10658,9 +10644,7 @@ static void ggml_compute_forward_mean_f32( const struct ggml_tensor * src0 = dst->src[0]; - assert(params->ith == 0); - - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { + if (params->ith != 0) { return; } @@ -10704,8 +10688,8 @@ static void ggml_compute_forward_mean( } break; default: { - GGML_ASSERT(false); - } break; + GGML_ABORT("fatal error"); + } } } @@ -10717,9 +10701,7 @@ static void ggml_compute_forward_argmax_f32( const struct ggml_tensor * src0 = dst->src[0]; - assert(params->ith == 0); - - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { + if (params->ith != 0) { return; } @@ -10754,8 +10736,8 @@ static void ggml_compute_forward_argmax( } break; default: { - GGML_ASSERT(false); - } break; + GGML_ABORT("fatal error"); + } } } @@ -10767,13 +10749,12 @@ static void ggml_compute_forward_repeat_f32( const struct ggml_tensor * src0 = dst->src[0]; - GGML_ASSERT(params->ith == 0); - GGML_ASSERT(ggml_can_repeat(src0, dst)); - - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { + if (params->ith != 0) { return; } + GGML_ASSERT(ggml_can_repeat(src0, dst)); + GGML_TENSOR_UNARY_OP_LOCALS // guaranteed to be an integer due to the check in ggml_can_repeat @@ -10812,13 +10793,12 @@ static void ggml_compute_forward_repeat_f16( const struct ggml_tensor * src0 = dst->src[0]; - GGML_ASSERT(params->ith == 0); - GGML_ASSERT(ggml_can_repeat(src0, dst)); - - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { + if (params->ith != 0) { return; } + GGML_ASSERT(ggml_can_repeat(src0, dst)); + GGML_TENSOR_UNARY_OP_LOCALS // guaranteed to be an integer due to the check in ggml_can_repeat @@ -10874,8 +10854,8 @@ static void ggml_compute_forward_repeat( } break; default: { - GGML_ASSERT(false); - } break; + GGML_ABORT("fatal error"); + } } } @@ -10887,13 +10867,12 @@ static void ggml_compute_forward_repeat_back_f32( const struct ggml_tensor * src0 = dst->src[0]; - GGML_ASSERT(params->ith == 0); - GGML_ASSERT(ggml_can_repeat(dst, src0)); - - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { + if (params->ith != 0) { return; } + GGML_ASSERT(ggml_can_repeat(dst, src0)); + GGML_TENSOR_UNARY_OP_LOCALS // guaranteed to be an integer due to the check in ggml_can_repeat @@ -10953,8 +10932,8 @@ static void ggml_compute_forward_repeat_back( } break; default: { - GGML_ASSERT(false); - } break; + GGML_ABORT("fatal error"); + } } } @@ -10967,10 +10946,6 @@ static void ggml_compute_forward_concat_f32( const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * src1 = dst->src[1]; - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { - return; - } - GGML_ASSERT(src0->nb[0] == sizeof(float)); const int ith = params->ith; @@ -11026,8 +11001,8 @@ static void ggml_compute_forward_concat( } break; default: { - GGML_ASSERT(false); - } break; + GGML_ABORT("fatal error"); + } } } @@ -11039,19 +11014,17 @@ static void ggml_compute_forward_abs_f32( const struct ggml_tensor * src0 = dst->src[0]; - assert(params->ith == 0); - assert(ggml_are_same_shape(src0, dst)); - - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { + if (params->ith != 0) { return; } + assert(ggml_is_contiguous_1(src0)); + assert(ggml_is_contiguous_1(dst)); + assert(ggml_are_same_shape(src0, dst)); + const int n = ggml_nrows(src0); const int nc = src0->ne[0]; - assert(dst->nb[0] == sizeof(float)); - assert(src0->nb[0] == sizeof(float)); - for (int i = 0; i < n; i++) { ggml_vec_abs_f32(nc, (float *) ((char *) dst->data + i*( dst->nb[1])), @@ -11072,8 +11045,8 @@ static void ggml_compute_forward_abs( } break; default: { - GGML_ASSERT(false); - } break; + GGML_ABORT("fatal error"); + } } } @@ -11085,19 +11058,17 @@ static void ggml_compute_forward_sgn_f32( const struct ggml_tensor * src0 = dst->src[0]; - assert(params->ith == 0); - assert(ggml_are_same_shape(src0, dst)); - - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { + if (params->ith != 0) { return; } + assert(ggml_is_contiguous_1(src0)); + assert(ggml_is_contiguous_1(dst)); + assert(ggml_are_same_shape(src0, dst)); + const int n = ggml_nrows(src0); const int nc = src0->ne[0]; - assert(dst->nb[0] == sizeof(float)); - assert(src0->nb[0] == sizeof(float)); - for (int i = 0; i < n; i++) { ggml_vec_sgn_f32(nc, (float *) ((char *) dst->data + i*( dst->nb[1])), @@ -11118,8 +11089,8 @@ static void ggml_compute_forward_sgn( } break; default: { - GGML_ASSERT(false); - } break; + GGML_ABORT("fatal error"); + } } } @@ -11131,19 +11102,17 @@ static void ggml_compute_forward_neg_f32( const struct ggml_tensor * src0 = dst->src[0]; - assert(params->ith == 0); - assert(ggml_are_same_shape(src0, dst)); - - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { + if (params->ith != 0) { return; } + assert(ggml_is_contiguous_1(src0)); + assert(ggml_is_contiguous_1(dst)); + assert(ggml_are_same_shape(src0, dst)); + const int n = ggml_nrows(src0); const int nc = src0->ne[0]; - assert(dst->nb[0] == sizeof(float)); - assert(src0->nb[0] == sizeof(float)); - for (int i = 0; i < n; i++) { ggml_vec_neg_f32(nc, (float *) ((char *) dst->data + i*( dst->nb[1])), @@ -11164,8 +11133,8 @@ static void ggml_compute_forward_neg( } break; default: { - GGML_ASSERT(false); - } break; + GGML_ABORT("fatal error"); + } } } @@ -11177,19 +11146,17 @@ static void ggml_compute_forward_step_f32( const struct ggml_tensor * src0 = dst->src[0]; - assert(params->ith == 0); - assert(ggml_are_same_shape(src0, dst)); - - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { + if (params->ith != 0) { return; } + assert(ggml_is_contiguous_1(src0)); + assert(ggml_is_contiguous_1(dst)); + assert(ggml_are_same_shape(src0, dst)); + const int n = ggml_nrows(src0); const int nc = src0->ne[0]; - assert(dst->nb[0] == sizeof(float)); - assert(src0->nb[0] == sizeof(float)); - for (int i = 0; i < n; i++) { ggml_vec_step_f32(nc, (float *) ((char *) dst->data + i*( dst->nb[1])), @@ -11210,8 +11177,8 @@ static void ggml_compute_forward_step( } break; default: { - GGML_ASSERT(false); - } break; + GGML_ABORT("fatal error"); + } } } @@ -11223,19 +11190,17 @@ static void ggml_compute_forward_tanh_f32( const struct ggml_tensor * src0 = dst->src[0]; - assert(params->ith == 0); - assert(ggml_are_same_shape(src0, dst)); - - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { + if (params->ith != 0) { return; } + assert(ggml_is_contiguous_1(src0)); + assert(ggml_is_contiguous_1(dst)); + assert(ggml_are_same_shape(src0, dst)); + const int n = ggml_nrows(src0); const int nc = src0->ne[0]; - assert(dst->nb[0] == sizeof(float)); - assert(src0->nb[0] == sizeof(float)); - for (int i = 0; i < n; i++) { ggml_vec_tanh_f32(nc, (float *) ((char *) dst->data + i*( dst->nb[1])), @@ -11256,8 +11221,8 @@ static void ggml_compute_forward_tanh( } break; default: { - GGML_ASSERT(false); - } break; + GGML_ABORT("fatal error"); + } } } @@ -11269,19 +11234,17 @@ static void ggml_compute_forward_elu_f32( const struct ggml_tensor * src0 = dst->src[0]; - assert(params->ith == 0); - assert(ggml_are_same_shape(src0, dst)); - - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { + if (params->ith != 0) { return; } + assert(ggml_is_contiguous_1(src0)); + assert(ggml_is_contiguous_1(dst)); + assert(ggml_are_same_shape(src0, dst)); + const int n = ggml_nrows(src0); const int nc = src0->ne[0]; - assert(dst->nb[0] == sizeof(float)); - assert(src0->nb[0] == sizeof(float)); - for (int i = 0; i < n; i++) { ggml_vec_elu_f32(nc, (float *) ((char *) dst->data + i*( dst->nb[1])), @@ -11302,8 +11265,8 @@ static void ggml_compute_forward_elu( } break; default: { - GGML_ASSERT(false); - } break; + GGML_ABORT("fatal error"); + } } } @@ -11315,19 +11278,17 @@ static void ggml_compute_forward_relu_f32( const struct ggml_tensor * src0 = dst->src[0]; - assert(params->ith == 0); - assert(ggml_are_same_shape(src0, dst)); - - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { + if (params->ith != 0) { return; } + assert(ggml_is_contiguous_1(src0)); + assert(ggml_is_contiguous_1(dst)); + assert(ggml_are_same_shape(src0, dst)); + const int n = ggml_nrows(src0); const int nc = src0->ne[0]; - assert(dst->nb[0] == sizeof(float)); - assert(src0->nb[0] == sizeof(float)); - for (int i = 0; i < n; i++) { ggml_vec_relu_f32(nc, (float *) ((char *) dst->data + i*( dst->nb[1])), @@ -11348,8 +11309,8 @@ static void ggml_compute_forward_relu( } break; default: { - GGML_ASSERT(false); - } break; + GGML_ABORT("fatal error"); + } } } @@ -11361,19 +11322,17 @@ static void ggml_compute_forward_sigmoid_f32( const struct ggml_tensor * src0 = dst->src[0]; - assert(params->ith == 0); - assert(ggml_are_same_shape(src0, dst)); - - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { + if (params->ith != 0) { return; } + assert(ggml_is_contiguous_1(src0)); + assert(ggml_is_contiguous_1(dst)); + assert(ggml_are_same_shape(src0, dst)); + const int n = ggml_nrows(src0); const int nc = src0->ne[0]; - assert(dst->nb[0] == sizeof(float)); - assert(src0->nb[0] == sizeof(float)); - for (int i = 0; i < n; i++) { ggml_vec_sigmoid_f32(nc, (float *) ((char *) dst->data + i*( dst->nb[1])), @@ -11394,8 +11353,8 @@ static void ggml_compute_forward_sigmoid( } break; default: { - GGML_ASSERT(false); - } break; + GGML_ABORT("fatal error"); + } } } @@ -11407,13 +11366,9 @@ static void ggml_compute_forward_gelu_f32( const struct ggml_tensor * src0 = dst->src[0]; - GGML_ASSERT(ggml_is_contiguous_1(src0)); - GGML_ASSERT(ggml_is_contiguous_1(dst)); - GGML_ASSERT(ggml_are_same_shape(src0, dst)); - - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { - return; - } + assert(ggml_is_contiguous_1(src0)); + assert(ggml_is_contiguous_1(dst)); + assert(ggml_are_same_shape(src0, dst)); const int ith = params->ith; const int nth = params->nth; @@ -11457,8 +11412,8 @@ static void ggml_compute_forward_gelu( } break; default: { - GGML_ASSERT(false); - } break; + GGML_ABORT("fatal error"); + } } } @@ -11470,13 +11425,9 @@ static void ggml_compute_forward_gelu_quick_f32( const struct ggml_tensor * src0 = dst->src[0]; - GGML_ASSERT(ggml_is_contiguous_1(src0)); - GGML_ASSERT(ggml_is_contiguous_1(dst)); - GGML_ASSERT(ggml_are_same_shape(src0, dst)); - - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { - return; - } + assert(ggml_is_contiguous_1(src0)); + assert(ggml_is_contiguous_1(dst)); + assert(ggml_are_same_shape(src0, dst)); const int ith = params->ith; const int nth = params->nth; @@ -11520,8 +11471,8 @@ static void ggml_compute_forward_gelu_quick( } break; default: { - GGML_ASSERT(false); - } break; + GGML_ABORT("fatal error"); + } } } @@ -11533,13 +11484,9 @@ static void ggml_compute_forward_silu_f32( const struct ggml_tensor * src0 = dst->src[0]; - GGML_ASSERT(ggml_is_contiguous_1(src0)); - GGML_ASSERT(ggml_is_contiguous_1(dst)); - GGML_ASSERT(ggml_are_same_shape(src0, dst)); - - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { - return; - } + assert(ggml_is_contiguous_1(src0)); + assert(ggml_is_contiguous_1(dst)); + assert(ggml_are_same_shape(src0, dst)); const int ith = params->ith; const int nth = params->nth; @@ -11583,8 +11530,8 @@ static void ggml_compute_forward_silu( } break; default: { - GGML_ASSERT(false); - } break; + GGML_ABORT("fatal error"); + } } } // ggml_compute_forward_leaky_relu @@ -11595,13 +11542,14 @@ static void ggml_compute_forward_leaky_relu_f32( const struct ggml_tensor * src0 = dst->src[0]; - assert(params->ith == 0); - assert(ggml_are_same_shape(src0, dst)); - - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { + if (params->ith != 0) { return; } + assert(ggml_is_contiguous_1(src0)); + assert(ggml_is_contiguous_1(dst)); + assert(ggml_are_same_shape(src0, dst)); + const int n = ggml_nrows(src0); const int nc = src0->ne[0]; @@ -11631,8 +11579,8 @@ static void ggml_compute_forward_leaky_relu( } break; default: { - GGML_ASSERT(false); - } break; + GGML_ABORT("fatal error"); + } } } @@ -11645,15 +11593,11 @@ static void ggml_compute_forward_silu_back_f32( const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * grad = dst->src[1]; - GGML_ASSERT(ggml_is_contiguous_1(grad)); - GGML_ASSERT(ggml_is_contiguous_1(src0)); - GGML_ASSERT(ggml_is_contiguous_1(dst)); - GGML_ASSERT(ggml_are_same_shape(src0, dst)); - GGML_ASSERT(ggml_are_same_shape(src0, grad)); - - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { - return; - } + assert(ggml_is_contiguous_1(grad)); + assert(ggml_is_contiguous_1(src0)); + assert(ggml_is_contiguous_1(dst)); + assert(ggml_are_same_shape(src0, dst)); + assert(ggml_are_same_shape(src0, grad)); const int ith = params->ith; const int nth = params->nth; @@ -11698,8 +11642,8 @@ static void ggml_compute_forward_silu_back( } break; default: { - GGML_ASSERT(false); - } break; + GGML_ABORT("fatal error"); + } } } @@ -11710,19 +11654,17 @@ static void ggml_compute_forward_hardswish_f32( const struct ggml_tensor * src0 = dst->src[0]; - assert(params->ith == 0); - assert(ggml_are_same_shape(src0, dst)); - - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { + if (params->ith != 0) { return; } + assert(ggml_is_contiguous_1(src0)); + assert(ggml_is_contiguous_1(dst)); + assert(ggml_are_same_shape(src0, dst)); + const int n = ggml_nrows(src0); const int nc = src0->ne[0]; - assert(dst->nb[0] == sizeof(float)); - assert(src0->nb[0] == sizeof(float)); - for (int i = 0; i < n; i++) { ggml_vec_hardswish_f32(nc, (float *) ((char *) dst->data + i*( dst->nb[1])), @@ -11742,8 +11684,8 @@ static void ggml_compute_forward_hardswish( } break; default: { - GGML_ASSERT(false); - } break; + GGML_ABORT("fatal error"); + } } } @@ -11753,19 +11695,17 @@ static void ggml_compute_forward_hardsigmoid_f32( const struct ggml_tensor * src0 = dst->src[0]; - assert(params->ith == 0); - assert(ggml_are_same_shape(src0, dst)); - - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { + if (params->ith != 0) { return; } + assert(ggml_is_contiguous_1(src0)); + assert(ggml_is_contiguous_1(dst)); + assert(ggml_are_same_shape(src0, dst)); + const int n = ggml_nrows(src0); const int nc = src0->ne[0]; - assert(dst->nb[0] == sizeof(float)); - assert(src0->nb[0] == sizeof(float)); - for (int i = 0; i < n; i++) { ggml_vec_hardsigmoid_f32(nc, (float *) ((char *) dst->data + i*( dst->nb[1])), @@ -11786,8 +11726,8 @@ static void ggml_compute_forward_hardsigmoid( } break; default: { - GGML_ASSERT(false); - } break; + GGML_ABORT("fatal error"); + } } } @@ -11802,10 +11742,6 @@ static void ggml_compute_forward_norm_f32( GGML_ASSERT(ggml_are_same_shape(src0, dst)); - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { - return; - } - GGML_ASSERT(src0->nb[0] == sizeof(float)); const int ith = params->ith; @@ -11862,8 +11798,8 @@ static void ggml_compute_forward_norm( } break; default: { - GGML_ASSERT(false); - } break; + GGML_ABORT("fatal error"); + } } } @@ -11877,10 +11813,6 @@ static void ggml_compute_forward_rms_norm_f32( GGML_ASSERT(ggml_are_same_shape(src0, dst)); - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { - return; - } - GGML_ASSERT(src0->nb[0] == sizeof(float)); const int ith = params->ith; @@ -11934,8 +11866,8 @@ static void ggml_compute_forward_rms_norm( } break; default: { - GGML_ASSERT(false); - } break; + GGML_ABORT("fatal error"); + } } } @@ -11948,10 +11880,6 @@ static void ggml_compute_forward_rms_norm_back_f32( GGML_ASSERT(ggml_are_same_shape(src0, dst) && ggml_are_same_shape(src0, src1)); - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { - return; - } - GGML_ASSERT(src0->nb[0] == sizeof(float)); const int ith = params->ith; @@ -12111,8 +12039,8 @@ static void ggml_compute_forward_rms_norm_back( } break; default: { - GGML_ASSERT(false); - } break; + GGML_ABORT("fatal error"); + } } } @@ -12126,10 +12054,6 @@ static void ggml_compute_forward_group_norm_f32( GGML_ASSERT(ggml_are_same_shape(src0, dst)); - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { - return; - } - GGML_ASSERT(src0->nb[0] == sizeof(float)); const int ith = params->ith; @@ -12209,46 +12133,13 @@ static void ggml_compute_forward_group_norm( } break; default: { - GGML_ASSERT(false); - } break; + GGML_ABORT("fatal error"); + } } } // ggml_compute_forward_mul_mat -#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) -// helper function to determine if it is better to use BLAS or not -// for large matrices, BLAS is faster -static bool ggml_compute_forward_mul_mat_use_blas(struct ggml_tensor * dst) { - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; - - //const int64_t ne00 = src0->ne[0]; - //const int64_t ne01 = src0->ne[1]; - - const int64_t ne10 = src1->ne[0]; - - const int64_t ne0 = dst->ne[0]; - const int64_t ne1 = dst->ne[1]; - - // NOTE: with GGML_OP_MUL_MAT_ID we don't want to go through the BLAS branch because it will dequantize (to_float) - // all the experts for each batch element and the processing would become incredibly slow - // TODO: find the optimal values for these - if (dst->op != GGML_OP_MUL_MAT_ID && - ggml_is_contiguous(src0) && - ggml_is_contiguous(src1) && - //src0->type == GGML_TYPE_F32 && - src1->type == GGML_TYPE_F32 && - (ne0 >= 32 && ne1 >= 32 && ne10 >= 32)) { - - /*printf("BLAS: %d %d %d %d %d\n", ne0, ne1, ne10, ne00, ne01);*/ - return true; - } - - return false; -} -#endif - static void ggml_compute_forward_mul_mat_one_chunk( const struct ggml_compute_params * params, struct ggml_tensor * dst, @@ -12267,8 +12158,8 @@ static void ggml_compute_forward_mul_mat_one_chunk( const bool src1_cont = ggml_is_contiguous(src1); - ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot; - enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type; + ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot; + enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type; // broadcast factors const int64_t r2 = ne12 / ne02; @@ -12342,15 +12233,11 @@ static void ggml_compute_forward_mul_mat_one_chunk( static void ggml_compute_forward_mul_mat( const struct ggml_compute_params * params, - struct ggml_tensor * dst, - struct ggml_compute_state * state) { + struct ggml_tensor * dst) { const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * src1 = dst->src[1]; - int64_t t0 = ggml_perf_time_us(); - UNUSED(t0); - GGML_TENSOR_BINARY_OP_LOCALS const int ith = params->ith; @@ -12358,9 +12245,14 @@ static void ggml_compute_forward_mul_mat( const enum ggml_type type = src0->type; - enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type; - ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float; - int64_t const vec_dot_num_rows = type_traits[type].nrows; + enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type; + ggml_from_float_t const from_float = type_traits[vec_dot_type].from_float; + ggml_from_float_to_mat_t const from_float_to_mat = type_traits[vec_dot_type].from_float_to_mat; + int64_t const vec_dot_num_rows = type_traits[type].nrows; + int64_t const matmul_num_cols = type_traits[type].ncols; + int64_t const blck_size_interleave = type_traits[type].blck_size_interleave; + ggml_gemv_t const gemv = type_traits[type].gemv; + ggml_gemm_t const gemm = type_traits[type].gemm; GGML_ASSERT(ne0 == ne01); GGML_ASSERT(ne1 == ne11); @@ -12377,83 +12269,14 @@ static void ggml_compute_forward_mul_mat( GGML_ASSERT(nb1 <= nb2); GGML_ASSERT(nb2 <= nb3); - // broadcast factors - const int64_t r2 = ne12 / ne02; - const int64_t r3 = ne13 / ne03; - UNUSED(r2); - UNUSED(r3); - // nb01 >= nb00 - src0 is not transposed // compute by src0 rows -#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) - if (ggml_compute_forward_mul_mat_use_blas(dst)) { - const int64_t ne_plane = ne01*ne00; - const size_t desired_wsize = ne13*ne12*ne_plane*sizeof(float); - UNUSED(desired_wsize); - - if (params->type == GGML_TASK_TYPE_INIT) { - if (type != GGML_TYPE_F32) { - assert(params->wsize >= desired_wsize); - // parallelize by src0 rows - for (int64_t i13 = 0; i13 < ne13; i13++) { - for (int64_t i12 = 0; i12 < ne12; i12++) { - // broadcast src0 into src1 across 2nd,3rd dimension - const int64_t i03 = i13/r3; - const int64_t i02 = i12/r2; - - const void * x = (char *) src0->data + i02*nb02 + i03*nb03; - float * const wdata = (float *) params->wdata + i13*ne12*ne_plane + i12*ne_plane; - ggml_to_float_t const to_float = type_traits[type].to_float; - - for (int64_t i01 = ith; i01 < ne01; i01 += nth) { - to_float((const char *) x + i01*nb01, wdata + i01*ne00, ne00); - } - } - } - } - return; - } - - if (params->type == GGML_TASK_TYPE_FINALIZE) { - return; - } - - // perform sgemm, parallelization controlled by blas lib - if (ith != 0) { - return; - } - - //const int64_t tgemm0 = ggml_perf_time_us(); - for (int64_t i13 = 0; i13 < ne13; i13++) { - for (int64_t i12 = 0; i12 < ne12; i12++) { - const int64_t i03 = i13/r3; - const int64_t i02 = i12/r2; - - const void * x = (char *) src0->data + i02*nb02 + i03*nb03; - const float * y = (float *) ((char *) src1->data + i12*nb12 + i13*nb13); - float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3); - - if (type != GGML_TYPE_F32) { - x = (float *) params->wdata + i13*ne12*ne_plane + i12*ne_plane; - } - - cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, - ne1, ne01, ne10, - 1.0f, y, ne10, - x, ne00, - 0.0f, d, ne01); - } - } - //printf("cblas_sgemm = %.3f ms, %lld flops\n", (ggml_perf_time_us() - tgemm0)/1000.0, ne13*ne12*ne1*ne01*ne10*2); - - //printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3); - - return; - } -#endif - #if GGML_USE_LLAMAFILE + // broadcast factors + const int64_t r2 = ne12 / ne02; + const int64_t r3 = ne13 / ne03; + const bool src1_cont = ggml_is_contiguous(src1); if (src1_cont) { @@ -12467,7 +12290,6 @@ static void ggml_compute_forward_mul_mat( (char *)dst->data + i12*nb2 + i13*nb3, nb1/ggml_type_size(dst->type), ith, nth, - params->type, src0->type, src1->type, dst->type)) @@ -12477,36 +12299,43 @@ static void ggml_compute_forward_mul_mat( UseGgmlGemm1:; #endif - if (params->type == GGML_TASK_TYPE_INIT) { - if (ith != 0) { - return; - } - // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start. - atomic_store(&state->shared->current_chunk, nth); - if (src1->type != vec_dot_type) { - char * wdata = params->wdata; - const size_t row_size = ggml_row_size(vec_dot_type, ne10); + if (src1->type != vec_dot_type) { + char * wdata = params->wdata; - assert(params->wsize >= ne11*ne12*ne13*row_size); - GGML_ASSERT(src1->type == GGML_TYPE_F32); + const size_t nbw1 = ggml_row_size(vec_dot_type, ne10); + const size_t nbw2 = nbw1*ne11; + const size_t nbw3 = nbw2*ne12; - for (int64_t i13 = 0; i13 < ne13; ++i13) { - for (int64_t i12 = 0; i12 < ne12; ++i12) { - for (int64_t i11 = 0; i11 < ne11; ++i11) { - from_float_to_vec_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10); - wdata += row_size; + assert(params->wsize >= ne13*nbw3); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + + for (int64_t i13 = 0; i13 < ne13; ++i13) { + for (int64_t i12 = 0; i12 < ne12; ++i12) { + int64_t i11_processed = 0; + if ((ggml_n_dims(src1) == 2) && from_float_to_mat && gemm) { + for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) { + from_float_to_mat((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), + (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1), + 4, ne10, blck_size_interleave); } + i11_processed = ne11 - ne11 % 4; + } + for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) { + from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), + (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1), + ne10); } } } - - return; } - if (params->type == GGML_TASK_TYPE_FINALIZE) { - return; + if (ith == 0) { + // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start. + atomic_store(¶ms->shared->current_chunk, nth); } + ggml_barrier(params->shared); + #if GGML_USE_LLAMAFILE if (src1->type != vec_dot_type) { const void* wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; @@ -12522,7 +12351,6 @@ UseGgmlGemm1:; (char *)dst->data + i12*nb2 + i13*nb3, nb1/ggml_type_size(dst->type), ith, nth, - params->type, src0->type, vec_dot_type, dst->type)) @@ -12532,11 +12360,6 @@ UseGgmlGemm1:; UseGgmlGemm2:; #endif -#ifdef GGML_PERF - int chunks_executed = 0; - UNUSED(chunks_executed); -#endif - // This is the size of the first dimension of the result, so we can iterate that way. (see the ASSERT above, these are the same numbers) const int64_t nr0 = ne0; @@ -12578,8 +12401,27 @@ UseGgmlGemm2:; const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0; const int64_t dr1 = (nr1 + nchunk1 - 1) / nchunk1; - //if (ith == 0) - // printf("MUL_MAT = [%d, %d, %d, %d] x [%d, %d, %d, %d] = %d x %d = %d. Fp Ops/Ch %d\n", ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nchunk0, nchunk1, nchunk0 * nchunk1, ne00 * nr0 * nr1 / nchunk0 / nchunk1); + if ((ggml_n_dims(src0) == 2) && gemv) { + const void * src1_wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; + const size_t src1_col_stride = ggml_is_contiguous(src1) || src1->type != vec_dot_type ? ggml_row_size(vec_dot_type, ne10) : nb11; + int64_t src0_start = (ith * ne01) / nth; + int64_t src0_end = ((ith + 1) * ne01) / nth; + src0_start = (src0_start % matmul_num_cols) ? src0_start + matmul_num_cols - (src0_start % matmul_num_cols): src0_start; + src0_end = (src0_end % matmul_num_cols) ? src0_end + matmul_num_cols - (src0_end % matmul_num_cols): src0_end; + if (src0_start >= src0_end) return; + + // If there are more than three rows in src1, use gemm; otherwise, use gemv. + if (gemm && (ne11 > 3)) { + gemm(ne00, (float *)((char *) dst->data) + src0_start, ne01, (const char *) src0->data + src0_start * nb01, + (const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start); + } + for (int iter = gemm ? ne11 - ne11 % 4 : 0; iter < ne11; iter++) { + gemv(ne00, (float *)((char *) dst->data + (iter * nb1)) + src0_start, ne01, + (const char *) src0->data + src0_start * nb01, (const char *) src1_wdata + (src1_col_stride * iter), 1, + src0_end - src0_start); + } + return; + } // The first chunk comes from our thread_id, the rest will get auto-assigned. int current_chunk = ith; @@ -12596,23 +12438,12 @@ UseGgmlGemm2:; ggml_compute_forward_mul_mat_one_chunk(params, dst, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end); -#ifdef GGML_PERF - chunks_executed++; -#endif - if (nth >= nchunk0 * nchunk1) { break; } - current_chunk = atomic_fetch_add(&state->shared->current_chunk, 1); + current_chunk = atomic_fetch_add(¶ms->shared->current_chunk, 1); } - -#ifdef GGML_PERF - // These numbers are useful when trying to measure how well the threading scheduling works. - //int64_t workSize = (ne01 * ne11 * ne12 * ne13 * ne00) / nchunk0 / nchunk1; - //float time = (ggml_perf_time_us() - t0); - //printf("MUL_MAT = %f ms, [%d, %d, %d, %d] x [%d, %d, %d, %d] = %I64u, %f ops/usec in %d chunks.\n", time / 1000.0, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, workSize, (float)workSize/time, chunks_executed); -#endif } // ggml_compute_forward_mul_mat_id @@ -12634,9 +12465,11 @@ static void ggml_compute_forward_mul_mat_id( const bool src1_cont = ggml_is_contiguous(src1); - ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot; - enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type; - ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float; + ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot; + enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type; + ggml_from_float_t const from_float = type_traits[vec_dot_type].from_float; + int64_t const matmul_num_cols = type_traits[type].ncols; + ggml_gemv_t const gemv = type_traits[type].gemv; // we don't support permuted src0 or src1 GGML_ASSERT(nb00 == ggml_type_size(type)); @@ -12664,32 +12497,33 @@ static void ggml_compute_forward_mul_mat_id( int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as] struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *)(matrix_row_counts + n_as); // [n_as][ne11] - if (params->type == GGML_TASK_TYPE_INIT) { - if (ith != 0) { - return; - } + if (src1->type != vec_dot_type) { char * wdata = params->wdata; - if (src1->type != vec_dot_type) { - const size_t row_size = ggml_row_size(vec_dot_type, ne10); - assert(params->wsize >= ne11*ne12*ne13*row_size); - assert(src1->type == GGML_TYPE_F32); + const size_t nbw1 = ggml_row_size(vec_dot_type, ne10); + const size_t nbw2 = nbw1*ne11; + const size_t nbw3 = nbw2*ne12; - for (int64_t i13 = 0; i13 < ne13; ++i13) { - for (int64_t i12 = 0; i12 < ne12; ++i12) { - for (int64_t i11 = 0; i11 < ne11; ++i11) { - from_float_to_vec_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10); - wdata += row_size; - } + assert(params->wsize >= ne13*nbw3); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + + for (int64_t i13 = 0; i13 < ne13; ++i13) { + for (int64_t i12 = 0; i12 < ne12; ++i12) { + for (int64_t i11 = ith; i11 < ne11; i11 += nth) { + from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), + (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1), + ne10); } } } - - // initialize matrix_row_counts - memset(matrix_row_counts, 0, n_as*sizeof(int64_t)); + } #define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne12 + (i1)] + if (ith == 0) { + // initialize matrix_row_counts + memset(matrix_row_counts, 0, n_as*sizeof(int64_t)); + // group rows by src0 matrix for (int64_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) { for (int id = 0; id < n_ids; ++id) { @@ -12701,13 +12535,9 @@ static void ggml_compute_forward_mul_mat_id( matrix_row_counts[i02] += 1; } } - - return; } - if (params->type == GGML_TASK_TYPE_FINALIZE) { - return; - } + ggml_barrier(params->shared); // compute each matrix multiplication in sequence for (int cur_a = 0; cur_a < n_as; ++cur_a) { @@ -12725,6 +12555,34 @@ static void ggml_compute_forward_mul_mat_id( const int64_t nr0 = ne01; // src0 rows const int64_t nr1 = cne1; // src1 rows + if (((ggml_n_dims(src0) - 1) == 2) && gemv) { + int64_t src0_cur_start = (ith * ne01) / nth; + int64_t src0_cur_end = ((ith + 1) * ne01) / nth; + src0_cur_start = (src0_cur_start % matmul_num_cols) ? src0_cur_start + matmul_num_cols - (src0_cur_start % matmul_num_cols): src0_cur_start; + src0_cur_end = (src0_cur_end % matmul_num_cols) ? src0_cur_end + matmul_num_cols - (src0_cur_end % matmul_num_cols): src0_cur_end; + if (src0_cur_start >= src0_cur_end) return; + + for (int ir1 = 0; ir1 < nr1; ir1++) { + struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, ir1); + const int id = row_mapping.i1; // selected expert index + + const int64_t i11 = id % ne11; + const int64_t i12 = row_mapping.i2; // row index in src1 + + const int64_t i1 = id; // selected expert index + const int64_t i2 = i12; // row + + const char * src1_col = (const char *) wdata + + (src1_cont || src1->type != vec_dot_type + ? (i11 + i12 * ne11) * row_size + : (i11 * nb11 + i12 * nb12)); + + gemv(ne00, (float *)((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start, ne01, + (const char *) src0_cur + src0_cur_start * nb01, src1_col, 1, src0_cur_end - src0_cur_start); + } + continue; + } + // distribute the thread work across the inner or outer loop based on which one is larger const int64_t nth0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows @@ -12806,9 +12664,6 @@ static void ggml_compute_forward_out_prod_f32( const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * src1 = dst->src[1]; - // int64_t t0 = ggml_perf_time_us(); - // UNUSED(t0); - GGML_TENSOR_BINARY_OP_LOCALS const int ith = params->ith; @@ -12833,73 +12688,10 @@ static void ggml_compute_forward_out_prod_f32( // nb01 >= nb00 - src0 is not transposed // compute by src0 rows -#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) - bool use_blas = ggml_is_matrix(src0) && - ggml_is_matrix(src1) && - ggml_is_contiguous(src0) && - (ggml_is_contiguous(src1) || ggml_is_transposed(src1)); -#endif - - if (params->type == GGML_TASK_TYPE_INIT) { -#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) // gemm beta will zero dst - if (use_blas) { - return; - } -#endif - if (ith != 0) { - return; - } + if (ith == 0) { ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0); - return; } - - if (params->type == GGML_TASK_TYPE_FINALIZE) { - return; - } - -#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) - if (use_blas) { - if (params->ith != 0) { // All threads other than the first do no work. - return; - } - // Arguments to ggml_compute_forward_out_prod (expressed as major,minor) - // src0: (k,n) - // src1: (k,m) - // dst: (m,n) - // - // Arguments to sgemm (see https://github.com/Reference-LAPACK/lapack/blob/master/BLAS/SRC/sgemm.f) - // Also expressed as (major,minor) - // a: (m,k): so src1 transposed - // b: (k,n): so src0 - // c: (m,n) - // - // However, if ggml_is_transposed(src1) is true, then - // src1->data already contains a transposed version, so sgemm mustn't - // transpose it further. - - int n = src0->ne[0]; - int k = src0->ne[1]; - int m = src1->ne[0]; - - int transposeA, lda; - - if (!ggml_is_transposed(src1)) { - transposeA = CblasTrans; - lda = m; - } else { - transposeA = CblasNoTrans; - lda = k; - } - - float * a = (float *) ((char *) src1->data); - float * b = (float *) ((char *) src0->data); - float * c = (float *) ((char *) dst->data); - - cblas_sgemm(CblasRowMajor, transposeA, CblasNoTrans, m, n, k, 1.0, a, lda, b, n, 0.0, c, n); - - return; - } -#endif + ggml_barrier(params->shared); // dst[:,:,:,:] = 0 // for i2,i3: @@ -12975,19 +12767,6 @@ static void ggml_compute_forward_out_prod_f32( } } } - - //int64_t t1 = ggml_perf_time_us(); - //static int64_t acc = 0; - //acc += t1 - t0; - //if (t1 - t0 > 10) { - // printf("\n"); - // printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03); - // printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03); - // printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13); - // printf("nb10 = %5d, nb11 = %5d, nb12 = %5d, nb13 = %5d\n", nb10, nb11, nb12, nb13); - - // printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc); - //} } static void ggml_compute_forward_out_prod_q_f32( @@ -12997,9 +12776,6 @@ static void ggml_compute_forward_out_prod_q_f32( const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * src1 = dst->src[1]; - // int64_t t0 = ggml_perf_time_us(); - // UNUSED(t0); - GGML_TENSOR_BINARY_OP_LOCALS; const int ith = params->ith; @@ -13030,19 +12806,10 @@ static void ggml_compute_forward_out_prod_q_f32( // nb01 >= nb00 - src0 is not transposed // compute by src0 rows - // TODO: #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) - - if (params->type == GGML_TASK_TYPE_INIT) { - if (ith != 0) { - return; - } + if (ith == 0) { ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0); - return; - } - - if (params->type == GGML_TASK_TYPE_FINALIZE) { - return; } + ggml_barrier(params->shared); // parallelize by last three dimensions @@ -13089,19 +12856,6 @@ static void ggml_compute_forward_out_prod_q_f32( ggml_vec_mad_f32(ne0, d, wdata, *s1); } } - - //int64_t t1 = ggml_perf_time_us(); - //static int64_t acc = 0; - //acc += t1 - t0; - //if (t1 - t0 > 10) { - // printf("\n"); - // printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03); - // printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03); - // printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13); - // printf("nb10 = %5d, nb11 = %5d, nb12 = %5d, nb13 = %5d\n", nb10, nb11, nb12, nb13); - - // printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc); - //} } static void ggml_compute_forward_out_prod( @@ -13130,22 +12884,25 @@ static void ggml_compute_forward_out_prod( case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ2_S: + case GGML_TYPE_Q4_0_4_4: + case GGML_TYPE_Q4_0_4_8: + case GGML_TYPE_Q4_0_8_8: { ggml_compute_forward_out_prod_q_f32(params, dst); } break; case GGML_TYPE_F16: { - GGML_ASSERT(false); // todo + GGML_ABORT("fatal error"); // todo // ggml_compute_forward_out_prod_f16_f32(params, dst); - } break; + } case GGML_TYPE_F32: { ggml_compute_forward_out_prod_f32(params, dst); } break; default: { - GGML_ASSERT(false); - } break; + GGML_ABORT("fatal error"); + } } } @@ -13161,10 +12918,6 @@ static void ggml_compute_forward_scale_f32( GGML_ASSERT(ggml_is_contiguous(dst)); GGML_ASSERT(ggml_are_same_shape(src0, dst)); - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { - return; - } - // scale factor float v; memcpy(&v, dst->op_params, sizeof(float)); @@ -13208,8 +12961,8 @@ static void ggml_compute_forward_scale( } break; default: { - GGML_ASSERT(false); - } break; + GGML_ABORT("fatal error"); + } } } @@ -13233,20 +12986,16 @@ static void ggml_compute_forward_set_f32( size_t offset = ((int32_t *) dst->op_params)[3]; bool inplace = (bool) ((int32_t *) dst->op_params)[4]; - if (!inplace && (params->type == GGML_TASK_TYPE_INIT)) { - if (params->ith != 0) { - return; + if (!inplace) { + if (params->ith == 0) { + // memcpy needs to be synchronized across threads to avoid race conditions. + // => do it in INIT phase + memcpy( + ((char *) dst->data), + ((char *) src0->data), + ggml_nbytes(dst)); } - // memcpy needs to be synchronized across threads to avoid race conditions. - // => do it in INIT phase - memcpy( - ((char *) dst->data), - ((char *) src0->data), - ggml_nbytes(dst)); - } - - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { - return; + ggml_barrier(params->shared); } const int ith = params->ith; @@ -13323,10 +13072,13 @@ static void ggml_compute_forward_set( case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ2_S: + case GGML_TYPE_Q4_0_4_4: + case GGML_TYPE_Q4_0_4_8: + case GGML_TYPE_Q4_0_8_8: default: { - GGML_ASSERT(false); - } break; + GGML_ABORT("fatal error"); + } } } @@ -13395,10 +13147,6 @@ static void ggml_compute_forward_get_rows_q( const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * src1 = dst->src[1]; - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { - return; - } - GGML_TENSOR_BINARY_OP_LOCALS const int64_t nc = ne00; @@ -13428,6 +13176,8 @@ static void ggml_compute_forward_get_rows_q( const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10); const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12); + assert(i01 >= 0 && i01 < ne01); + dequantize_row_q( (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03), (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc); @@ -13441,10 +13191,6 @@ static void ggml_compute_forward_get_rows_f16( const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * src1 = dst->src[1]; - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { - return; - } - GGML_TENSOR_BINARY_OP_LOCALS const int64_t nc = ne00; @@ -13471,6 +13217,8 @@ static void ggml_compute_forward_get_rows_f16( const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10); const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12); + assert(i01 >= 0 && i01 < ne01); + ggml_fp16_to_fp32_row( (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03), (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc); @@ -13484,10 +13232,6 @@ static void ggml_compute_forward_get_rows_bf16( const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * src1 = dst->src[1]; - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { - return; - } - GGML_TENSOR_BINARY_OP_LOCALS const int64_t nc = ne00; @@ -13514,7 +13258,9 @@ static void ggml_compute_forward_get_rows_bf16( const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10); const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12); - ggml_bf16_to_fp32_row( + assert(i01 >= 0 && i01 < ne01); + + ggml_bf16_to_fp32_row( (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03), (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc); } @@ -13527,10 +13273,6 @@ static void ggml_compute_forward_get_rows_f32( const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * src1 = dst->src[1]; - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { - return; - } - GGML_TENSOR_BINARY_OP_LOCALS const int64_t nc = ne00; @@ -13557,6 +13299,8 @@ static void ggml_compute_forward_get_rows_f32( const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10); const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12); + assert(i01 >= 0 && i01 < ne01); + ggml_vec_cpy_f32(nc, (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), (float *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03)); @@ -13590,6 +13334,9 @@ static void ggml_compute_forward_get_rows( case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ2_S: + case GGML_TYPE_Q4_0_4_4: + case GGML_TYPE_Q4_0_4_8: + case GGML_TYPE_Q4_0_8_8: { ggml_compute_forward_get_rows_q(params, dst); } break; @@ -13608,8 +13355,8 @@ static void ggml_compute_forward_get_rows( } break; default: { - GGML_ASSERT(false); - } break; + GGML_ABORT("fatal error"); + } } //static bool first = true; @@ -13640,21 +13387,15 @@ static void ggml_compute_forward_get_rows_back_f32_f16( const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * src1 = dst->src[1]; - GGML_ASSERT(params->ith == 0); + if (params->ith != 0) { + return; + } + GGML_ASSERT(ggml_is_contiguous(dst)); // ggml_compute_forward_dup_same_cont(params, opt0, dst); - if (params->type == GGML_TASK_TYPE_INIT) { - if (params->ith != 0) { - return; - } - memset(dst->data, 0, ggml_nbytes(dst)); - } - - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { - return; - } + memset(dst->data, 0, ggml_nbytes(dst)); const int nc = src0->ne[0]; const int nr = ggml_nelements(src1); @@ -13679,21 +13420,15 @@ static void ggml_compute_forward_get_rows_back_f32( const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * src1 = dst->src[1]; - GGML_ASSERT(params->ith == 0); + if (params->ith != 0) { + return; + } + GGML_ASSERT(ggml_is_contiguous(dst)); // ggml_compute_forward_dup_same_cont(params, opt0, dst); - if (params->type == GGML_TASK_TYPE_INIT) { - if (params->ith != 0) { - return; - } - memset(dst->data, 0, ggml_nbytes(dst)); - } - - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { - return; - } + memset(dst->data, 0, ggml_nbytes(dst)); const int nc = src0->ne[0]; const int nr = ggml_nelements(src1); @@ -13728,8 +13463,8 @@ static void ggml_compute_forward_get_rows_back( } break; default: { - GGML_ASSERT(false); - } break; + GGML_ABORT("fatal error"); + } } //static bool first = true; @@ -13759,9 +13494,7 @@ static void ggml_compute_forward_diag_f32( const struct ggml_tensor * src0 = dst->src[0]; - GGML_ASSERT(params->ith == 0); - - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { + if (params->ith != 0) { return; } @@ -13808,8 +13541,8 @@ static void ggml_compute_forward_diag( } break; default: { - GGML_ASSERT(false); - } break; + GGML_ABORT("fatal error"); + } } } @@ -13830,22 +13563,18 @@ static void ggml_compute_forward_diag_mask_f32( GGML_ASSERT(n_past >= 0); - if (!inplace && (params->type == GGML_TASK_TYPE_INIT)) { - if (ith != 0) { - return; + if (!inplace) { + if (ith == 0) { + // memcpy needs to be synchronized across threads to avoid race conditions. + // => do it in INIT phase + GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0)); + GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0)); + memcpy( + ((char *) dst->data), + ((char *) src0->data), + ggml_nbytes(dst)); } - // memcpy needs to be synchronized across threads to avoid race conditions. - // => do it in INIT phase - GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0)); - GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0)); - memcpy( - ((char *) dst->data), - ((char *) src0->data), - ggml_nbytes(dst)); - } - - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { - return; + ggml_barrier(params->shared); } // TODO: handle transposed/permuted matrices @@ -13882,8 +13611,8 @@ static void ggml_compute_forward_diag_mask_inf( } break; default: { - GGML_ASSERT(false); - } break; + GGML_ABORT("fatal error"); + } } } @@ -13900,8 +13629,8 @@ static void ggml_compute_forward_diag_mask_zero( } break; default: { - GGML_ASSERT(false); - } break; + GGML_ABORT("fatal error"); + } } } @@ -13917,10 +13646,6 @@ static void ggml_compute_forward_soft_max_f32( assert(ggml_is_contiguous(dst)); assert(ggml_are_same_shape(src0, dst)); - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { - return; - } - float scale = 1.0f; float max_bias = 0.0f; @@ -14022,11 +13747,12 @@ static void ggml_compute_forward_soft_max( } break; default: { - GGML_ASSERT(false); - } break; + GGML_ABORT("fatal error"); + } } } + // ggml_compute_forward_soft_max_back static void ggml_compute_forward_soft_max_back_f32( @@ -14042,10 +13768,6 @@ static void ggml_compute_forward_soft_max_back_f32( GGML_ASSERT(ggml_are_same_shape(src0, dst)); GGML_ASSERT(ggml_are_same_shape(src1, dst)); - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { - return; - } - // TODO: handle transposed/permuted matrices const int ith = params->ith; @@ -14121,8 +13843,8 @@ static void ggml_compute_forward_soft_max_back( } break; default: { - GGML_ASSERT(false); - } break; + GGML_ABORT("fatal error"); + } } } @@ -14134,9 +13856,7 @@ static void ggml_compute_forward_clamp_f32( const struct ggml_tensor * src0 = dst->src[0]; - assert(params->ith == 0); - - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { + if (params->ith != 0) { return; } @@ -14204,6 +13924,9 @@ static void ggml_compute_forward_clamp( case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ2_S: case GGML_TYPE_Q8_K: + case GGML_TYPE_Q4_0_4_4: + case GGML_TYPE_Q4_0_4_8: + case GGML_TYPE_Q4_0_8_8: case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: @@ -14211,8 +13934,8 @@ static void ggml_compute_forward_clamp( case GGML_TYPE_F64: case GGML_TYPE_COUNT: { - GGML_ASSERT(false); - } break; + GGML_ABORT("fatal error"); + } } } @@ -14283,10 +14006,6 @@ static void ggml_compute_forward_rope_f32( const struct ggml_tensor * src1 = dst->src[1]; const struct ggml_tensor * src2 = dst->src[2]; - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { - return; - } - float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; //const int n_past = ((int32_t *) dst->op_params)[0]; @@ -14413,10 +14132,6 @@ static void ggml_compute_forward_rope_f16( const struct ggml_tensor * src1 = dst->src[1]; const struct ggml_tensor * src2 = dst->src[2]; - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { - return; - } - float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; //const int n_past = ((int32_t *) dst->op_params)[0]; @@ -14549,8 +14264,8 @@ static void ggml_compute_forward_rope( } break; default: { - GGML_ASSERT(false); - } break; + GGML_ABORT("fatal error"); + } } } @@ -14573,8 +14288,8 @@ static void ggml_compute_forward_rope_back( } break; default: { - GGML_ASSERT(false); - } break; + GGML_ABORT("fatal error"); + } } } @@ -14591,9 +14306,6 @@ static void ggml_compute_forward_conv_transpose_1d_f16_f32( GGML_ASSERT(src1->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); - int64_t t0 = ggml_perf_time_us(); - UNUSED(t0); - GGML_TENSOR_BINARY_OP_LOCALS const int ith = params->ith; @@ -14604,10 +14316,7 @@ static void ggml_compute_forward_conv_transpose_1d_f16_f32( GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); GGML_ASSERT(nb10 == sizeof(float)); - if (params->type == GGML_TASK_TYPE_INIT) { - if (ith != 0) { - return; - } + if (ith == 0) { memset(params->wdata, 0, params->wsize); // permute kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout) @@ -14640,13 +14349,8 @@ static void ggml_compute_forward_conv_transpose_1d_f16_f32( // need to zero dst since we are accumulating into it memset(dst->data, 0, ggml_nbytes(dst)); - - return; - } - - if (params->type == GGML_TASK_TYPE_FINALIZE) { - return; } + ggml_barrier(params->shared); const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; @@ -14690,9 +14394,6 @@ static void ggml_compute_forward_conv_transpose_1d_f32( GGML_ASSERT(src1->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); - int64_t t0 = ggml_perf_time_us(); - UNUSED(t0); - GGML_TENSOR_BINARY_OP_LOCALS const int ith = params->ith; @@ -14703,10 +14404,7 @@ static void ggml_compute_forward_conv_transpose_1d_f32( GGML_ASSERT(nb00 == sizeof(float)); GGML_ASSERT(nb10 == sizeof(float)); - if (params->type == GGML_TASK_TYPE_INIT) { - if (ith != 0) { - return; - } + if (ith == 0) { memset(params->wdata, 0, params->wsize); // prepare kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout) @@ -14739,13 +14437,8 @@ static void ggml_compute_forward_conv_transpose_1d_f32( // need to zero dst since we are accumulating into it memset(dst->data, 0, ggml_nbytes(dst)); - - return; - } - - if (params->type == GGML_TASK_TYPE_FINALIZE) { - return; } + ggml_barrier(params->shared); const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; @@ -14795,8 +14488,8 @@ static void ggml_compute_forward_conv_transpose_1d( } break; default: { - GGML_ASSERT(false); - } break; + GGML_ABORT("fatal error"); + } } } @@ -14814,9 +14507,6 @@ static void ggml_compute_forward_im2col_f32( GGML_ASSERT(src1->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); - int64_t t0 = ggml_perf_time_us(); - UNUSED(t0); - GGML_TENSOR_BINARY_OP_LOCALS; const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; @@ -14847,14 +14537,6 @@ static void ggml_compute_forward_im2col_f32( GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); GGML_ASSERT(nb10 == sizeof(float)); - if (params->type == GGML_TASK_TYPE_INIT) { - return; - } - - if (params->type == GGML_TASK_TYPE_FINALIZE) { - return; - } - // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW] { float * const wdata = (float *) dst->data; @@ -14902,9 +14584,6 @@ static void ggml_compute_forward_im2col_f16( GGML_ASSERT(src1->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F16); - int64_t t0 = ggml_perf_time_us(); - UNUSED(t0); - GGML_TENSOR_BINARY_OP_LOCALS; const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; @@ -14935,14 +14614,6 @@ static void ggml_compute_forward_im2col_f16( GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); GGML_ASSERT(nb10 == sizeof(float)); - if (params->type == GGML_TASK_TYPE_INIT) { - return; - } - - if (params->type == GGML_TASK_TYPE_FINALIZE) { - return; - } - // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW] { ggml_fp16_t * const wdata = (ggml_fp16_t *) dst->data; @@ -14989,8 +14660,8 @@ static void ggml_compute_forward_im2col( } break; default: { - GGML_ASSERT(false); - } break; + GGML_ABORT("fatal error"); + } } } @@ -15008,9 +14679,6 @@ static void ggml_compute_forward_conv_transpose_2d( GGML_ASSERT(src1->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); - int64_t t0 = ggml_perf_time_us(); - UNUSED(t0); - GGML_TENSOR_BINARY_OP_LOCALS const int ith = params->ith; @@ -15021,10 +14689,7 @@ static void ggml_compute_forward_conv_transpose_2d( GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); GGML_ASSERT(nb10 == sizeof(float)); - if (params->type == GGML_TASK_TYPE_INIT) { - if (ith != 0) { - return; - } + if (ith == 0) { memset(params->wdata, 0, params->wsize); // permute kernel data (src0) from (Kw x Kh x Cout x Cin) to (Cin x Kw x Kh x Cout) @@ -15059,13 +14724,8 @@ static void ggml_compute_forward_conv_transpose_2d( } memset(dst->data, 0, ggml_nbytes(dst)); - - return; - } - - if (params->type == GGML_TASK_TYPE_FINALIZE) { - return; } + ggml_barrier(params->shared); const int32_t stride = ggml_get_op_params_i32(dst, 0); @@ -15112,10 +14772,9 @@ static void ggml_compute_forward_pool_1d_sk_p0( const struct ggml_tensor * src = dst->src[0]; - assert(src->type == GGML_TYPE_F32); - assert(params->ith == 0); + assert(src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16); - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { + if (params->ith != 0) { return; } @@ -15126,28 +14785,27 @@ static void ggml_compute_forward_pool_1d_sk_p0( const int64_t rs = dst->ne[0]; while (cdata < data_end) { - const float * const srow = (const float *)cdata; - + const void * srow = (const void *)cdata; int j = 0; - for (int64_t i = 0; i < rs; ++i) { switch (op) { case GGML_OP_POOL_AVG: drow[i] = 0; break; case GGML_OP_POOL_MAX: drow[i] = -FLT_MAX; break; - case GGML_OP_POOL_COUNT: GGML_ASSERT(false); break; + case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error"); } for (int ki = 0; ki < k; ++ki) { + const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]); switch (op) { - case GGML_OP_POOL_AVG: drow[i] += srow[j]; break; - case GGML_OP_POOL_MAX: if (srow[j] > drow[i]) drow[i] = srow[j]; break; - case GGML_OP_POOL_COUNT: GGML_ASSERT(false); break; + case GGML_OP_POOL_AVG: drow[i] += srow_j; break; + case GGML_OP_POOL_MAX: if (srow_j > drow[i]) drow[i] = srow_j; break; + case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error"); } ++j; } switch (op) { case GGML_OP_POOL_AVG: drow[i] /= k; break; case GGML_OP_POOL_MAX: break; - case GGML_OP_POOL_COUNT: GGML_ASSERT(false); break; + case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error"); } } @@ -15181,10 +14839,9 @@ static void ggml_compute_forward_pool_2d( const struct ggml_tensor * src = dst->src[0]; - GGML_ASSERT(src->type == GGML_TYPE_F32); - GGML_ASSERT(params->ith == 0); + assert(src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16); - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { + if (params->ith != 0) { return; } @@ -15217,7 +14874,7 @@ static void ggml_compute_forward_pool_2d( switch (op) { case GGML_OP_POOL_AVG: *out = 0; break; case GGML_OP_POOL_MAX: *out = -FLT_MAX; break; - case GGML_OP_POOL_COUNT: GGML_ASSERT(false); break; + case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error"); } const int ix = offset0 + ox * s0; @@ -15225,21 +14882,22 @@ static void ggml_compute_forward_pool_2d( for (int ky = 0; ky < k1; ++ky) { if (iy + ky < 0 || iy + ky >= src->ne[1]) continue; - const float * const srow = (const float *)(cdata + src->nb[1] * (iy + ky)); + const void * srow = (const void *)(cdata + src->nb[1] * (iy + ky)); for (int kx = 0; kx < k0; ++kx) { int j = ix + kx; if (j < 0 || j >= src->ne[0]) continue; + const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]); switch (op) { - case GGML_OP_POOL_AVG: *out += srow[j]; break; - case GGML_OP_POOL_MAX: if (srow[j] > *out) *out = srow[j]; break; - case GGML_OP_POOL_COUNT: GGML_ASSERT(false); break; + case GGML_OP_POOL_AVG: *out += srow_j; break; + case GGML_OP_POOL_MAX: if (srow_j > *out) *out = srow_j; break; + case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error"); } } } switch (op) { case GGML_OP_POOL_AVG: *out /= ka; break; case GGML_OP_POOL_MAX: break; - case GGML_OP_POOL_COUNT: GGML_ASSERT(false); break; + case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error"); } } } @@ -15257,10 +14915,6 @@ static void ggml_compute_forward_upscale_f32( const struct ggml_tensor * src0 = dst->src[0]; - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { - return; - } - GGML_ASSERT(src0->type == GGML_TYPE_F32); const int ith = params->ith; @@ -15307,8 +14961,8 @@ static void ggml_compute_forward_upscale( } break; default: { - GGML_ASSERT(false); - } break; + GGML_ABORT("fatal error"); + } } } @@ -15321,10 +14975,6 @@ static void ggml_compute_forward_pad_f32( const struct ggml_tensor * src0 = dst->src[0]; - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { - return; - } - GGML_ASSERT(src0->nb[0] == sizeof(float)); GGML_ASSERT( dst->nb[0] == sizeof(float)); @@ -15369,8 +15019,8 @@ static void ggml_compute_forward_pad( } break; default: { - GGML_ASSERT(false); - } break; + GGML_ABORT("fatal error"); + } } } @@ -15381,10 +15031,6 @@ static void ggml_compute_forward_arange_f32( const struct ggml_compute_params * params, struct ggml_tensor * dst) { - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { - return; - } - GGML_ASSERT(dst->nb[0] == sizeof(float)); const int ith = params->ith; @@ -15414,8 +15060,8 @@ static void ggml_compute_forward_arange( } break; default: { - GGML_ASSERT(false); - } break; + GGML_ABORT("fatal error"); + } } } @@ -15423,10 +15069,6 @@ static void ggml_compute_forward_timestep_embedding_f32( const struct ggml_compute_params * params, struct ggml_tensor * dst) { - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { - return; - } - const struct ggml_tensor * src0 = dst->src[0]; GGML_ASSERT(src0->nb[0] == sizeof(float)); @@ -15469,8 +15111,8 @@ static void ggml_compute_forward_timestep_embedding( } break; default: { - GGML_ASSERT(false); - } break; + GGML_ABORT("fatal error"); + } } } @@ -15482,10 +15124,6 @@ static void ggml_compute_forward_argsort_f32( const struct ggml_tensor * src0 = dst->src[0]; - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { - return; - } - GGML_TENSOR_UNARY_OP_LOCALS GGML_ASSERT(nb0 == sizeof(float)); @@ -15532,8 +15170,8 @@ static void ggml_compute_forward_argsort( } break; default: { - GGML_ASSERT(false); - } break; + GGML_ABORT("fatal error"); + } } } @@ -15546,8 +15184,6 @@ static void ggml_compute_forward_flash_attn_ext_f16( const struct ggml_tensor * v, const struct ggml_tensor * mask, struct ggml_tensor * dst) { - int64_t t0 = ggml_perf_time_us(); - UNUSED(t0); GGML_TENSOR_LOCALS(int64_t, neq, q, ne) GGML_TENSOR_LOCALS(size_t, nbq, q, nb) @@ -15592,14 +15228,6 @@ static void ggml_compute_forward_flash_attn_ext_f16( const int64_t rv2 = neq2/nev2; const int64_t rv3 = neq3/nev3; - if (params->type == GGML_TASK_TYPE_INIT) { - return; - } - - if (params->type == GGML_TASK_TYPE_FINALIZE) { - return; - } - // parallelize by q rows using ggml_vec_dot_f32 // total rows in q @@ -15765,8 +15393,8 @@ static void ggml_compute_forward_flash_attn_ext( } break; default: { - GGML_ASSERT(false); - } break; + GGML_ABORT("fatal error"); + } } } @@ -15782,9 +15410,6 @@ static void ggml_compute_forward_flash_attn_back_f32( const struct ggml_tensor * v = dst->src[2]; const struct ggml_tensor * d = dst->src[3]; - int64_t t0 = ggml_perf_time_us(); - UNUSED(t0); - GGML_TENSOR_LOCALS(int64_t, neq, q, ne) GGML_TENSOR_LOCALS(size_t, nbq, q, nb) GGML_TENSOR_LOCALS(int64_t, nek, k, ne) @@ -15831,16 +15456,10 @@ static void ggml_compute_forward_flash_attn_back_f32( GGML_ASSERT(nb1 <= nb2); GGML_ASSERT(nb2 <= nb3); - if (params->type == GGML_TASK_TYPE_INIT) { - if (ith == 0) { - memset(dst->data, 0, nb0*ne0*ne1*ne2*ne3); - } - return; - } - - if (params->type == GGML_TASK_TYPE_FINALIZE) { - return; + if (ith == 0) { + memset(dst->data, 0, nb0*ne0*ne1*ne2*ne3); } + ggml_barrier(params->shared); const int64_t elem_q = ggml_nelements(q); const int64_t elem_k = ggml_nelements(k); @@ -16110,8 +15729,8 @@ static void ggml_compute_forward_flash_attn_back( } break; default: { - GGML_ASSERT(false); - } break; + GGML_ABORT("fatal error"); + } } } @@ -16120,10 +15739,6 @@ static void ggml_compute_forward_flash_attn_back( static void ggml_compute_forward_ssm_conv_f32( const struct ggml_compute_params * params, struct ggml_tensor * dst) { - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { - return; - } - const struct ggml_tensor * src0 = dst->src[0]; // conv_state const struct ggml_tensor * src1 = dst->src[1]; // x const struct ggml_tensor * src2 = dst->src[2]; // conv1d.weight @@ -16236,8 +15851,8 @@ static void ggml_compute_forward_ssm_conv( } break; default: { - GGML_ASSERT(false); - } break; + GGML_ABORT("fatal error"); + } } } @@ -16246,10 +15861,6 @@ static void ggml_compute_forward_ssm_conv( static void ggml_compute_forward_ssm_scan_f32( const struct ggml_compute_params * params, struct ggml_tensor * dst) { - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { - return; - } - const struct ggml_tensor * src0 = dst->src[0]; // s const struct ggml_tensor * src1 = dst->src[1]; // x const struct ggml_tensor * src2 = dst->src[2]; // dt @@ -16361,8 +15972,8 @@ static void ggml_compute_forward_ssm_scan( } break; default: { - GGML_ASSERT(false); - } break; + GGML_ABORT("fatal error"); + } } } @@ -16371,13 +15982,10 @@ static void ggml_compute_forward_ssm_scan( static void ggml_compute_forward_win_part_f32( const struct ggml_compute_params * params, struct ggml_tensor * dst) { + UNUSED(params); const struct ggml_tensor * src0 = dst->src[0]; - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { - return; - } - GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) @@ -16427,8 +16035,8 @@ static void ggml_compute_forward_win_part( } break; default: { - GGML_ASSERT(false); - } break; + GGML_ABORT("fatal error"); + } } } @@ -16437,13 +16045,10 @@ static void ggml_compute_forward_win_part( static void ggml_compute_forward_win_unpart_f32( const struct ggml_compute_params * params, struct ggml_tensor * dst) { + UNUSED(params); const struct ggml_tensor * src0 = dst->src[0]; - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { - return; - } - GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) @@ -16491,8 +16096,8 @@ static void ggml_compute_forward_win_unpart( } break; default: { - GGML_ASSERT(false); - } break; + GGML_ABORT("fatal error"); + } } } @@ -16559,8 +16164,8 @@ static void ggml_compute_forward_unary( } break; default: { - GGML_ASSERT(false); - } break; + GGML_ABORT("fatal error"); + } } } @@ -16569,13 +16174,10 @@ static void ggml_compute_forward_unary( static void ggml_compute_forward_get_rel_pos_f16( const struct ggml_compute_params * params, struct ggml_tensor * dst) { + UNUSED(params); const struct ggml_tensor * src0 = dst->src[0]; - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { - return; - } - // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L292-L322 GGML_TENSOR_UNARY_OP_LOCALS @@ -16609,8 +16211,8 @@ static void ggml_compute_forward_get_rel_pos( } break; default: { - GGML_ASSERT(false); - } break; + GGML_ABORT("fatal error"); + } } } @@ -16625,20 +16227,12 @@ static void ggml_compute_forward_add_rel_pos_f32( const struct ggml_tensor * src2 = dst->src[2]; const bool inplace = (bool) ((int32_t *) dst->op_params)[0]; - if (!inplace && params->type == GGML_TASK_TYPE_INIT) { - if (params->ith != 0) { - return; + if (!inplace) { + if (params->ith == 0) { + memcpy((char *) dst->data, (char *) src0->data, ggml_nbytes(dst)); } - memcpy((char *) dst->data, (char *) src0->data, ggml_nbytes(dst)); - return; + ggml_barrier(params->shared); } - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { - return; - } - - int64_t t0 = ggml_perf_time_us(); - UNUSED(t0); - // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L357-L359 float * src1_data = (float *) src1->data; @@ -16698,8 +16292,8 @@ static void ggml_compute_forward_add_rel_pos( } break; default: { - GGML_ASSERT(false); - } break; + GGML_ABORT("fatal error"); + } } } @@ -16712,18 +16306,17 @@ static void ggml_compute_forward_map_unary_f32( const struct ggml_tensor * src0 = dst->src[0]; - GGML_ASSERT(ggml_are_same_shape(src0, dst)); - - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { + if (params->ith != 0) { return; } + assert(ggml_is_contiguous_1(src0)); + assert(ggml_is_contiguous_1(dst)); + assert(ggml_are_same_shape(src0, dst)); + const int n = ggml_nrows(src0); const int nc = src0->ne[0]; - assert( dst->nb[0] == sizeof(float)); - assert(src0->nb[0] == sizeof(float)); - for (int i = 0; i < n; i++) { fun(nc, (float *) ((char *) dst->data + i*( dst->nb[1])), @@ -16745,8 +16338,8 @@ static void ggml_compute_forward_map_unary( } break; default: { - GGML_ASSERT(false); - } break; + GGML_ABORT("fatal error"); + } } } @@ -16760,20 +16353,18 @@ static void ggml_compute_forward_map_binary_f32( const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * src1 = dst->src[1]; - assert(params->ith == 0); - assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); - - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { + if (params->ith != 0) { return; } + assert(ggml_is_contiguous_1(src0)); + assert(ggml_is_contiguous_1(src1)); + assert(ggml_is_contiguous_1(dst)); + assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); + const int n = ggml_nrows(src0); const int nc = src0->ne[0]; - assert( dst->nb[0] == sizeof(float)); - assert(src0->nb[0] == sizeof(float)); - assert(src1->nb[0] == sizeof(float)); - for (int i = 0; i < n; i++) { fun(nc, (float *) ((char *) dst->data + i*( dst->nb[1])), @@ -16796,8 +16387,8 @@ static void ggml_compute_forward_map_binary( } break; default: { - GGML_ASSERT(false); - } break; + GGML_ABORT("fatal error"); + } } } @@ -16810,9 +16401,7 @@ static void ggml_compute_forward_map_custom1_f32( const struct ggml_tensor * a = dst->src[0]; - assert(params->ith == 0); - - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { + if (params->ith != 0) { return; } @@ -16829,9 +16418,7 @@ static void ggml_compute_forward_map_custom2_f32( const struct ggml_tensor * a = dst->src[0]; const struct ggml_tensor * b = dst->src[1]; - assert(params->ith == 0); - - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { + if (params->ith != 0) { return; } @@ -16849,9 +16436,7 @@ static void ggml_compute_forward_map_custom3_f32( const struct ggml_tensor * b = dst->src[1]; const struct ggml_tensor * c = dst->src[1]; - assert(params->ith == 0); - - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { + if (params->ith != 0) { return; } @@ -16866,10 +16451,6 @@ static void ggml_compute_forward_map_custom1( const struct ggml_tensor * a = dst->src[0]; - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { - return; - } - struct ggml_map_custom1_op_params p; memcpy(&p, dst->op_params, sizeof(p)); @@ -16885,10 +16466,6 @@ static void ggml_compute_forward_map_custom2( const struct ggml_tensor * a = dst->src[0]; const struct ggml_tensor * b = dst->src[1]; - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { - return; - } - struct ggml_map_custom2_op_params p; memcpy(&p, dst->op_params, sizeof(p)); @@ -16905,10 +16482,6 @@ static void ggml_compute_forward_map_custom3( const struct ggml_tensor * b = dst->src[1]; const struct ggml_tensor * c = dst->src[2]; - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { - return; - } - struct ggml_map_custom3_op_params p; memcpy(&p, dst->op_params, sizeof(p)); @@ -16940,21 +16513,10 @@ static void ggml_compute_forward_cross_entropy_loss_f32( GGML_ASSERT(params->wsize >= sizeof(float) * (nth + nth * nc)); - if (params->type == GGML_TASK_TYPE_INIT) { - if (ith == 0) { - memset(sums, 0, sizeof(float) * (nth + nth * nc)); - } - return; - } - - if (params->type == GGML_TASK_TYPE_FINALIZE) { - if (ith == 0) { - float * dp = (float *) dst->data; - ggml_vec_sum_f32(nth, dp, sums); - dp[0] *= -1.0f / (float) nr; - } - return; + if (ith == 0) { + memset(sums, 0, sizeof(float) * (nth + nth * nc)); } + ggml_barrier(params->shared); const double eps = 1e-9; @@ -17002,7 +16564,13 @@ static void ggml_compute_forward_cross_entropy_loss_f32( } #endif } + ggml_barrier(params->shared); + if (ith == 0) { + float * dp = (float *) dst->data; + ggml_vec_sum_f32(nth, dp, sums); + dp[0] *= -1.0f / (float) nr; + } } static void ggml_compute_forward_cross_entropy_loss( @@ -17018,8 +16586,8 @@ static void ggml_compute_forward_cross_entropy_loss( } break; default: { - GGML_ASSERT(false); - } break; + GGML_ABORT("fatal error"); + } } } @@ -17042,10 +16610,6 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32( const int64_t ith = params->ith; const int64_t nth = params->nth; - if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { - return; - } - const double eps = 1e-9; // TODO: handle transposed/permuted matrices @@ -17109,14 +16673,14 @@ static void ggml_compute_forward_cross_entropy_loss_back( } break; default: { - GGML_ASSERT(false); - } break; + GGML_ABORT("fatal error"); + } } } ///////////////////////////////// -static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor, struct ggml_compute_state * state) { +static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) { GGML_ASSERT(params); if (tensor->op == GGML_OP_NONE || ggml_is_empty(tensor)) { @@ -17214,7 +16778,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm } break; case GGML_OP_MUL_MAT: { - ggml_compute_forward_mul_mat(params, tensor, state); + ggml_compute_forward_mul_mat(params, tensor); } break; case GGML_OP_MUL_MAT_ID: { @@ -17445,14 +17009,32 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm } break; case GGML_OP_COUNT: { - GGML_ASSERT(false); - } break; + GGML_ABORT("fatal error"); + } } } //////////////////////////////////////////////////////////////////////////////// -static size_t ggml_hash_size(size_t min_sz) { +struct ggml_hash_set ggml_hash_set_new(size_t size) { + size = ggml_hash_size(size); + struct ggml_hash_set result; + result.size = size; + result.keys = GGML_MALLOC(sizeof(struct ggml_tensor *) * size); + result.used = GGML_CALLOC(ggml_bitset_size(size), sizeof(ggml_bitset_t)); + return result; +} + +void ggml_hash_set_reset(struct ggml_hash_set * hash_set) { + memset(hash_set->used, 0, sizeof(ggml_bitset_t) * ggml_bitset_size(hash_set->size)); +} + +void ggml_hash_set_free(struct ggml_hash_set * hash_set) { + GGML_FREE(hash_set->used); + GGML_FREE(hash_set->keys); +} + +size_t ggml_hash_size(size_t min_sz) { // next primes after powers of two static const size_t primes[] = { 2, 3, 5, 11, 17, 37, 67, 131, 257, 521, 1031, @@ -17463,7 +17045,7 @@ static size_t ggml_hash_size(size_t min_sz) { }; static const size_t n_primes = sizeof(primes)/sizeof(primes[0]); - // find the smallest prime that is larger or equal to min_sz + // find the smallest prime that is larger or equal than min_sz size_t l = 0; size_t r = n_primes; while (l < r) { @@ -17478,67 +17060,6 @@ static size_t ggml_hash_size(size_t min_sz) { return sz; } -static size_t ggml_hash(const void * p) { - return (size_t)p; -} - -size_t ggml_hash_find(const struct ggml_hash_set hash_set, struct ggml_tensor * key) { - size_t h = ggml_hash(key) % hash_set.size; - - // linear probing - size_t i = h; - while (hash_set.keys[i] != NULL && hash_set.keys[i] != key) { - i = (i + 1) % hash_set.size; - if (i == h) { - // visited all hash table entries -> not found - return GGML_HASHTABLE_FULL; - } - } - return i; -} - -bool ggml_hash_contains(struct ggml_hash_set hash_set, struct ggml_tensor * key) { - size_t i = ggml_hash_find(hash_set, key); - return i != GGML_HASHTABLE_FULL && hash_set.keys[i] == key; -} - -size_t ggml_hash_insert(struct ggml_hash_set hash_set, struct ggml_tensor * key) { - size_t i = ggml_hash_find(hash_set, key); - - GGML_ASSERT(i != GGML_HASHTABLE_FULL); - - if (hash_set.keys[i] == key) { - return GGML_HASHTABLE_ALREADY_EXISTS; - } - - // insert - GGML_ASSERT(hash_set.keys[i] == NULL); - hash_set.keys[i] = key; - return i; -} - -size_t ggml_hash_find_or_insert(struct ggml_hash_set hash_set, struct ggml_tensor * key) { - size_t i = ggml_hash_find(hash_set, key); - - GGML_ASSERT(i != GGML_HASHTABLE_FULL); - - hash_set.keys[i] = key; - return i; -} - -struct ggml_hash_set ggml_hash_set_new(size_t size) { - size = ggml_hash_size(size); - struct ggml_hash_set result; - result.size = size; - result.keys = GGML_MALLOC(sizeof(struct ggml_tensor *) * size); - memset(result.keys, 0, sizeof(struct ggml_tensor *) * size); - return result; -} - -static void ggml_hash_set_free(struct ggml_hash_set hash_set) { - GGML_FREE(hash_set.keys); -} - struct hash_map { struct ggml_hash_set set; struct ggml_tensor ** vals; @@ -17547,13 +17068,12 @@ struct hash_map { static struct hash_map * ggml_new_hash_map(size_t size) { struct hash_map * result = GGML_MALLOC(sizeof(struct hash_map)); result->set = ggml_hash_set_new(size); - result->vals = GGML_MALLOC(sizeof(struct ggml_tensor *) * result->set.size); - memset(result->vals, 0, sizeof(struct ggml_tensor *) * result->set.size); + result->vals = GGML_CALLOC(result->set.size, sizeof(struct ggml_tensor *)); return result; } static void ggml_hash_map_free(struct hash_map * map) { - ggml_hash_set_free(map->set); + ggml_hash_set_free(&map->set); GGML_FREE(map->vals); GGML_FREE(map); } @@ -17574,7 +17094,7 @@ static struct ggml_tensor * ggml_recompute_graph_node( return node; } - if (!ggml_hash_contains(graph->visited_hash_table, node)) { + if (!ggml_hash_contains(&graph->visited_hash_set, node)) { return node; } @@ -17589,8 +17109,8 @@ static struct ggml_tensor * ggml_recompute_graph_node( return node; } - size_t i = ggml_hash_find(replacements->set, node); - GGML_ASSERT(i != GGML_HASHTABLE_FULL); // assert that not full + size_t i = ggml_hash_find(&replacements->set, node); + GGML_ASSERT(i != GGML_HASHSET_FULL); // assert that not full if (replacements->set.keys[i] == node) { return replacements->vals[i]; } @@ -17648,8 +17168,8 @@ void ggml_build_backward_gradient_checkpointing( // insert checkpoints in replacements for (int i = 0; i < n_checkpoints; ++i) { - size_t k = ggml_hash_find(replacements->set, checkpoints[i]); - GGML_ASSERT(k != GGML_HASHTABLE_FULL); // assert that not full + size_t k = ggml_hash_find(&replacements->set, checkpoints[i]); + GGML_ASSERT(k != GGML_HASHSET_FULL); // assert that not full GGML_ASSERT(replacements->set.keys[k] == NULL); // assert that we don't overwrite replacements->set.keys[k] = checkpoints[i]; replacements->vals[k] = checkpoints[i]; @@ -17677,7 +17197,7 @@ void ggml_build_backward_gradient_checkpointing( // functions to change gradients considering the case that input a might be initial gradient with zero value -static struct ggml_tensor * ggml_add_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, struct ggml_hash_set zero_table) { +static struct ggml_tensor * ggml_add_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, struct ggml_hash_set * zero_table) { if (ggml_hash_contains(zero_table, a)) { return b; } else { @@ -17685,7 +17205,7 @@ static struct ggml_tensor * ggml_add_or_set(struct ggml_context * ctx, struct gg } } -static struct ggml_tensor * ggml_acc_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, size_t nb1, size_t nb2, size_t nb3, size_t offset, struct ggml_hash_set zero_table) { +static struct ggml_tensor * ggml_acc_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, size_t nb1, size_t nb2, size_t nb3, size_t offset, struct ggml_hash_set * zero_table) { if (ggml_hash_contains(zero_table, a)) { struct ggml_tensor * a_zero = ggml_scale(ctx, a, 0.0f); return ggml_acc_impl(ctx, a_zero, b, nb1, nb2, nb3, offset, false); @@ -17694,7 +17214,7 @@ static struct ggml_tensor * ggml_acc_or_set(struct ggml_context * ctx, struct gg } } -static struct ggml_tensor * ggml_add1_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, struct ggml_hash_set zero_table) { +static struct ggml_tensor * ggml_add1_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, struct ggml_hash_set * zero_table) { if (ggml_hash_contains(zero_table, a)) { return ggml_repeat(ctx, b, a); } else { @@ -17702,7 +17222,7 @@ static struct ggml_tensor * ggml_add1_or_set(struct ggml_context * ctx, struct g } } -static struct ggml_tensor * ggml_sub_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, struct ggml_hash_set zero_table) { +static struct ggml_tensor * ggml_sub_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, struct ggml_hash_set * zero_table) { if (ggml_hash_contains(zero_table, a)) { return ggml_neg(ctx, b); } else { @@ -17710,7 +17230,7 @@ static struct ggml_tensor * ggml_sub_or_set(struct ggml_context * ctx, struct gg } } -static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, struct ggml_hash_set zero_table) { +static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, struct ggml_hash_set * zero_table) { struct ggml_tensor * src0 = tensor->src[0]; struct ggml_tensor * src1 = tensor->src[1]; struct ggml_tensor * src2 = tensor->src[2]; @@ -17879,8 +17399,8 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor case GGML_OP_MEAN: case GGML_OP_ARGMAX: { - GGML_ASSERT(false); // TODO: implement - } break; + GGML_ABORT("fatal error"); // TODO: implement + } case GGML_OP_REPEAT: { // necessary for llama @@ -17903,16 +17423,16 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor } break; case GGML_OP_CONCAT: { - GGML_ASSERT(false); // TODO: implement - } break; + GGML_ABORT("fatal error"); // TODO: implement + } case GGML_OP_SILU_BACK: { - GGML_ASSERT(false); // TODO: not implemented - } break; + GGML_ABORT("fatal error"); // TODO: not implemented + } case GGML_OP_NORM: { - GGML_ASSERT(false); // TODO: not implemented - } break; + GGML_ABORT("fatal error"); // TODO: not implemented + } case GGML_OP_RMS_NORM: { // necessary for llama @@ -17928,12 +17448,12 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor } break; case GGML_OP_RMS_NORM_BACK: { - GGML_ASSERT(false); // TODO: not implemented - } break; + GGML_ABORT("fatal error"); // TODO: not implemented + } case GGML_OP_GROUP_NORM: { - GGML_ASSERT(false); // TODO: not implemented - } break; + GGML_ABORT("fatal error"); // TODO: not implemented + } case GGML_OP_MUL_MAT: { // https://cs231n.github.io/optimization-2/#staged @@ -17994,12 +17514,12 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor } break; case GGML_OP_MUL_MAT_ID: { - GGML_ASSERT(false); // TODO: not implemented - } break; + GGML_ABORT("fatal error"); // TODO: not implemented + } case GGML_OP_OUT_PROD: { - GGML_ASSERT(false); // TODO: not implemented - } break; + GGML_ABORT("fatal error"); // TODO: not implemented + } case GGML_OP_SCALE: { // necessary for llama @@ -18175,12 +17695,12 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor } break; case GGML_OP_GET_ROWS_BACK: { - GGML_ASSERT(false); // TODO: not implemented - } break; + GGML_ABORT("fatal error"); // TODO: not implemented + } case GGML_OP_DIAG: { - GGML_ASSERT(false); // TODO: not implemented - } break; + GGML_ABORT("fatal error"); // TODO: not implemented + } case GGML_OP_DIAG_MASK_INF: { // necessary for llama @@ -18218,8 +17738,8 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor } break; case GGML_OP_SOFT_MAX_BACK: { - GGML_ASSERT(false); // TODO: not implemented - } break; + GGML_ABORT("fatal error"); // TODO: not implemented + } case GGML_OP_ROPE: { // necessary for llama @@ -18294,52 +17814,52 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor } break; case GGML_OP_CLAMP: { - GGML_ASSERT(false); // TODO: not implemented - } break; + GGML_ABORT("fatal error"); // TODO: not implemented + } case GGML_OP_CONV_TRANSPOSE_1D: { - GGML_ASSERT(false); // TODO: not implemented - } break; + GGML_ABORT("fatal error"); // TODO: not implemented + } case GGML_OP_IM2COL: { - GGML_ASSERT(false); // TODO: not implemented - } break; + GGML_ABORT("fatal error"); // TODO: not implemented + } case GGML_OP_CONV_TRANSPOSE_2D: { - GGML_ASSERT(false); // TODO: not implemented - } break; + GGML_ABORT("fatal error"); // TODO: not implemented + } case GGML_OP_POOL_1D: { - GGML_ASSERT(false); // TODO: not implemented - } break; + GGML_ABORT("fatal error"); // TODO: not implemented + } case GGML_OP_POOL_2D: { - GGML_ASSERT(false); // TODO: not implemented - } break; + GGML_ABORT("fatal error"); // TODO: not implemented + } case GGML_OP_UPSCALE: { - GGML_ASSERT(false); // TODO: not implemented - } break; + GGML_ABORT("fatal error"); // TODO: not implemented + } case GGML_OP_PAD: { - GGML_ASSERT(false); // TODO: not implemented - } break; + GGML_ABORT("fatal error"); // TODO: not implemented + } case GGML_OP_ARANGE: { - GGML_ASSERT(false); // TODO: not implemented - } break; + GGML_ABORT("fatal error"); // TODO: not implemented + } case GGML_OP_TIMESTEP_EMBEDDING: { - GGML_ASSERT(false); // TODO: not implemented - } break; + GGML_ABORT("fatal error"); // TODO: not implemented + } case GGML_OP_ARGSORT: { - GGML_ASSERT(false); // TODO: not implemented - } break; + GGML_ABORT("fatal error"); // TODO: not implemented + } case GGML_OP_LEAKY_RELU: { - GGML_ASSERT(false); // TODO: not implemented - } break; + GGML_ABORT("fatal error"); // TODO: not implemented + } case GGML_OP_FLASH_ATTN_EXT: { struct ggml_tensor * flash_grad = NULL; @@ -18395,13 +17915,13 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor } break; case GGML_OP_FLASH_ATTN_BACK: { - GGML_ASSERT(false); // not supported - } break; + GGML_ABORT("fatal error"); // not supported + } case GGML_OP_SSM_CONV: case GGML_OP_SSM_SCAN: { - GGML_ASSERT(false); // TODO: not implemented - } break; + GGML_ABORT("fatal error"); // TODO: not implemented + } case GGML_OP_WIN_PART: case GGML_OP_WIN_UNPART: case GGML_OP_UNARY: @@ -18439,12 +17959,12 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor } break; case GGML_UNARY_OP_TANH: { - GGML_ASSERT(false); // TODO: not implemented - } break; + GGML_ABORT("fatal error"); // TODO: not implemented + } case GGML_UNARY_OP_ELU: { - GGML_ASSERT(false); // TODO: not implemented - } break; + GGML_ABORT("fatal error"); // TODO: not implemented + } case GGML_UNARY_OP_RELU: { if (src0->grad) { @@ -18458,16 +17978,16 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor } break; case GGML_UNARY_OP_SIGMOID: { - GGML_ASSERT(false); // TODO: not implemented - } break; + GGML_ABORT("fatal error"); // TODO: not implemented + } case GGML_UNARY_OP_GELU: { - GGML_ASSERT(false); // TODO: not implemented - } break; + GGML_ABORT("fatal error"); // TODO: not implemented + } case GGML_UNARY_OP_GELU_QUICK: { - GGML_ASSERT(false); // TODO: not implemented - } break; + GGML_ABORT("fatal error"); // TODO: not implemented + } case GGML_UNARY_OP_SILU: { // necessary for llama @@ -18479,7 +17999,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor } } break; default: - GGML_ASSERT(false); + GGML_ABORT("fatal error"); } } break; case GGML_OP_GET_REL_POS: @@ -18493,8 +18013,8 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor case GGML_OP_MAP_CUSTOM2: case GGML_OP_MAP_CUSTOM3: { - GGML_ASSERT(false); // not supported - } break; + GGML_ABORT("fatal error"); // not supported + } case GGML_OP_CROSS_ENTROPY_LOSS: { if (src0->grad) { @@ -18509,16 +18029,16 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor } break; case GGML_OP_CROSS_ENTROPY_LOSS_BACK: { - GGML_ASSERT(false); // not supported - } break; + GGML_ABORT("fatal error"); // not supported + } case GGML_OP_NONE: { // nop } break; case GGML_OP_COUNT: { - GGML_ASSERT(false); - } break; + GGML_ABORT("fatal error"); + } } for (int i = 0; i < GGML_MAX_SRC; ++i) { @@ -18538,7 +18058,7 @@ static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * } // check if already visited - if (ggml_hash_insert(cgraph->visited_hash_table, node) == GGML_HASHTABLE_ALREADY_EXISTS) { + if (ggml_hash_insert(&cgraph->visited_hash_set, node) == GGML_HASHSET_ALREADY_EXISTS) { return; } @@ -18584,7 +18104,6 @@ static void ggml_build_forward_impl(struct ggml_cgraph * cgraph, struct ggml_ten } const int n0 = cgraph->n_nodes; - UNUSED(n0); ggml_visit_parents(cgraph, tensor); @@ -18620,7 +18139,7 @@ void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * struct ggml_hash_set zero_table = ggml_hash_set_new(gf->size); for (int i = 0; i < gf->n_nodes; i++) { if (gf->grads[i]) { - ggml_hash_insert(zero_table, gf->grads[i]); + ggml_hash_insert(&zero_table, gf->grads[i]); } } @@ -18630,7 +18149,7 @@ void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * // inplace operations to add gradients are not created by ggml_compute_backward // use allocator to automatically make inplace operations if (node->grad) { - ggml_compute_backward(ctx, node, zero_table); + ggml_compute_backward(ctx, node, &zero_table); } } @@ -18643,16 +18162,29 @@ void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * } } - ggml_hash_set_free(zero_table); + ggml_hash_set_free(&zero_table); +} + +static void * incr_ptr_aligned(void ** p, size_t size, size_t align) { + void * ptr = *p; + ptr = (void *) GGML_PAD((uintptr_t) ptr, align); + *p = (void *) ((char *) ptr + size); + return ptr; } static size_t ggml_graph_nbytes(size_t size, bool grads) { - size_t nbytes = sizeof(struct ggml_cgraph); - nbytes += size * sizeof(struct ggml_tensor *) * 2; // leafs + nodes + size_t hash_size = ggml_hash_size(size * 2); + void * p = 0; + incr_ptr_aligned(&p, sizeof(struct ggml_cgraph), 1); + incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // nodes + incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // leafs + incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // hash keys if (grads) { - nbytes += size * sizeof(struct ggml_tensor *); // grads + incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // grads } - nbytes += ggml_hash_size(size * 2) * sizeof(struct ggml_tensor *); // hash set + incr_ptr_aligned(&p, ggml_bitset_size(hash_size) * sizeof(ggml_bitset_t), sizeof(ggml_bitset_t)); + + size_t nbytes = (size_t) p; return nbytes; } @@ -18669,19 +18201,19 @@ struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t siz struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_TYPE_GRAPH, obj_size); struct ggml_cgraph * cgraph = (struct ggml_cgraph *) ((char *) ctx->mem_buffer + obj->offs); - struct ggml_tensor ** data_start = (struct ggml_tensor **) (cgraph + 1); - + // the size of the hash table is doubled since it needs to hold both nodes and leafs size_t hash_size = ggml_hash_size(size * 2); - struct ggml_tensor ** nodes_ptr = data_start; - struct ggml_tensor ** leafs_ptr = nodes_ptr + size; - struct ggml_tensor ** hash_keys_ptr = leafs_ptr + size; - struct ggml_tensor ** grads_ptr = grads ? hash_keys_ptr + hash_size : NULL; + + void * p = cgraph + 1; + + struct ggml_tensor ** nodes_ptr = incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); + struct ggml_tensor ** leafs_ptr = incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); + struct ggml_tensor ** hash_keys_ptr = incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); + struct ggml_tensor ** grads_ptr = grads ? incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL; + ggml_bitset_t * hash_used = incr_ptr_aligned(&p, ggml_bitset_size(hash_size) * sizeof(ggml_bitset_t), sizeof(ggml_bitset_t)); // check that we allocated the correct amount of memory - assert(obj_size == (size_t) ( - (grads ? (char *)(grads_ptr + size) : (char *)(hash_keys_ptr + hash_size)) - (char *)cgraph)); - - memset(hash_keys_ptr, 0, hash_size * sizeof(struct ggml_tensor *)); + assert(obj_size == (size_t)((char *)p - (char *)cgraph)); *cgraph = (struct ggml_cgraph) { /*.size =*/ size, @@ -18690,13 +18222,12 @@ struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t siz /*.nodes =*/ nodes_ptr, /*.grads =*/ grads_ptr, /*.leafs =*/ leafs_ptr, - /*.hash_table =*/ { hash_size, hash_keys_ptr }, + /*.hash_table =*/ { hash_size, hash_used, hash_keys_ptr }, /*.order =*/ GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT, - /*.perf_runs =*/ 0, - /*.perf_cycles =*/ 0, - /*.perf_time_us =*/ 0, }; + ggml_hash_set_reset(&cgraph->visited_hash_set); + return cgraph; } @@ -18712,11 +18243,8 @@ struct ggml_cgraph ggml_graph_view(struct ggml_cgraph * cgraph0, int i0, int i1) /*.nodes =*/ cgraph0->nodes + i0, /*.grads =*/ cgraph0->grads ? cgraph0->grads + i0 : NULL, /*.leafs =*/ NULL, - /*.hash_table =*/ { 0, NULL }, + /*.hash_table =*/ { 0, NULL, NULL }, /*.order =*/ cgraph0->order, - /*.perf_runs =*/ 0, - /*.perf_cycles =*/ 0, - /*.perf_time_us =*/ 0, }; return cgraph; @@ -18725,7 +18253,7 @@ struct ggml_cgraph ggml_graph_view(struct ggml_cgraph * cgraph0, int i0, int i1) void ggml_graph_cpy(struct ggml_cgraph * src, struct ggml_cgraph * dst) { GGML_ASSERT(dst->size >= src->n_leafs); GGML_ASSERT(dst->size >= src->n_nodes); - GGML_ASSERT(dst->visited_hash_table.size >= src->visited_hash_table.size); + GGML_ASSERT(dst->visited_hash_set.size >= src->visited_hash_set.size); dst->n_leafs = src->n_leafs; dst->n_nodes = src->n_nodes; @@ -18746,9 +18274,9 @@ void ggml_graph_cpy(struct ggml_cgraph * src, struct ggml_cgraph * dst) { } } - for (size_t i = 0; i < src->visited_hash_table.size; ++i) { - if (src->visited_hash_table.keys[i]) { - ggml_hash_insert(dst->visited_hash_table, src->visited_hash_table.keys[i]); + for (size_t i = 0; i < src->visited_hash_set.size; ++i) { + if (src->visited_hash_set.keys[i]) { + ggml_hash_insert(&dst->visited_hash_set, src->visited_hash_set.keys[i]); } } } @@ -18774,7 +18302,7 @@ void ggml_graph_reset(struct ggml_cgraph * cgraph) { void ggml_graph_clear(struct ggml_cgraph * cgraph) { cgraph->n_leafs = 0; cgraph->n_nodes = 0; - memset(cgraph->visited_hash_table.keys, 0, cgraph->visited_hash_table.size * sizeof(struct ggml_tensor *)); + ggml_hash_set_reset(&cgraph->visited_hash_set); } // @@ -18910,16 +18438,7 @@ static void set_numa_thread_affinity(int thread_n) { UNUSED(thread_n); } static void clear_numa_thread_affinity(void) {} #endif -static void ggml_graph_compute_perf_stats_node(struct ggml_tensor * node, const struct ggml_compute_state_shared * st) { - int64_t cycles_cur = ggml_perf_cycles() - st->perf_node_start_cycles; - int64_t time_us_cur = ggml_perf_time_us() - st->perf_node_start_time_us; - - node->perf_runs++; - node->perf_cycles += cycles_cur; - node->perf_time_us += time_us_cur; -} - -static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_threads) { +static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { int n_tasks = 0; if (ggml_is_empty(node)) { @@ -18931,6 +18450,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_ switch (node->op) { case GGML_OP_CPY: case GGML_OP_DUP: + case GGML_OP_CONT: case GGML_OP_ADD: case GGML_OP_ADD1: case GGML_OP_ACC: @@ -18961,8 +18481,8 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_ case GGML_UNARY_OP_ELU: case GGML_UNARY_OP_RELU: case GGML_UNARY_OP_SIGMOID: - case GGML_UNARY_OP_HARDSWISH: // to opt for multiple threads - case GGML_UNARY_OP_HARDSIGMOID: // to opt for multiple threads + case GGML_UNARY_OP_HARDSWISH: + case GGML_UNARY_OP_HARDSIGMOID: { n_tasks = 1; } break; @@ -18974,7 +18494,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_ n_tasks = n_threads; } break; default: - GGML_ASSERT(false); + GGML_ABORT("fatal error"); } break; case GGML_OP_SILU_BACK: @@ -18985,37 +18505,21 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_ case GGML_OP_RMS_NORM_BACK: case GGML_OP_GROUP_NORM: case GGML_OP_CONCAT: - { - n_tasks = n_threads; - } break; case GGML_OP_MUL_MAT: - { - n_tasks = n_threads; - - // TODO: use different scheduling for different matrix sizes - //const int nr0 = ggml_nrows(node->src[0]); - //const int nr1 = ggml_nrows(node->src[1]); - - //n_tasks = MIN(n_threads, MAX(1, nr0/128)); - //printf("nr0 = %8d, nr1 = %8d, nr0*nr1 = %8d, n_tasks%d\n", nr0, nr1, nr0*nr1, n_tasks); - } break; case GGML_OP_MUL_MAT_ID: - { - n_tasks = n_threads; - } break; case GGML_OP_OUT_PROD: { n_tasks = n_threads; } break; case GGML_OP_GET_ROWS: { - // FIXME: the cost of launching additional threads decreases performance with GPU offloading - //n_tasks = MIN(n_threads, ggml_nelements(node->src[1])); - n_tasks = MIN(n_cur_threads, ggml_nelements(node->src[1])); + // FIXME: get_rows can use additional threads, but the cost of launching additional threads + // decreases performance with GPU offloading + //n_tasks = n_threads; + n_tasks = 1; } break; case GGML_OP_SCALE: case GGML_OP_SET: - case GGML_OP_CONT: case GGML_OP_RESHAPE: case GGML_OP_VIEW: case GGML_OP_PERMUTE: @@ -19042,14 +18546,8 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_ { n_tasks = MIN(n_threads, ggml_nrows(node->src[0])); } break; - case GGML_OP_CONV_TRANSPOSE_1D: - { - n_tasks = n_threads; - } break; case GGML_OP_IM2COL: - { - n_tasks = n_threads; - } break; + case GGML_OP_CONV_TRANSPOSE_1D: case GGML_OP_CONV_TRANSPOSE_2D: { n_tasks = n_threads; @@ -19060,33 +18558,12 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_ n_tasks = 1; } break; case GGML_OP_UPSCALE: - { - n_tasks = n_threads; - } break; case GGML_OP_PAD: - { - n_tasks = n_threads; - } break; case GGML_OP_ARANGE: - { - n_tasks = n_threads; - } break; case GGML_OP_TIMESTEP_EMBEDDING: - { - n_tasks = n_threads; - } break; case GGML_OP_ARGSORT: - { - n_tasks = n_threads; - } break; case GGML_OP_FLASH_ATTN_EXT: - { - n_tasks = n_threads; - } break; case GGML_OP_FLASH_ATTN_BACK: - { - n_tasks = n_threads; - } break; case GGML_OP_SSM_CONV: case GGML_OP_SSM_SCAN: { @@ -19134,9 +18611,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_ } } break; case GGML_OP_CROSS_ENTROPY_LOSS: - { - n_tasks = n_threads; - } break; case GGML_OP_CROSS_ENTROPY_LOSS_BACK: { n_tasks = n_threads; @@ -19147,8 +18621,8 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_ } break; case GGML_OP_COUNT: { - GGML_ASSERT(false); - } break; + GGML_ABORT("fatal error"); + } default: { fprintf(stderr, "%s: op not implemented: ", __func__); @@ -19157,8 +18631,8 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_ } else { fprintf(stderr, "%d\n", node->op); } - GGML_ASSERT(false); - } break; + GGML_ABORT("fatal error"); + } } assert(n_tasks > 0); @@ -19166,184 +18640,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_ return n_tasks; } -static void ggml_graph_compute_thread_sync_node(int * node_n, struct ggml_compute_state * state, const bool do_yield) { - // wait for other threads to finish - const int last_node_n = * node_n; - - while (true) { - if (do_yield) { - sched_yield(); - } - - * node_n = atomic_load(&state->shared->node_n); - if (* node_n != last_node_n) break; -#if defined(__SSE3__) - // Tell the processor we're spinning. It's a processor hint for spinlocks. - _mm_pause(); -#endif - } -} - -static void ggml_graph_compute_thread_sync_task(int * task_phase, struct ggml_compute_state * state, const bool do_yield) { - // wait for other threads to finish - const int last_task_phase = * task_phase; - - while (true) { - if (do_yield) { - sched_yield(); - } - - * task_phase = atomic_load(&state->shared->node_task); - if (* task_phase != last_task_phase) break; -#if defined(__SSE3__) - // Tell the processor we're spinning. It's a processor hint for spinlocks. - _mm_pause(); -#endif - } -} - -static thread_ret_t ggml_graph_compute_thread(void * data) { - struct ggml_compute_state * state = (struct ggml_compute_state *) data; - - const struct ggml_cgraph * cgraph = state->shared->cgraph; - const struct ggml_cplan * cplan = state->shared->cplan; - - const int n_threads = state->shared->n_threads; - - set_numa_thread_affinity(state->ith); - - int node_n = -1; - int task_phase = GGML_TASK_TYPE_FINALIZE; - - while (true) { - if (cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) { - state->shared->node_n += 1; - state->ec = GGML_STATUS_ABORTED; - return 0; - } - - if (atomic_fetch_sub(&state->shared->n_active, 1) == 1) { - // all other threads are finished and spinning - // do finalize and init here so we don't have synchronize again - struct ggml_compute_params params = { - /*.type =*/ GGML_TASK_TYPE_FINALIZE, - /*.ith =*/ 0, - /*.nth =*/ 0, - /*.wsize =*/ cplan->work_size, - /*.wdata =*/ cplan->work_data, - }; - - if (node_n != -1) { - /* FINALIZE */ - struct ggml_tensor * node = cgraph->nodes[node_n]; - if (GGML_OP_HAS_FINALIZE[node->op]) { - params.nth = ggml_get_n_tasks(node, n_threads, state->shared->n_threads); - ggml_compute_forward(¶ms, node, state); - } - ggml_graph_compute_perf_stats_node(node, state->shared); - } - - // distribute new work or execute it direct if 1T - while (++node_n < cgraph->n_nodes) { - GGML_PRINT_DEBUG_5("%s: %d/%d\n", __func__, node_n, cgraph->n_nodes); - struct ggml_tensor * node = cgraph->nodes[node_n]; - const int n_tasks = ggml_get_n_tasks(node, n_threads, state->shared->n_threads); - - state->shared->perf_node_start_cycles = ggml_perf_cycles(); - state->shared->perf_node_start_time_us = ggml_perf_time_us(); - - params.nth = n_tasks; - - if (n_tasks == 1) { - /* INIT */ - if (GGML_OP_HAS_INIT[node->op]) { - params.type = GGML_TASK_TYPE_INIT; - ggml_compute_forward(¶ms, node, state); - } - - // TODO: maybe push node_n to the atomic but if other threads see n_tasks is 1, - // they do something more efficient than spinning (?) - params.type = GGML_TASK_TYPE_COMPUTE; - ggml_compute_forward(¶ms, node, state); - - if (GGML_OP_HAS_FINALIZE[node->op]) { - params.type = GGML_TASK_TYPE_FINALIZE; - ggml_compute_forward(¶ms, node, state); - } - - ggml_graph_compute_perf_stats_node(node, state->shared); - } else { - break; - } - - if (cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) { - break; - } - } - - task_phase = GGML_TASK_TYPE_INIT; - atomic_store(&state->shared->n_active, n_threads); - atomic_store(&state->shared->node_n, node_n); - atomic_store(&state->shared->node_task, task_phase); - } else { - ggml_graph_compute_thread_sync_node(&node_n, state, false); - ggml_graph_compute_thread_sync_task(&task_phase, state, false); - } - - // check if we should stop - if (node_n >= cgraph->n_nodes) break; - - /* INIT & COMPUTE */ - struct ggml_tensor * node = cgraph->nodes[node_n]; - const int n_tasks = ggml_get_n_tasks(node, n_threads, state->shared->n_threads); - - struct ggml_compute_params params = { - /*.type =*/ GGML_TASK_TYPE_INIT, - /*.ith =*/ state->ith, - /*.nth =*/ n_tasks, - /*.wsize =*/ cplan->work_size, - /*.wdata =*/ cplan->work_data, - }; - - if (state->ith < n_tasks) { - if (GGML_OP_HAS_INIT[node->op]) { - ggml_compute_forward(¶ms, node, state); - } - } - - if (atomic_fetch_sub(&state->shared->n_active, 1) == 1) { - task_phase = GGML_TASK_TYPE_COMPUTE; - atomic_store(&state->shared->n_active, n_threads); - atomic_store(&state->shared->node_task, task_phase); - } - else { - // TODO: this sched_yield can have significant impact on the performance - either positive or negative - // depending on the workload and the operating system. - // since it is not clear what is the best approach, it should potentially become user-configurable - // ref: https://github.com/ggerganov/ggml/issues/291 - // UPD: adding the do_yield flag seems to resolve the issue universally - const bool do_yield = node_n < 0 || cgraph->nodes[node_n]->op == GGML_OP_MUL_MAT; - ggml_graph_compute_thread_sync_task(&task_phase, state, do_yield); - } - - if (state->ith < n_tasks) { - params.type = GGML_TASK_TYPE_COMPUTE; - ggml_compute_forward(¶ms, node, state); - } - - if (atomic_fetch_sub(&state->shared->n_active, 1) == 1) { - task_phase = GGML_TASK_TYPE_FINALIZE; - atomic_store(&state->shared->n_active, n_threads); - atomic_store(&state->shared->node_task, task_phase); - } - else { - ggml_graph_compute_thread_sync_task(&task_phase, state, false); - } - } - - return 0; -} - struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threads) { if (n_threads <= 0) { n_threads = GGML_DEFAULT_N_THREADS; @@ -19360,7 +18656,7 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa for (int i = 0; i < cgraph->n_nodes; i++) { struct ggml_tensor * node = cgraph->nodes[i]; - const int n_tasks = ggml_get_n_tasks(node, n_threads, 1); + const int n_tasks = ggml_get_n_tasks(node, n_threads); max_tasks = MAX(max_tasks, n_tasks); @@ -19394,17 +18690,6 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa { const enum ggml_type vec_dot_type = type_traits[node->src[0]->type].vec_dot_type; -#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) - if (ggml_compute_forward_mul_mat_use_blas(node)) { - if (node->src[0]->type != GGML_TYPE_F32) { - // here we need memory for fully dequantized matrix from src0 - // take into account that src0 can be broadcasted into src1[2,3] - cur = ggml_type_size(GGML_TYPE_F32) - * node->src[0]->ne[0]*node->src[0]->ne[1] - * node->src[1]->ne[2]*node->src[1]->ne[3]; - } - } else -#endif if (node->src[1]->type != vec_dot_type) { cur = ggml_row_size(vec_dot_type, ggml_nelements(node->src[1])); } @@ -19457,7 +18742,7 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa cur += sizeof(float)*ne00*ne01*ne02; cur += sizeof(float)*ne10*ne11; } else { - GGML_ASSERT(false); + GGML_ABORT("fatal error"); } } break; case GGML_OP_CONV_TRANSPOSE_2D: @@ -19503,8 +18788,8 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa } break; case GGML_OP_COUNT: { - GGML_ASSERT(false); - } break; + GGML_ABORT("fatal error"); + } default: break; } @@ -19523,8 +18808,59 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa return cplan; } -static enum ggml_status ggml_graph_compute_parallel(struct ggml_compute_state * workers, int n_threads) { - enum ggml_status compute_status = GGML_STATUS_SUCCESS; +static thread_ret_t ggml_graph_compute_thread(void * data) { + struct ggml_compute_state * state = (struct ggml_compute_state *) data; + + const struct ggml_cgraph * cgraph = state->shared->cgraph; + const struct ggml_cplan * cplan = state->shared->cplan; + + set_numa_thread_affinity(state->ith); + + struct ggml_compute_params params = { + /*.ith =*/ state->ith, + /*.nth =*/ state->shared->n_threads, + /*.wsize =*/ cplan->work_size, + /*.wdata =*/ cplan->work_data, + /*.shared=*/ state->shared, + }; + + for (int node_n = 0; node_n < cgraph->n_nodes; node_n++) { + struct ggml_tensor * node = cgraph->nodes[node_n]; + + ggml_compute_forward(¶ms, node); + + if (state->ith == 0 && cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) { + state->shared->ec = GGML_STATUS_ABORTED; + } + + ggml_barrier(state->shared); + + if (state->shared->ec != GGML_STATUS_SUCCESS) { + break; + } + } + + return 0; +} + +enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) { + GGML_ASSERT(cplan); + GGML_ASSERT(cplan->n_threads > 0); + GGML_ASSERT(cplan->work_size == 0 || cplan->work_data != NULL); + + int n_threads = cplan->n_threads; + + struct ggml_compute_state_shared state_shared = { + /*.cgraph =*/ cgraph, + /*.cgraph_plan =*/ cplan, + /*.n_threads =*/ n_threads, + /*.n_barrier =*/ 0, + /*.n_barrier_passed =*/ 0, + /*.abort_callback =*/ NULL, + /*.abort_callback_data =*/ NULL, + /*.current_chunk =*/ 0, + /*.ec =*/ GGML_STATUS_SUCCESS, + }; #ifdef GGML_USE_OPENMP if (n_threads > 1) { @@ -19534,22 +18870,40 @@ static enum ggml_status ggml_graph_compute_parallel(struct ggml_compute_state * { // update the number of threads from the actual number of threads that we got from OpenMP n_threads = omp_get_num_threads(); - workers[0].shared->n_threads = n_threads; - workers[0].shared->n_active = n_threads; + state_shared.n_threads = n_threads; } - ggml_graph_compute_thread(&workers[omp_get_thread_num()]); + + struct ggml_compute_state worker = { + .thrd = 0, + .ith = omp_get_thread_num(), + .shared = &state_shared, + }; + ggml_graph_compute_thread(&worker); } } else { - ggml_graph_compute_thread(&workers[0]); + struct ggml_compute_state worker = { + .thrd = 0, + .ith = 0, + .shared = &state_shared, + }; + ggml_graph_compute_thread(&worker); } #else + struct ggml_compute_state * workers = alloca(sizeof(struct ggml_compute_state)*n_threads); + + for (int j = 0; j < n_threads; ++j) { + workers[j] = (struct ggml_compute_state) { + .thrd = 0, + .ith = j, + .shared = &state_shared, + }; + } + // create thread pool - if (n_threads > 1) { - for (int j = 1; j < n_threads; ++j) { - const int rc = ggml_thread_create(&workers[j].thrd, NULL, ggml_graph_compute_thread, &workers[j]); - GGML_ASSERT(rc == 0); - UNUSED(rc); - } + for (int j = 1; j < n_threads; ++j) { + const int rc = ggml_thread_create(&workers[j].thrd, NULL, ggml_graph_compute_thread, &workers[j]); + GGML_ASSERT(rc == 0); + UNUSED(rc); } // this is a work thread too @@ -19564,80 +18918,11 @@ static enum ggml_status ggml_graph_compute_parallel(struct ggml_compute_state * } } #endif + // don't leave affinity set on the main thread clear_numa_thread_affinity(); - for (int j = 0; j < n_threads; j++) { - if (workers[j].ec != GGML_STATUS_SUCCESS) { - compute_status = workers[j].ec; - break; - } - } - return compute_status; -} - -enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) { - { - GGML_ASSERT(cplan); - GGML_ASSERT(cplan->n_threads > 0); - - if (cplan->work_size > 0) { - GGML_ASSERT(cplan->work_data); - } - } - - int n_threads = cplan->n_threads; - -#if defined(GGML_USE_OPENMP) - n_threads = MIN(n_threads, omp_get_max_threads()); -#endif - - struct ggml_compute_state_shared state_shared = { - /*.cgraph =*/ cgraph, - /*.cgraph_plan =*/ cplan, - /*.perf_node_start_cycles =*/ 0, - /*.perf_node_start_time_us =*/ 0, - /*.n_threads =*/ n_threads, - /*.n_active =*/ n_threads, - /*.node_n =*/ -1, - /*.node_task =*/ GGML_TASK_TYPE_FINALIZE, - /*.abort_callback =*/ NULL, - /*.abort_callback_data =*/ NULL, - /*.current_chunk; =*/ 0, - }; - struct ggml_compute_state * workers = alloca(sizeof(struct ggml_compute_state)*n_threads); - const int64_t perf_start_cycles = ggml_perf_cycles(); - const int64_t perf_start_time_us = ggml_perf_time_us(); - - for (int j = 0; j < n_threads; ++j) { - workers[j] = (struct ggml_compute_state) { - .thrd = 0, - .ith = j, - .shared = &state_shared, - .ec = GGML_STATUS_SUCCESS, - }; - } - - enum ggml_status compute_status = ggml_graph_compute_parallel(workers, n_threads); - - // performance stats (graph) - { - int64_t perf_cycles_cur = ggml_perf_cycles() - perf_start_cycles; - int64_t perf_time_us_cur = ggml_perf_time_us() - perf_start_time_us; - - cgraph->perf_runs++; - cgraph->perf_cycles += perf_cycles_cur; - cgraph->perf_time_us += perf_time_us_cur; - - GGML_PRINT_DEBUG("%s: perf (%d) - cpu = %.3f / %.3f ms, wall = %.3f / %.3f ms\n", - __func__, cgraph->perf_runs, - (double) perf_cycles_cur / (double) ggml_cycles_per_ms(), - (double) cgraph->perf_cycles / (double) ggml_cycles_per_ms() / (double) cgraph->perf_runs, - (double) perf_time_us_cur / 1000.0, - (double) cgraph->perf_time_us / 1000.0 / cgraph->perf_runs); - } - - return compute_status; + return state_shared.ec; } enum ggml_status ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads) { @@ -19757,7 +19042,7 @@ void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname) { FILE * fout = ggml_fopen(fname, "wb"); if (!fout) { - fprintf(stderr, "%s: failed to open %s\n", __func__, fname); + fprintf(stderr, "%s: failed to open %s: %s\n", __func__, fname, strerror(errno)); return; } @@ -19894,7 +19179,7 @@ struct ggml_cgraph * ggml_graph_import(const char * fname, struct ggml_context * { FILE * fin = ggml_fopen(fname, "rb"); if (!fin) { - fprintf(stderr, "%s: failed to open %s\n", __func__, fname); + fprintf(stderr, "%s: failed to open %s: %s\n", __func__, fname, strerror(errno)); return result; } @@ -20136,24 +19421,16 @@ struct ggml_cgraph * ggml_graph_import(const char * fname, struct ggml_context * } void ggml_graph_print(const struct ggml_cgraph * cgraph) { - int64_t perf_total_per_op_us[GGML_OP_COUNT] = {0}; - GGML_PRINT("=== GRAPH ===\n"); GGML_PRINT("n_nodes = %d\n", cgraph->n_nodes); for (int i = 0; i < cgraph->n_nodes; i++) { struct ggml_tensor * node = cgraph->nodes[i]; - perf_total_per_op_us[node->op] += MAX(1, node->perf_time_us); - - GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 ", %5" PRId64 "] %16s %s (%3d) cpu = %7.3f / %7.3f ms, wall = %7.3f / %7.3f ms\n", + GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 ", %5" PRId64 "] %16s %s\n", i, node->ne[0], node->ne[1], node->ne[2], - ggml_op_name(node->op), (node->flags & GGML_TENSOR_FLAG_PARAM) ? "x" : node->grad ? "g" : " ", node->perf_runs, - (double) node->perf_cycles / (double) ggml_cycles_per_ms(), - (double) node->perf_cycles / (double) ggml_cycles_per_ms() / (double) node->perf_runs, - (double) node->perf_time_us / 1000.0, - (double) node->perf_time_us / 1000.0 / node->perf_runs); + ggml_op_name(node->op), (node->flags & GGML_TENSOR_FLAG_PARAM) ? "x" : node->grad ? "g" : " "); } GGML_PRINT("n_leafs = %d\n", cgraph->n_leafs); @@ -20167,14 +19444,6 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) { ggml_get_name(node)); } - for (int i = 0; i < GGML_OP_COUNT; i++) { - if (perf_total_per_op_us[i] == 0) { - continue; - } - - GGML_PRINT("perf_total_per_op_us[%16s] = %7.3f ms\n", ggml_op_name(i), (double) perf_total_per_op_us[i] / 1000.0); - } - GGML_PRINT("========================================\n"); } @@ -20233,7 +19502,7 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph fprintf(fp, "digraph G {\n"); fprintf(fp, " newrank = true;\n"); - fprintf(fp, " rankdir = LR;\n"); + fprintf(fp, " rankdir = TB;\n"); for (int i = 0; i < gb->n_nodes; i++) { struct ggml_tensor * node = gb->nodes[i]; @@ -20295,7 +19564,7 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph } fprintf(fp, "CONST %d [%" PRId64 ", %" PRId64 "]", i, node->ne[0], node->ne[1]); - if (ggml_nelements(node) < 5) { + if (ggml_nelements(node) < 5 && node->data != NULL) { fprintf(fp, " | ("); for (int j = 0; j < ggml_nelements(node); j++) { if (node->type == GGML_TYPE_I8 || node->type == GGML_TYPE_I16 || node->type == GGML_TYPE_I32) { @@ -20754,9 +20023,9 @@ static enum ggml_opt_result linesearch_backtracking( (*step) *= width; } - GGML_ASSERT(false && "line search failed"); + GGML_ABORT("line search failed"); - return GGML_LINESEARCH_FAIL; + //return GGML_LINESEARCH_FAIL; } static enum ggml_opt_result ggml_opt_lbfgs( @@ -21024,9 +20293,9 @@ static enum ggml_opt_result ggml_opt_lbfgs( step[0] = 1.0; } - GGML_ASSERT(false && "lbfgs failed"); + GGML_ABORT("lbfgs failed"); - return GGML_OPT_RESULT_DID_NOT_CONVERGE; + //return GGML_OPT_RESULT_DID_NOT_CONVERGE; } struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) { @@ -21351,6 +20620,9 @@ size_t ggml_quantize_chunk( case GGML_TYPE_IQ1_M: result = quantize_iq1_m (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ4_NL: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ4_XS: result = quantize_iq4_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q4_0_4_4: result = quantize_q4_0_4x4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q4_0_4_8: result = quantize_q4_0_4x8(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q4_0_8_8: result = quantize_q4_0_8x8(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_F16: { size_t elemsize = sizeof(ggml_fp16_t); @@ -21581,6 +20853,7 @@ struct gguf_context * gguf_init_empty(void) { struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params) { FILE * file = ggml_fopen(fname, "rb"); if (!file) { + fprintf(stderr, "%s: failed to open '%s': '%s'\n", __func__, fname, strerror(errno)); return NULL; } @@ -21717,10 +20990,10 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p } } break; case GGUF_TYPE_ARRAY: - default: GGML_ASSERT(false && "invalid type"); break; + default: GGML_ABORT("invalid type"); } } break; - default: GGML_ASSERT(false && "invalid type"); + default: GGML_ABORT("invalid type"); } if (!ok) { @@ -21765,7 +21038,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p gguf_tensor_info_sanitize(info); // make sure there is no duplicated tensor names - for (uint64_t j = 0; j < i; ++j) { + for (uint64_t j = 0; j < i && ok; ++j) { if (strcmp(info->name.data, ctx->infos[j].name.data) == 0) { fprintf(stderr, "%s: duplicated tensor name %s\n", __func__, info->name.data); ok = false; @@ -21814,8 +21087,8 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p (int64_t) info->ne[3]; if (ne % ggml_blck_size(info->type) != 0) { - fprintf(stderr, "%s: tensor '%s' of type %d (%s) number of elements (%" PRId64 ") is not a multiple of block size (%d)\n", - __func__, info->name.data, (int)info->type, ggml_type_name(info->type), ne, ggml_blck_size(info->type)); + fprintf(stderr, "%s: tensor '%s' of type %d (%s) number of elements (%" PRId64 ") is not a multiple of block size (%" PRId64 ")\n", + __func__, info->name.data, (int) info->type, ggml_type_name(info->type), ne, ggml_blck_size(info->type)); fclose(file); gguf_free(ctx); return NULL; @@ -21846,6 +21119,12 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p }; *params.ctx = ggml_init(pdata); + if (*params.ctx == NULL) { + fprintf(stderr, "%s: failed to initialize context\n", __func__); + fclose(file); + gguf_free(ctx); + return NULL; + } struct ggml_context * ctx_data = *params.ctx; @@ -22295,12 +21574,12 @@ void gguf_set_kv(struct gguf_context * ctx, struct gguf_context * src) { gguf_set_arr_str(ctx, src->kv[i].key.data, data, src->kv[i].value.arr.n); GGML_FREE((void *)data); } else if (src->kv[i].value.arr.type == GGUF_TYPE_ARRAY) { - GGML_ASSERT(false && "nested arrays not supported"); + GGML_ABORT("nested arrays not supported"); } else { gguf_set_arr_data(ctx, src->kv[i].key.data, src->kv[i].value.arr.type, src->kv[i].value.arr.data, src->kv[i].value.arr.n); } } break; - default: GGML_ASSERT(false && "invalid type"); break; + default: GGML_ABORT("invalid type"); } } } @@ -22309,7 +21588,7 @@ void gguf_add_tensor( struct gguf_context * ctx, const struct ggml_tensor * tensor) { if (gguf_find_tensor(ctx, tensor->name) != -1) { - GGML_ASSERT(false && "duplicated tensor name"); + GGML_ABORT("duplicated tensor name"); } const int idx = ctx->header.n_tensors; @@ -22342,7 +21621,7 @@ void gguf_add_tensor( void gguf_set_tensor_type(struct gguf_context * ctx, const char * name, enum ggml_type type) { const int idx = gguf_find_tensor(ctx, name); if (idx < 0) { - GGML_ASSERT(false && "tensor not found"); + GGML_ABORT("tensor not found"); } ctx->infos[idx].type = type; @@ -22351,7 +21630,7 @@ void gguf_set_tensor_type(struct gguf_context * ctx, const char * name, enum ggm void gguf_set_tensor_data(struct gguf_context * ctx, const char * name, const void * data, size_t size) { const int idx = gguf_find_tensor(ctx, name); if (idx < 0) { - GGML_ASSERT(false && "tensor not found"); + GGML_ABORT("tensor not found"); } ctx->infos[idx].data = data; @@ -22480,10 +21759,10 @@ static void gguf_write_to_buf(const struct gguf_context * ctx, struct gguf_buf * } } break; case GGUF_TYPE_ARRAY: - default: GGML_ASSERT(false && "invalid type"); break; + default: GGML_ABORT("invalid type"); } } break; - default: GGML_ASSERT(false && "invalid type"); + default: GGML_ABORT("invalid type"); } } @@ -22544,7 +21823,7 @@ static void gguf_write_to_buf(const struct gguf_context * ctx, struct gguf_buf * void gguf_write_to_file(const struct gguf_context * ctx, const char * fname, bool only_meta) { FILE * file = ggml_fopen(fname, "wb"); if (!file) { - GGML_ASSERT(false && "failed to open file for writing"); + GGML_ABORT("failed to open file for writing"); } struct gguf_buf buf = gguf_buf_init(16*1024); @@ -22653,8 +21932,6 @@ int ggml_cpu_has_neon(void) { int ggml_cpu_has_sve(void) { #if defined(__ARM_FEATURE_SVE) - // TODO: Currently, SVE 256 bit is only supported. - GGML_ASSERT(svcntb() == QK8_0); return 1; #else return 0; @@ -22702,7 +21979,7 @@ int ggml_cpu_has_wasm_simd(void) { } int ggml_cpu_has_blas(void) { -#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUDA) || defined(GGML_USE_VULKAN) || defined(GGML_USE_SYCL) +#if defined(GGML_USE_BLAS) || defined(GGML_USE_CUDA) || defined(GGML_USE_VULKAN) || defined(GGML_USE_SYCL) return 1; #else return 0; @@ -22749,6 +22026,22 @@ int ggml_cpu_has_rpc(void) { #endif } +int ggml_cpu_has_cann(void) { +#if defined(GGML_USE_CANN) + return 1; +#else + return 0; +#endif +} + +int ggml_cpu_has_llamafile(void) { +#if defined(GGML_USE_LLAMAFILE) + return 1; +#else + return 0; +#endif +} + int ggml_cpu_has_gpublas(void) { return ggml_cpu_has_cuda() || ggml_cpu_has_vulkan() || ggml_cpu_has_kompute() || ggml_cpu_has_sycl(); } diff --git a/llama/ggml.h b/llama/ggml.h index 03ffc2cf..f5821853 100644 --- a/llama/ggml.h +++ b/llama/ggml.h @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * @@ -280,18 +280,8 @@ #define GGML_PAD(x, n) (((x) + (n) - 1) & ~((n) - 1)) -#define GGML_ASSERT(x) \ - do { \ - if (!(x)) { \ - fflush(stdout); \ - fprintf(stderr, "GGML_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \ - ggml_print_backtrace(); \ - abort(); \ - } \ - } while (0) - #ifndef NDEBUG -#define GGML_UNREACHABLE() GGML_ASSERT(!"statement should not be reached") +#define GGML_UNREACHABLE() do { fprintf(stderr, "statement should be unreachable\n"); abort(); } while(0) #elif defined(__GNUC__) #define GGML_UNREACHABLE() __builtin_unreachable() #elif defined(_MSC_VER) @@ -300,6 +290,17 @@ #define GGML_UNREACHABLE() ((void) 0) #endif +#ifdef __cplusplus +#define GGML_NORETURN [[noreturn]] +#elif defined(_MSC_VER) +#define GGML_NORETURN __declspec(noreturn) +#else +#define GGML_NORETURN _Noreturn +#endif + +#define GGML_ABORT(...) ggml_abort(__FILE__, __LINE__, __VA_ARGS__) +#define GGML_ASSERT(x) if (!(x)) GGML_ABORT("GGML_ASSERT(%s) failed", #x) + // used to copy the number of elements and stride in bytes of tensors into local variables. // main purpose is to reduce code duplication and improve readability. // @@ -338,10 +339,19 @@ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \ GGML_TENSOR_LOCALS(size_t, nb, dst, nb) +#define GGML_TENSOR_BINARY_OP_LOCALS01 \ + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \ + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \ + GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \ + GGML_TENSOR_LOCALS(size_t, nb1, src1, nb) + #ifdef __cplusplus extern "C" { #endif + GGML_NORETURN GGML_ATTRIBUTE_FORMAT(3, 4) + GGML_API void ggml_abort(const char * file, int line, const char * fmt, ...); + enum ggml_status { GGML_STATUS_ALLOC_FAILED = -2, GGML_STATUS_FAILED = -1, @@ -403,6 +413,9 @@ extern "C" { GGML_TYPE_F64 = 28, GGML_TYPE_IQ1_M = 29, GGML_TYPE_BF16 = 30, + GGML_TYPE_Q4_0_4_4 = 31, + GGML_TYPE_Q4_0_4_8 = 32, + GGML_TYPE_Q4_0_8_8 = 33, GGML_TYPE_COUNT, }; @@ -444,6 +457,9 @@ extern "C" { GGML_FTYPE_MOSTLY_IQ4_XS = 22, // except 1d tensors GGML_FTYPE_MOSTLY_IQ1_M = 23, // except 1d tensors GGML_FTYPE_MOSTLY_BF16 = 24, // except 1d tensors + GGML_FTYPE_MOSTLY_Q4_0_4_4 = 25, // except 1d tensors + GGML_FTYPE_MOSTLY_Q4_0_4_8 = 26, // except 1d tensors + GGML_FTYPE_MOSTLY_Q4_0_8_8 = 27, // except 1d tensors }; // available tensor operations: @@ -611,11 +627,7 @@ extern "C" { struct ggml_tensor * grad; struct ggml_tensor * src[GGML_MAX_SRC]; - // performance - int perf_runs; - int64_t perf_cycles; - int64_t perf_time_us; - + // source tensor and offset for views struct ggml_tensor * view_src; size_t view_offs; @@ -625,7 +637,7 @@ extern "C" { void * extra; // extra things e.g. for ggml-cuda.cu - char padding[8]; + // char padding[4]; }; static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor); @@ -654,8 +666,11 @@ extern "C" { GGML_CGRAPH_EVAL_ORDER_COUNT }; + typedef uint32_t ggml_bitset_t; + struct ggml_hash_set { size_t size; + ggml_bitset_t * used; struct ggml_tensor ** keys; }; @@ -669,14 +684,9 @@ extern "C" { struct ggml_tensor ** grads; struct ggml_tensor ** leafs; - struct ggml_hash_set visited_hash_table; + struct ggml_hash_set visited_hash_set; enum ggml_cgraph_eval_order order; - - // performance - int perf_runs; - int64_t perf_cycles; - int64_t perf_time_us; }; // scratch buffer @@ -693,28 +703,6 @@ extern "C" { bool no_alloc; // don't allocate memory for the tensor data }; - - // compute types - - // NOTE: the INIT or FINALIZE pass is not scheduled unless explicitly enabled. - // This behavior was changed since https://github.com/ggerganov/llama.cpp/pull/1995. - enum ggml_task_type { - GGML_TASK_TYPE_INIT = 0, - GGML_TASK_TYPE_COMPUTE, - GGML_TASK_TYPE_FINALIZE, - }; - - struct ggml_compute_params { - enum ggml_task_type type; - - // ith = thread index, nth = number of threads - int ith, nth; - - // work buffer for all threads - size_t wsize; - void * wdata; - }; - // numa strategies enum ggml_numa_strategy { GGML_NUMA_STRATEGY_DISABLED = 0, @@ -743,8 +731,6 @@ extern "C" { GGML_API int64_t ggml_cycles(void); GGML_API int64_t ggml_cycles_per_ms(void); - GGML_API void ggml_print_backtrace(void); - // accepts a UTF-8 path, even on Windows GGML_API FILE * ggml_fopen(const char * fname, const char * mode); @@ -759,9 +745,9 @@ extern "C" { GGML_API GGML_CALL size_t ggml_nbytes (const struct ggml_tensor * tensor); GGML_API size_t ggml_nbytes_pad (const struct ggml_tensor * tensor); // same as ggml_nbytes() but padded to GGML_MEM_ALIGN - GGML_API GGML_CALL int ggml_blck_size(enum ggml_type type); - GGML_API GGML_CALL size_t ggml_type_size(enum ggml_type type); // size in bytes for all elements in a block - GGML_API GGML_CALL size_t ggml_row_size (enum ggml_type type, int64_t ne); // size in bytes for all elements in a row + GGML_API GGML_CALL int64_t ggml_blck_size(enum ggml_type type); + GGML_API GGML_CALL size_t ggml_type_size(enum ggml_type type); // size in bytes for all elements in a block + GGML_API GGML_CALL size_t ggml_row_size (enum ggml_type type, int64_t ne); // size in bytes for all elements in a row GGML_DEPRECATED( GGML_API double ggml_type_sizef(enum ggml_type type), // ggml_type_size()/ggml_blck_size() as float @@ -798,6 +784,8 @@ extern "C" { GGML_API bool ggml_are_same_shape (const struct ggml_tensor * t0, const struct ggml_tensor * t1); GGML_API bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tensor * t1); + GGML_API bool ggml_can_repeat(const struct ggml_tensor * t0, const struct ggml_tensor * t1); + // use this to compute the memory overhead of a tensor GGML_API size_t ggml_tensor_overhead(void); @@ -2048,8 +2036,8 @@ extern "C" { // ggml_graph_plan() has to be called before ggml_graph_compute() // when plan.work_size > 0, caller must allocate memory for plan.work_data - GGML_API struct ggml_cplan ggml_graph_plan (const struct ggml_cgraph * cgraph, int n_threads /*= GGML_DEFAULT_N_THREADS*/); - GGML_API enum ggml_status ggml_graph_compute ( struct ggml_cgraph * cgraph, struct ggml_cplan * cplan); + GGML_API struct ggml_cplan ggml_graph_plan (const struct ggml_cgraph * cgraph, int n_threads /*= GGML_DEFAULT_N_THREADS*/); + GGML_API enum ggml_status ggml_graph_compute( struct ggml_cgraph * cgraph, struct ggml_cplan * cplan); // same as ggml_graph_compute() but the work data is allocated as a part of the context // note: the drawback of this API is that you must have ensured that the context has enough memory for the work data GGML_API enum ggml_status ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads); @@ -2442,6 +2430,8 @@ extern "C" { GGML_API int ggml_cpu_has_rpc (void); GGML_API int ggml_cpu_has_vsx (void); GGML_API int ggml_cpu_has_matmul_int8(void); + GGML_API int ggml_cpu_has_cann (void); + GGML_API int ggml_cpu_has_llamafile (void); // // Internal types and functions exposed for tests and benchmarks @@ -2455,20 +2445,31 @@ extern "C" { #endif typedef void (*ggml_to_float_t) (const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); typedef void (*ggml_from_float_t)(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); - typedef void (*ggml_vec_dot_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x, size_t bx, - const void * GGML_RESTRICT y, size_t by, int nrc); + typedef void (*ggml_from_float_to_mat_t) + (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t nr, int64_t k, int64_t bs); + typedef void (*ggml_vec_dot_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x, size_t bx, + const void * GGML_RESTRICT y, size_t by, int nrc); + typedef void (*ggml_gemv_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x, + const void * GGML_RESTRICT y, int nr, int nc); + typedef void (*ggml_gemm_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x, + const void * GGML_RESTRICT y, int nr, int nc); typedef struct { - const char * type_name; - int blck_size; - size_t type_size; - bool is_quantized; - ggml_to_float_t to_float; - ggml_from_float_t from_float; - ggml_from_float_t from_float_reference; - ggml_vec_dot_t vec_dot; - enum ggml_type vec_dot_type; - int64_t nrows; // number of rows to process simultaneously; + const char * type_name; + int64_t blck_size; + int64_t blck_size_interleave; // interleave elements in blocks + size_t type_size; + bool is_quantized; + ggml_to_float_t to_float; + ggml_from_float_t from_float; + ggml_from_float_t from_float_ref; + ggml_from_float_to_mat_t from_float_to_mat; + ggml_vec_dot_t vec_dot; + enum ggml_type vec_dot_type; + int64_t nrows; // number of rows to process simultaneously + int64_t ncols; // number of columns to process simultaneously + ggml_gemv_t gemv; + ggml_gemm_t gemm; } ggml_type_traits_t; GGML_API ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type); diff --git a/llama/grammar-parser.cpp b/llama/grammar-parser.cpp index f5d4404e..ebfb3198 100644 --- a/llama/grammar-parser.cpp +++ b/llama/grammar-parser.cpp @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/grammar-parser.h b/llama/grammar-parser.h index 4c7f1e8a..9a24cad8 100644 --- a/llama/grammar-parser.h +++ b/llama/grammar-parser.h @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/json-schema-to-grammar.cpp b/llama/json-schema-to-grammar.cpp index 737f05b7..e78c57ab 100644 --- a/llama/json-schema-to-grammar.cpp +++ b/llama/json-schema-to-grammar.cpp @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * @@ -66,7 +66,234 @@ static std::string build_repetition(const std::string & item_rule, int min_items return result; } -const std::string SPACE_RULE = "\" \"?"; +/* Minimalistic replacement for std::string_view, which is only available from C++17 onwards */ +class string_view { + const std::string & _str; + const size_t _start; + const size_t _end; +public: + string_view(const std::string & str, size_t start = 0, size_t end = std::string::npos) : _str(str), _start(start), _end(end == std::string::npos ? str.length() : end) {} + + size_t size() const { + return _end - _start; + } + + size_t length() const { + return size(); + } + + operator std::string() const { + return str(); + } + + std::string str() const { + return _str.substr(_start, _end - _start); + } + + string_view substr(size_t pos, size_t len = std::string::npos) const { + return string_view(_str, _start + pos, len == std::string::npos ? _end : _start + pos + len); + } + + char operator[](size_t pos) const { + auto index = _start + pos; + if (index >= _end) { + throw std::out_of_range("string_view index out of range"); + } + return _str[_start + pos]; + } + + bool operator==(const string_view & other) const { + std::string this_str = *this; + std::string other_str = other; + return this_str == other_str; + } +}; + +static void _build_min_max_int(int min_value, int max_value, std::stringstream & out, int decimals_left = 16, bool top_level = true) { + auto has_min = min_value != std::numeric_limits::min(); + auto has_max = max_value != std::numeric_limits::max(); + + auto digit_range = [&](char from, char to) { + out << "["; + if (from == to) { + out << from; + } else { + out << from << "-" << to; + } + out << "]"; + }; + auto more_digits = [&](int min_digits, int max_digits) { + out << "[0-9]"; + if (min_digits == max_digits && min_digits == 1) { + return; + } + out << "{"; + out << min_digits; + if (max_digits != min_digits) { + out << ","; + if (max_digits != std::numeric_limits::max()) { + out << max_digits; + } + } + out << "}"; + }; + std::function uniform_range = + [&](const string_view & from, const string_view & to) { + size_t i = 0; + while (i < from.length() && i < to.length() && from[i] == to[i]) { + i++; + } + if (i > 0) { + out << "\"" << from.substr(0, i).str() << "\""; + } + if (i < from.length() && i < to.length()) { + if (i > 0) { + out << " "; + } + auto sub_len = from.length() - i - 1; + if (sub_len > 0) { + auto from_sub = from.substr(i + 1); + auto to_sub = to.substr(i + 1); + auto sub_zeros = repeat("0", sub_len); + auto sub_nines = repeat("9", sub_len); + + auto to_reached = false; + out << "("; + if (from_sub == sub_zeros) { + digit_range(from[i], to[i] - 1); + out << " "; + more_digits(sub_len, sub_len); + } else { + out << "[" << from[i] << "] "; + out << "("; + uniform_range(from_sub, sub_nines); + out << ")"; + if (from[i] < to[i] - 1) { + out << " | "; + if (to_sub == sub_nines) { + digit_range(from[i] + 1, to[i]); + to_reached = true; + } else { + digit_range(from[i] + 1, to[i] - 1); + } + out << " "; + more_digits(sub_len, sub_len); + } + } + if (!to_reached) { + out << " | "; + digit_range(to[i], to[i]); + out << " "; + uniform_range(sub_zeros, to_sub); + } + out << ")"; + } else { + out << "[" << from[i] << "-" << to[i] << "]"; + } + } + }; + + if (has_min && has_max) { + if (min_value < 0 && max_value < 0) { + out << "\"-\" ("; + _build_min_max_int(-max_value, -min_value, out, decimals_left, /* top_level= */ true); + out << ")"; + return; + } + + if (min_value < 0) { + out << "\"-\" ("; + _build_min_max_int(0, -min_value, out, decimals_left, /* top_level= */ true); + out << ") | "; + min_value = 0; + } + + auto min_s = std::to_string(min_value); + auto max_s = std::to_string(max_value); + auto min_digits = min_s.length(); + auto max_digits = max_s.length(); + + for (auto digits = min_digits; digits < max_digits; digits++) { + uniform_range(min_s, repeat("9", digits)); + min_s = "1" + repeat("0", digits); + out << " | "; + } + uniform_range(min_s, max_s); + return; + } + + auto less_decimals = std::max(decimals_left - 1, 1); + + if (has_min) { + if (min_value < 0) { + out << "\"-\" ("; + _build_min_max_int(std::numeric_limits::min(), -min_value, out, decimals_left, /* top_level= */ false); + out << ") | [0] | [1-9] "; + more_digits(0, decimals_left - 1); + } else if (min_value == 0) { + if (top_level) { + out << "[0] | [1-9] "; + more_digits(0, less_decimals); + } else { + more_digits(1, decimals_left); + } + } else if (min_value <= 9) { + char c = '0' + min_value; + auto range_start = top_level ? '1' : '0'; + if (c > range_start) { + digit_range(range_start, c - 1); + out << " "; + more_digits(1, less_decimals); + out << " | "; + } + digit_range(c, '9'); + out << " "; + more_digits(0, less_decimals); + } else { + auto min_s = std::to_string(min_value); + auto len = min_s.length(); + auto c = min_s[0]; + + if (c > '1') { + digit_range(top_level ? '1' : '0', c - 1); + out << " "; + more_digits(len, less_decimals); + out << " | "; + } + digit_range(c, c); + out << " ("; + _build_min_max_int(std::stoi(min_s.substr(1)), std::numeric_limits::max(), out, less_decimals, /* top_level= */ false); + out << ")"; + if (c < '9') { + out << " | "; + digit_range(c + 1, '9'); + out << " "; + more_digits(len - 1, less_decimals); + } + } + return; + } + + if (has_max) { + if (max_value >= 0) { + if (top_level) { + out << "\"-\" [1-9] "; + more_digits(0, less_decimals); + out << " | "; + } + _build_min_max_int(0, max_value, out, decimals_left, /* top_level= */ true); + } else { + out << "\"-\" ("; + _build_min_max_int(-max_value, std::numeric_limits::max(), out, decimals_left, /* top_level= */ false); + out << ")"; + } + return; + } + + throw std::runtime_error("At least one of min_value or max_value must be set"); +} + +const std::string SPACE_RULE = "| \" \" | \"\\n\" [ \\t]{0,20}"; struct BuiltinRule { std::string content; @@ -83,7 +310,7 @@ std::unordered_map PRIMITIVE_RULES = { {"object", {"\"{\" space ( string \":\" space value (\",\" space string \":\" space value)* )? \"}\" space", {"string", "value"}}}, {"array", {"\"[\" space ( value (\",\" space value)* )? \"]\" space", {"value"}}}, {"uuid", {"\"\\\"\" [0-9a-fA-F]{8} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{12} \"\\\"\" space", {}}}, - {"char", {"[^\"\\\\] | \"\\\\\" ([\"\\\\/bfnrt] | \"u\" [0-9a-fA-F]{4})", {}}}, + {"char", {"[^\"\\\\\\x7F\\x00-\\x1F] | [\\\\] ([\"\\\\bfnrt] | \"u\" [0-9a-fA-F]{4})", {}}}, {"string", {"\"\\\"\" char* \"\\\"\" space", {"char"}}}, {"null", {"\"null\" space", {}}}, }; @@ -115,7 +342,7 @@ std::unordered_map GRAMMAR_LITERAL_ESCAPES = { }; std::unordered_set NON_LITERAL_SET = {'|', '.', '(', ')', '[', ']', '{', '}', '*', '+', '?'}; -std::unordered_set ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = {'[', ']', '(', ')', '|', '{', '}', '*', '+', '?'}; +std::unordered_set ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = {'^', '$', '.', '[', ']', '(', ')', '|', '{', '}', '*', '+', '?'}; template std::string join(Iterator begin, Iterator end, const std::string & separator) { @@ -186,7 +413,6 @@ static std::string format_literal(const std::string & literal) { return "\"" + escaped + "\""; } - class SchemaConverter { private: std::function _fetch_json; @@ -414,6 +640,75 @@ private: return _add_rule(name, "\"\\\"\" " + to_rule(transform()) + " \"\\\"\" space"); } + /* + Returns a rule that matches a JSON string that is none of the provided strings + + not_strings({"a"}) + -> ["] ( [a] char+ | [^"a] char* )? ["] space + not_strings({"and", "also"}) + -> ["] ( [a] ([l] ([s] ([o] char+ | [^"o] char*) | [^"s] char*) | [n] ([d] char+ | [^"d] char*) | [^"ln] char*) | [^"a] char* )? ["] space + */ + std::string _not_strings(const std::vector & strings) { + + struct TrieNode { + std::map children; + bool is_end_of_string; + + TrieNode() : is_end_of_string(false) {} + + void insert(const std::string & string) { + auto node = this; + for (char c : string) { + node = &node->children[c]; + } + node->is_end_of_string = true; + } + }; + + TrieNode trie; + for (const auto & s : strings) { + trie.insert(s); + } + + std::string char_rule = _add_primitive("char", PRIMITIVE_RULES.at("char")); + std::ostringstream out; + out << "[\"] ( "; + std::function visit = [&](const TrieNode & node) { + std::ostringstream rejects; + auto first = true; + for (const auto & kv : node.children) { + rejects << kv.first; + if (first) { + first = false; + } else { + out << " | "; + } + out << "[" << kv.first << "]"; + if (!kv.second.children.empty()) { + out << " ("; + visit(kv.second); + out << ")"; + } else if (kv.second.is_end_of_string) { + out << " " << char_rule << "+"; + } + } + if (!node.children.empty()) { + if (!first) { + out << " | "; + } + out << "[^\"" << rejects.str() << "] " << char_rule << "*"; + } + }; + visit(trie); + + out << " )"; + if (!trie.is_end_of_string) { + out << "?"; + } + out << " [\"] space"; + return out.str(); + } + std::string _resolve_ref(const std::string & ref) { std::string ref_name = ref.substr(ref.find_last_of('/') + 1); if (_rules.find(ref_name) == _rules.end() && _refs_being_resolved.find(ref) == _refs_being_resolved.end()) { @@ -434,6 +729,7 @@ private: std::vector required_props; std::vector optional_props; std::unordered_map prop_kv_rule_names; + std::vector prop_names; for (const auto & kv : properties) { const auto &prop_name = kv.first; const auto &prop_schema = kv.second; @@ -448,11 +744,18 @@ private: } else { optional_props.push_back(prop_name); } + prop_names.push_back(prop_name); } - if (additional_properties.is_object() || (additional_properties.is_boolean() && additional_properties.get())) { + if ((additional_properties.is_boolean() && additional_properties.get()) || additional_properties.is_object()) { std::string sub_name = name + (name.empty() ? "" : "-") + "additional"; - std::string value_rule = visit(additional_properties.is_object() ? additional_properties : json::object(), sub_name + "-value"); - std::string kv_rule = _add_rule(sub_name + "-kv", _add_primitive("string", PRIMITIVE_RULES.at("string")) + " \":\" space " + value_rule); + std::string value_rule = + additional_properties.is_object() ? visit(additional_properties, sub_name + "-value") + : _add_primitive("value", PRIMITIVE_RULES.at("value")); + + auto key_rule = + prop_names.empty() ? _add_primitive("string", PRIMITIVE_RULES.at("string")) + : _add_rule(sub_name + "-k", _not_strings(prop_names)); + std::string kv_rule = _add_rule(sub_name + "-kv", key_rule + " \":\" space " + value_rule); prop_kv_rule_names["*"] = kv_rule; optional_props.push_back("*"); } @@ -478,15 +781,11 @@ private: } std::string k = ks[0]; std::string kv_rule_name = prop_kv_rule_names[k]; - if (k == "*") { - res = _add_rule( - name + (name.empty() ? "" : "-") + "additional-kvs", - kv_rule_name + " ( \",\" space " + kv_rule_name + " )*" - ); - } else if (first_is_optional) { - res = "( \",\" space " + kv_rule_name + " )?"; + std::string comma_ref = "( \",\" space " + kv_rule_name + " )"; + if (first_is_optional) { + res = comma_ref + (k == "*" ? "*" : "?"); } else { - res = kv_rule_name; + res = kv_rule_name + (k == "*" ? " " + comma_ref + "*" : ""); } if (ks.size() > 1) { res += " " + _add_rule( @@ -620,17 +919,19 @@ public: } else if (schema_type.is_array()) { std::vector schema_types; for (const auto & t : schema_type) { - schema_types.push_back({{"type", t}}); + json schema_copy(schema); + schema_copy["type"] = t; + schema_types.push_back(schema_copy); } return _add_rule(rule_name, _generate_union_rule(name, schema_types)); } else if (schema.contains("const")) { - return _add_rule(rule_name, _generate_constant_rule(schema["const"])); + return _add_rule(rule_name, _generate_constant_rule(schema["const"]) + " space"); } else if (schema.contains("enum")) { std::vector enum_values; for (const auto & v : schema["enum"]) { enum_values.push_back(_generate_constant_rule(v)); } - return _add_rule(rule_name, join(enum_values.begin(), enum_values.end(), " | ")); + return _add_rule(rule_name, "(" + join(enum_values.begin(), enum_values.end(), " | ") + ") space"); } else if ((schema_type.is_null() || schema_type == "object") && (schema.contains("properties") || (schema.contains("additionalProperties") && schema["additionalProperties"] != true))) { @@ -712,6 +1013,24 @@ public: int min_len = schema.contains("minLength") ? schema["minLength"].get() : 0; int max_len = schema.contains("maxLength") ? schema["maxLength"].get() : std::numeric_limits::max(); return _add_rule(rule_name, "\"\\\"\" " + build_repetition(char_rule, min_len, max_len) + " \"\\\"\" space"); + } else if (schema_type == "integer" && (schema.contains("minimum") || schema.contains("exclusiveMinimum") || schema.contains("maximum") || schema.contains("exclusiveMaximum"))) { + int min_value = std::numeric_limits::min(); + int max_value = std::numeric_limits::max(); + if (schema.contains("minimum")) { + min_value = schema["minimum"].get(); + } else if (schema.contains("exclusiveMinimum")) { + min_value = schema["exclusiveMinimum"].get() + 1; + } + if (schema.contains("maximum")) { + max_value = schema["maximum"].get(); + } else if (schema.contains("exclusiveMaximum")) { + max_value = schema["exclusiveMaximum"].get() - 1; + } + std::stringstream out; + out << "("; + _build_min_max_int(min_value, max_value, out); + out << ") space"; + return _add_rule(rule_name, out.str()); } else if (schema.empty() || schema_type == "object") { return _add_rule(rule_name, _add_primitive("object", PRIMITIVE_RULES.at("object"))); } else { diff --git a/llama/json-schema-to-grammar.h b/llama/json-schema-to-grammar.h index 053f509d..d3311b70 100644 --- a/llama/json-schema-to-grammar.h +++ b/llama/json-schema-to-grammar.h @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * diff --git a/llama/llama-grammar.cpp b/llama/llama-grammar.cpp new file mode 100644 index 00000000..e5e67c7b --- /dev/null +++ b/llama/llama-grammar.cpp @@ -0,0 +1,565 @@ +/** + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file + * + * MIT License + * + * Copyright (c) 2023-2024 The ggml authors + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#include "llama-grammar.h" + +#include "llama-vocab.h" +#include "llama-sampling.h" + +#include + +// Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as +// pointer. If an invalid sequence is encountered, returns `llama_partial_utf8.n_remain == -1`. +std::pair, llama_partial_utf8> decode_utf8( + const std::string & src, + llama_partial_utf8 partial_start) { + static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 }; + const char * pos = src.c_str(); + std::vector code_points; + + // common english strings have the same number of codepoints and bytes. `+ 1` for the terminating 0. + code_points.reserve(src.size() + 1); + uint32_t value = partial_start.value; + int n_remain = partial_start.n_remain; + + // continue previous decode, if applicable + while (*pos != 0 && n_remain > 0) { + uint8_t next_byte = static_cast(*pos); + if ((next_byte >> 6) != 2) { + // invalid sequence, abort + code_points.push_back(0); + return std::make_pair(std::move(code_points), llama_partial_utf8{ 0, -1 }); + } + value = (value << 6) + (next_byte & 0x3F); + ++pos; + --n_remain; + } + + if (partial_start.n_remain > 0 && n_remain == 0) { + code_points.push_back(value); + } + + // decode any subsequent utf-8 sequences, which may end in an incomplete one + while (*pos != 0) { + uint8_t first_byte = static_cast(*pos); + uint8_t highbits = first_byte >> 4; + n_remain = lookup[highbits] - 1; + + if (n_remain < 0) { + // invalid sequence, abort + code_points.clear(); + code_points.push_back(0); + return std::make_pair(std::move(code_points), llama_partial_utf8{ 0, n_remain }); + } + + uint8_t mask = (1 << (7 - n_remain)) - 1; + value = first_byte & mask; + + ++pos; + while (*pos != 0 && n_remain > 0) { + value = (value << 6) + (static_cast(*pos) & 0x3F); + ++pos; + --n_remain; + } + if (n_remain == 0) { + code_points.push_back(value); + } + } + code_points.push_back(0); + + return std::make_pair(std::move(code_points), llama_partial_utf8{ value, n_remain }); +} + +const llama_grammar_rules & llama_grammar_get_rules(const struct llama_grammar * grammar) { + return grammar->rules; +} + +llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar) { + return grammar->stacks; +} + +// returns true iff pos points to the end of one of the definitions of a rule +static bool llama_grammar_is_end_of_sequence(const llama_grammar_element * pos) { + switch (pos->type) { + case LLAMA_GRETYPE_END: return true; // NOLINT + case LLAMA_GRETYPE_ALT: return true; // NOLINT + default: return false; + } +} + +// returns true iff chr satisfies the char range at pos (regular or inverse range) +// asserts that pos is pointing to a char range element +static std::pair llama_grammar_match_char( + const llama_grammar_element * pos, + const uint32_t chr) { + + bool found = false; + bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR || pos->type == LLAMA_GRETYPE_CHAR_ANY; + + GGML_ASSERT(is_positive_char || pos->type == LLAMA_GRETYPE_CHAR_NOT); // NOLINT + + do { + if (pos[1].type == LLAMA_GRETYPE_CHAR_RNG_UPPER) { + // inclusive range, e.g. [a-z] + found = found || (pos->value <= chr && chr <= pos[1].value); + pos += 2; + } else if (pos->type == LLAMA_GRETYPE_CHAR_ANY) { + // Any character matches "." + found = true; + pos += 1; + } else { + // exact char match, e.g. [a] or "a" + found = found || pos->value == chr; + pos += 1; + } + } while (pos->type == LLAMA_GRETYPE_CHAR_ALT); + + return std::make_pair(found == is_positive_char, pos); +} + +// returns true iff some continuation of the given partial UTF-8 sequence could satisfy the char +// range at pos (regular or inverse range) +// asserts that pos is pointing to a char range element +static bool llama_grammar_match_partial_char( + const llama_grammar_element * pos, + const llama_partial_utf8 partial_utf8) { + bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR || pos->type == LLAMA_GRETYPE_CHAR_ANY; + GGML_ASSERT(is_positive_char || pos->type == LLAMA_GRETYPE_CHAR_NOT); + + uint32_t partial_value = partial_utf8.value; + int n_remain = partial_utf8.n_remain; + + // invalid sequence or 7-bit char split across 2 bytes (overlong) + if (n_remain < 0 || (n_remain == 1 && partial_value < 2)) { + return false; + } + + // range of possible code points this partial UTF-8 sequence could complete to + uint32_t low = partial_value << (n_remain * 6); + uint32_t high = low | ((1 << (n_remain * 6)) - 1); + + if (low == 0) { + if (n_remain == 2) { + low = 1 << 11; + } else if (n_remain == 3) { + low = 1 << 16; + } + } + + do { + if (pos[1].type == LLAMA_GRETYPE_CHAR_RNG_UPPER) { + // inclusive range, e.g. [a-z] + if (pos->value <= high && low <= pos[1].value) { + return is_positive_char; + } + pos += 2; + } else if (pos->type == LLAMA_GRETYPE_CHAR_ANY) { + // Any character matches "." + return true; + } else { + // exact char match, e.g. [a] or "a" + if (low <= pos->value && pos->value <= high) { + return is_positive_char; + } + pos += 1; + } + } while (pos->type == LLAMA_GRETYPE_CHAR_ALT); + + return !is_positive_char; +} + +// transforms a grammar pushdown stack into N possible stacks, all ending +// at a character range (terminal element) +static void llama_grammar_advance_stack( + const llama_grammar_rules & rules, + const llama_grammar_stack & stack, + llama_grammar_stacks & new_stacks) { + if (stack.empty()) { + if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) { + new_stacks.emplace_back(stack); + } + return; + } + + const llama_grammar_element * pos = stack.back(); + + switch (pos->type) { + case LLAMA_GRETYPE_RULE_REF: { + const size_t rule_id = static_cast(pos->value); + const llama_grammar_element * subpos = rules[rule_id].data(); + do { + // init new stack without the top (pos) + llama_grammar_stack new_stack(stack.begin(), stack.end() - 1); + if (!llama_grammar_is_end_of_sequence(pos + 1)) { + // if this rule ref is followed by another element, add that to stack + new_stack.push_back(pos + 1); + } + if (!llama_grammar_is_end_of_sequence(subpos)) { + // if alternate is nonempty, add to stack + new_stack.push_back(subpos); + } + llama_grammar_advance_stack(rules, new_stack, new_stacks); + while (!llama_grammar_is_end_of_sequence(subpos)) { + // scan to end of alternate def + subpos++; + } + if (subpos->type == LLAMA_GRETYPE_ALT) { + // there's another alternate def of this rule to process + subpos++; + } else { + break; + } + } while (true); + break; + } + case LLAMA_GRETYPE_CHAR: + case LLAMA_GRETYPE_CHAR_NOT: + case LLAMA_GRETYPE_CHAR_ANY: + if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) { + // only add the stack if it's not a duplicate of one we already have + new_stacks.emplace_back(stack); + } + break; + default: + // end of alternate (LLAMA_GRETYPE_END, LLAMA_GRETYPE_ALT) or middle of char range + // (LLAMA_GRETYPE_CHAR_ALT, LLAMA_GRETYPE_CHAR_RNG_UPPER); stack should never be left on + // those + GGML_ABORT("fatal error"); + } +} + +// takes a set of possible pushdown stacks on a grammar, which are required to +// be positioned at a character range (see `llama_grammar_advance_stack`), and +// produces the N possible stacks if the given char is accepted at those +// positions +void llama_grammar_accept( + const llama_grammar_rules & rules, + const llama_grammar_stacks & stacks, + const uint32_t chr, + llama_grammar_stacks & new_stacks) { + new_stacks.clear(); + + for (const auto & stack : stacks) { + if (stack.empty()) { + continue; + } + + auto match = llama_grammar_match_char(stack.back(), chr); + if (match.first) { + const llama_grammar_element * pos = match.second; + + // update top of stack to next element, if any + llama_grammar_stack new_stack(stack.begin(), stack.end() - 1); + if (!llama_grammar_is_end_of_sequence(pos)) { + new_stack.push_back(pos); + } + llama_grammar_advance_stack(rules, new_stack, new_stacks); + } + } +} + +static llama_grammar_candidates llama_grammar_reject_candidates( + const llama_grammar_rules & rules, + const llama_grammar_stacks & stacks, + const llama_grammar_candidates & candidates) { + GGML_ASSERT(!stacks.empty()); // REVIEW + + if (candidates.empty()) { + return {}; + } + + auto rejects = llama_grammar_reject_candidates_for_stack(rules, stacks.front(), candidates); + + for (size_t i = 1, size = stacks.size(); i < size; ++i) { + rejects = llama_grammar_reject_candidates_for_stack(rules, stacks[i], rejects); + } + return rejects; +} + +llama_grammar_candidates llama_grammar_reject_candidates_for_stack( + const llama_grammar_rules & rules, + const llama_grammar_stack & stack, + const llama_grammar_candidates & candidates) { + + llama_grammar_candidates rejects; + rejects.reserve(candidates.size()); + + if (stack.empty()) { + for (const auto & tok : candidates) { + if (*tok.code_points != 0 || tok.partial_utf8.n_remain != 0) { + rejects.push_back(tok); + } + } + return rejects; + } + + const llama_grammar_element * stack_pos = stack.back(); + + llama_grammar_candidates next_candidates; + next_candidates.reserve(candidates.size()); + + for (const auto & tok : candidates) { + if (*tok.code_points == 0) { + // reached end of full codepoints in token, reject iff it ended in a partial sequence + // that cannot satisfy this position in grammar + if (tok.partial_utf8.n_remain != 0 && + !llama_grammar_match_partial_char(stack_pos, tok.partial_utf8)) { + rejects.push_back(tok); + } + } else if (llama_grammar_match_char(stack_pos, *tok.code_points).first) { + next_candidates.push_back({ tok.index, tok.code_points + 1, tok.partial_utf8 }); + } else { + rejects.push_back(tok); + } + } + + const auto * stack_pos_after = llama_grammar_match_char(stack_pos, 0).second; + + // update top of stack to next element, if any + llama_grammar_stack stack_after(stack.begin(), stack.end() - 1); + if (!llama_grammar_is_end_of_sequence(stack_pos_after)) { + stack_after.push_back(stack_pos_after); + } + llama_grammar_stacks next_stacks; + llama_grammar_advance_stack(rules, stack_after, next_stacks); + + auto next_rejects = llama_grammar_reject_candidates(rules, next_stacks, next_candidates); + for (const auto & tok : next_rejects) { + rejects.push_back({ tok.index, tok.code_points - 1, tok.partial_utf8 }); + } + + return rejects; +} + +static bool llama_grammar_detect_left_recursion( + const llama_grammar_rules & rules, + size_t rule_index, + std::vector * rules_visited, + std::vector * rules_in_progress, + std::vector * rules_may_be_empty) { + if ((*rules_in_progress)[rule_index]) { + return true; + } + + (*rules_in_progress)[rule_index] = true; + + const llama_grammar_rule & rule = rules[rule_index]; + + // First check if the rule might produce the empty string. This could be done combined with the second + // step but it's more readable as two steps. + bool at_rule_start = true; + for (size_t i = 0; i < rule.size(); i++) { + if (llama_grammar_is_end_of_sequence(&rule[i])) { + if (at_rule_start) { + (*rules_may_be_empty)[rule_index] = true; + break; + } + at_rule_start = true; + } else { + at_rule_start = false; + } + } + + // Second, recurse into leftmost nonterminals (or next-leftmost as long as the previous nonterminal may + // be empty) + bool recurse_into_nonterminal = true; + for (size_t i = 0; i < rule.size(); i++) { + if (rule[i].type == LLAMA_GRETYPE_RULE_REF && recurse_into_nonterminal) { + if (llama_grammar_detect_left_recursion(rules, (size_t)rule[i].value, rules_visited, rules_in_progress, rules_may_be_empty)) { + return true; + } + if (!((*rules_may_be_empty)[(size_t)rule[i].value])) { + recurse_into_nonterminal = false; + } + } else if (llama_grammar_is_end_of_sequence(&rule[i])) { + recurse_into_nonterminal = true; + } else { + recurse_into_nonterminal = false; + } + } + + (*rules_in_progress)[rule_index] = false; + (*rules_visited)[rule_index] = true; + return false; +} + +// +// grammar - external +// + +struct llama_grammar * llama_grammar_init_impl( + const llama_grammar_element ** rules, + size_t n_rules, + size_t start_rule_index) { + const llama_grammar_element * pos; + + // copy rule definitions into vectors + llama_grammar_rules vec_rules(n_rules); + for (size_t i = 0; i < n_rules; i++) { + for (pos = rules[i]; pos->type != LLAMA_GRETYPE_END; pos++) { + vec_rules[i].push_back(*pos); + } + vec_rules[i].push_back({LLAMA_GRETYPE_END, 0}); + } + + // Check for left recursion + std::vector rules_visited(n_rules); + std::vector rules_in_progress(n_rules); + std::vector rules_may_be_empty(n_rules); + for (size_t i = 0; i < n_rules; i++) { + if (rules_visited[i]) { + continue; + } + if (llama_grammar_detect_left_recursion(vec_rules, i, &rules_visited, &rules_in_progress, &rules_may_be_empty)) { + LLAMA_LOG_ERROR("unsupported grammar, left recursion detected for nonterminal at index %zu", i); + return nullptr; + } + } + + // loop over alternates of start rule to build initial stacks + llama_grammar_stacks stacks; + pos = vec_rules[start_rule_index].data(); + do { + llama_grammar_stack stack; + if (!llama_grammar_is_end_of_sequence(pos)) { + // if alternate is nonempty, add to stack + stack.push_back(pos); + } + llama_grammar_advance_stack(vec_rules, stack, stacks); + while (!llama_grammar_is_end_of_sequence(pos)) { + // scan to end of alternate def + pos++; + } + if (pos->type == LLAMA_GRETYPE_ALT) { + // there's another alternate def of this rule to process + pos++; + } else { + break; + } + } while (true); + + // Important: vec_rules has to be moved here, not copied, because stacks contains + // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar + // then the pointers would be invalidated when the local vec_rules goes out of scope. + return new llama_grammar{ std::move(vec_rules), std::move(stacks), {} }; +} + +void llama_grammar_free_impl(struct llama_grammar * grammar) { + delete grammar; +} + +struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar * grammar) { + llama_grammar * result = new llama_grammar{ grammar->rules, grammar->stacks, grammar->partial_utf8 }; + + // redirect elements in stacks to point to new rules + for (size_t is = 0; is < result->stacks.size(); is++) { + for (size_t ie = 0; ie < result->stacks[is].size(); ie++) { + for (size_t ir0 = 0; ir0 < grammar->rules.size(); ir0++) { + for (size_t ir1 = 0; ir1 < grammar->rules[ir0].size(); ir1++) { + if (grammar->stacks[is][ie] == &grammar->rules[ir0][ir1]) { + result->stacks[is][ie] = &result->rules[ir0][ir1]; + } + } + } + } + } + + return result; +} + +void llama_grammar_sample_impl(const struct llama_grammar * grammar, const struct llama_vocab * vocab, const struct llama_sampling * smpl, llama_token_data_array * candidates) { + GGML_ASSERT(grammar); + GGML_ASSERT(vocab); + + int64_t t_start_sample_us = ggml_time_us(); + + bool allow_eog = false; + for (const auto & stack : grammar->stacks) { + if (stack.empty()) { + allow_eog = true; + break; + } + } + + std::vector, llama_partial_utf8>> candidates_decoded; + candidates_decoded.reserve(candidates->size); + + llama_grammar_candidates candidates_grammar; + candidates_grammar.reserve(candidates->size); + + for (size_t i = 0; i < candidates->size; ++i) { + const llama_token id = candidates->data[i].id; + const std::string & piece = vocab->cache_token_to_piece.at(id); + + if (llama_token_is_eog_impl(*vocab, id)) { + if (!allow_eog) { + candidates->data[i].logit = -INFINITY; + } + } else if (piece.empty() || piece[0] == 0) { + candidates->data[i].logit = -INFINITY; + } else { + candidates_decoded.push_back(decode_utf8(piece, grammar->partial_utf8)); + candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second }); + } + } + + const auto rejects = llama_grammar_reject_candidates(grammar->rules, grammar->stacks, candidates_grammar); + for (const auto & reject : rejects) { + candidates->data[reject.index].logit = -INFINITY; + } + + smpl->t_sample_us += ggml_time_us() - t_start_sample_us; +} + +void llama_grammar_accept_token_impl(struct llama_grammar * grammar, const struct llama_vocab * vocab, const struct llama_sampling * smpl, llama_token token) { + const int64_t t_start_sample_us = ggml_time_us(); + + if (llama_token_is_eog_impl(*vocab, token)) { + for (const auto & stack : grammar->stacks) { + if (stack.empty()) { + return; + } + } + GGML_ABORT("fatal error"); + } + + const std::string & piece = vocab->cache_token_to_piece.at(token); + + // Note terminating 0 in decoded string + const auto decoded = decode_utf8(piece, grammar->partial_utf8); + const auto & code_points = decoded.first; + + llama_grammar_stacks tmp_new_stacks; + for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { + llama_grammar_accept(grammar->rules, grammar->stacks, *it, tmp_new_stacks); + grammar->stacks = tmp_new_stacks; + } + + grammar->partial_utf8 = decoded.second; + GGML_ASSERT(!grammar->stacks.empty()); + + smpl->t_sample_us += ggml_time_us() - t_start_sample_us; +} diff --git a/llama/llama-grammar.h b/llama/llama-grammar.h new file mode 100644 index 00000000..17f6f88a --- /dev/null +++ b/llama/llama-grammar.h @@ -0,0 +1,65 @@ +/** + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file + * + * MIT License + * + * Copyright (c) 2023-2024 The ggml authors + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +#include "llama-impl.h" + +struct llama_vocab; +struct llama_sampling; + +struct llama_grammar { + const llama_grammar_rules rules; + llama_grammar_stacks stacks; + + // buffer for partially generated UTF-8 sequence from accepted tokens + llama_partial_utf8 partial_utf8; +}; + +// +// internal API +// + +struct llama_grammar * llama_grammar_init_impl( + const llama_grammar_element ** rules, + size_t n_rules, + size_t start_rule_index); + +void llama_grammar_free_impl(struct llama_grammar * grammar); + +struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar * grammar); + +void llama_grammar_sample_impl( + const struct llama_grammar * grammar, + const struct llama_vocab * vocab, + const struct llama_sampling * smpl, + llama_token_data_array * candidates); + +void llama_grammar_accept_token_impl( + struct llama_grammar * grammar, + const struct llama_vocab * vocab, + const struct llama_sampling * smpl, + llama_token token); diff --git a/llama/llama-impl.h b/llama/llama-impl.h new file mode 100644 index 00000000..322307c7 --- /dev/null +++ b/llama/llama-impl.h @@ -0,0 +1,52 @@ +/** + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file + * + * MIT License + * + * Copyright (c) 2023-2024 The ggml authors + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +#define LLAMA_API_INTERNAL +#include "llama.h" + +#ifdef __GNUC__ +#ifdef __MINGW32__ +#define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__))) +#else +#define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__))) +#endif +#else +#define LLAMA_ATTRIBUTE_FORMAT(...) +#endif + +// +// logging +// + +LLAMA_ATTRIBUTE_FORMAT(2, 3) +void llama_log_internal (ggml_log_level level, const char * format, ...); +void llama_log_callback_default(ggml_log_level level, const char * text, void * user_data); + +#define LLAMA_LOG_INFO(...) llama_log_internal(GGML_LOG_LEVEL_INFO , __VA_ARGS__) +#define LLAMA_LOG_WARN(...) llama_log_internal(GGML_LOG_LEVEL_WARN , __VA_ARGS__) +#define LLAMA_LOG_ERROR(...) llama_log_internal(GGML_LOG_LEVEL_ERROR, __VA_ARGS__) diff --git a/llama/llama-sampling.cpp b/llama/llama-sampling.cpp new file mode 100644 index 00000000..935547c2 --- /dev/null +++ b/llama/llama-sampling.cpp @@ -0,0 +1,661 @@ +/** + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file + * + * MIT License + * + * Copyright (c) 2023-2024 The ggml authors + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#include "llama-sampling.h" + +#include +#include +#include +#include +#include +#include + +static void llama_log_softmax(float * array, size_t size) { + float max_l = *std::max_element(array, array + size); + float sum = 0.f; + for (size_t i = 0; i < size; ++i) { + float p = expf(array[i] - max_l); + sum += p; + array[i] = p; + } + + for (size_t i = 0; i < size; ++i) { + array[i] = logf(array[i] / sum); + } +} + +void llama_set_rng_seed_impl(struct llama_sampling * smpl, uint32_t seed) { + if (seed == LLAMA_DEFAULT_SEED) { + seed = time(NULL); + } + + smpl->rng.seed(seed); +} + +void llama_sample_softmax_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) { + GGML_ASSERT(candidates->size > 0); + + const int64_t t_start_sample_us = ggml_time_us(); + + // Sort the logits in descending order + if (!candidates->sorted) { + std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) { + return a.logit > b.logit; + }); + candidates->sorted = true; + } + + float max_l = candidates->data[0].logit; + float cum_sum = 0.0f; + for (size_t i = 0; i < candidates->size; ++i) { + float p = expf(candidates->data[i].logit - max_l); + candidates->data[i].p = p; + cum_sum += p; + } + for (size_t i = 0; i < candidates->size; ++i) { + candidates->data[i].p /= cum_sum; + } + + if (smpl) { + smpl->t_sample_us += ggml_time_us() - t_start_sample_us; + } +} + +void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep) { + // TODO: move bucket sort to separate function so that top_p/tail_free/typical/softmax first is equally fast + // if (k >= (int32_t)candidates->size) { + // return; + // } + + const int64_t t_start_sample_us = ggml_time_us(); + + if (k <= 0) { + k = candidates->size; + } + + k = std::max(k, (int) min_keep); + k = std::min(k, (int) candidates->size); + + // Sort scores in descending order + if (!candidates->sorted) { + auto comp = [](const llama_token_data & a, const llama_token_data & b) { + return a.logit > b.logit; + }; + if (k <= 128) { + std::partial_sort(candidates->data, candidates->data + k, candidates->data + candidates->size, comp); + } else { + constexpr int nbuckets = 128; + constexpr float bucket_low = -10.0f; + constexpr float bucket_high = 10.0f; + constexpr float bucket_scale = nbuckets/(bucket_high - bucket_low); + constexpr float bucker_inter = -bucket_low * bucket_scale; + + std::vector bucket_idx(candidates->size); + std::vector histo(nbuckets, 0); + + for (int i = 0; i < (int)candidates->size; ++i) { + const float val = candidates->data[i].logit; + int ib = int(bucket_scale * val + bucker_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low); + ib = std::max(0, std::min(nbuckets-1, ib)); + bucket_idx[i] = ib; + ++histo[ib]; + } + int nhave = 0; + int ib = nbuckets - 1; + for ( ; ib >= 0; --ib) { + nhave += histo[ib]; + if (nhave >= k) break; + } + std::vector tmp_tokens(nhave); + auto ptr = tmp_tokens.data(); + std::vector bucket_ptrs; + bucket_ptrs.reserve(nbuckets - ib); + for (int j = nbuckets - 1; j >= ib; --j) { + bucket_ptrs.push_back(ptr); + ptr += histo[j]; + } + for (int i = 0; i < (int)candidates->size; ++i) { + int j = bucket_idx[i]; + if (j >= ib) { + *bucket_ptrs[nbuckets-1-j]++ = candidates->data[i]; + } + } + + ptr = tmp_tokens.data(); + int ndone = 0; + for (int j = nbuckets-1; j > ib; --j) { + std::sort(ptr, ptr + histo[j], comp); + ptr += histo[j]; + ndone += histo[j]; + } + std::partial_sort(ptr, ptr + k - ndone, ptr + histo[ib], comp); + + std::memcpy(candidates->data, tmp_tokens.data(), k*sizeof(llama_token_data)); + + } + candidates->sorted = true; + } + candidates->size = k; + + if (smpl) { + smpl->t_sample_us += ggml_time_us() - t_start_sample_us; + } +} + +void llama_sample_top_p_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) { + if (p >= 1.0f) { + return; + } + + llama_sample_softmax_impl(smpl, candidates); + + const int64_t t_start_sample_us = ggml_time_us(); + + // Compute the cumulative probabilities + float cum_sum = 0.0f; + size_t last_idx = candidates->size; + + for (size_t i = 0; i < candidates->size; ++i) { + cum_sum += candidates->data[i].p; + + // Check if the running sum is at least p or if we have kept at least min_keep tokens + // we set the last index to i+1 to indicate that the current iterate should be included in the set + if (cum_sum >= p && i + 1 >= min_keep) { + last_idx = i + 1; + break; + } + } + + // Resize the output vector to keep only the top-p tokens + candidates->size = last_idx; + + if (smpl) { + smpl->t_sample_us += ggml_time_us() - t_start_sample_us; + } +} + +void llama_sample_min_p_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) { + if (p <= 0.0f || !candidates->size) { + return; + } + + const int64_t t_start_sample_us = ggml_time_us(); + + bool min_p_applied = false; + + // if the candidates aren't sorted, try the unsorted implementation first + if (!candidates->sorted) { + std::vector filtered_tokens; + + float max_logit = -FLT_MAX; + for (size_t i = 0; i < candidates->size; ++i) { + max_logit = std::max(max_logit, candidates->data[i].logit); + } + const float min_logit = max_logit + logf(p); // min logit for p_i >= p * p_max + + for (size_t i = 0; i < candidates->size; ++i) { + if (candidates->data[i].logit >= min_logit) { + filtered_tokens.push_back(candidates->data[i]); + } + } + + // if we have enough values the operation was a success + if (filtered_tokens.size() >= min_keep) { + memcpy(candidates->data, filtered_tokens.data(), filtered_tokens.size()*sizeof(llama_token_data)); + candidates->size = filtered_tokens.size(); + min_p_applied = true; + } + } + + // if the candidates are sorted or the unsorted implementation failed, use this implementation + if (!min_p_applied) { + // Sort the logits in descending order + if (!candidates->sorted) { + std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) { + return a.logit > b.logit; + }); + candidates->sorted = true; + } + + const float min_logit = candidates->data[0].logit + logf(p); // min logit for p_i >= p * p_max + size_t i = 1; // first token always matches + + for (; i < candidates->size; ++i) { + if (candidates->data[i].logit < min_logit && i >= min_keep) { + break; // prob too small + } + } + + // Resize the output vector to keep only the matching tokens + candidates->size = i; + } + + if (smpl) { + smpl->t_sample_us += ggml_time_us() - t_start_sample_us; + } +} + +void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep) { + if (z >= 1.0f || candidates->size <= 2) { + return; + } + + llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates); + const int64_t t_start_sample_us = ggml_time_us(); + + // Compute the first and second derivatives + std::vector first_derivatives(candidates->size - 1); + std::vector second_derivatives(candidates->size - 2); + + for (size_t i = 0; i < first_derivatives.size(); ++i) { + first_derivatives[i] = candidates->data[i].p - candidates->data[i + 1].p; + } + for (size_t i = 0; i < second_derivatives.size(); ++i) { + second_derivatives[i] = first_derivatives[i] - first_derivatives[i + 1]; + } + + // Calculate absolute value of second derivatives + for (size_t i = 0; i < second_derivatives.size(); ++i) { + second_derivatives[i] = std::abs(second_derivatives[i]); + } + + // Normalize the second derivatives + { + const float second_derivatives_sum = std::accumulate(second_derivatives.begin(), second_derivatives.end(), 0.0f); + + if (second_derivatives_sum > 1e-6f) { + for (float & value : second_derivatives) { + value /= second_derivatives_sum; + } + } else { + for (float & value : second_derivatives) { + value = 1.0f / second_derivatives.size(); + } + } + } + + float cum_sum = 0.0f; + size_t last_idx = candidates->size; + for (size_t i = 0; i < second_derivatives.size(); ++i) { + cum_sum += second_derivatives[i]; + + // Check if the running sum is greater than z or if we have kept at least min_keep tokens + if (cum_sum > z && i >= min_keep) { + last_idx = i; + break; + } + } + + // Resize the output vector to keep only the tokens above the tail location + candidates->size = last_idx; + + if (smpl) { + smpl->t_sample_us += ggml_time_us() - t_start_sample_us; + } +} + +void llama_sample_typical_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) { + // Reference implementation: + // https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr + if (p >= 1.0f) { + return; + } + + // Compute the softmax of logits and calculate entropy + llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates); + + const int64_t t_start_sample_us = ggml_time_us(); + + float entropy = 0.0f; + for (size_t i = 0; i < candidates->size; ++i) { + entropy += -candidates->data[i].p * logf(candidates->data[i].p); + } + + // Compute the absolute difference between negative log probability and entropy for each candidate + std::vector shifted_scores; + for (size_t i = 0; i < candidates->size; ++i) { + float shifted_score = fabsf(-logf(candidates->data[i].p) - entropy); + shifted_scores.push_back(shifted_score); + } + + // Sort tokens based on the shifted_scores and their corresponding indices + std::vector indices(candidates->size); + std::iota(indices.begin(), indices.end(), 0); + + std::sort(indices.begin(), indices.end(), [&](size_t a, size_t b) { + return shifted_scores[a] < shifted_scores[b]; + }); + + // Compute the cumulative probabilities + float cum_sum = 0.0f; + size_t last_idx = indices.size(); + + for (size_t i = 0; i < indices.size(); ++i) { + size_t idx = indices[i]; + cum_sum += candidates->data[idx].p; + + // Check if the running sum is greater than typical or if we have kept at least min_keep tokens + if (cum_sum > p && i >= min_keep - 1) { + last_idx = i + 1; + break; + } + } + + // Resize the output vector to keep only the locally typical tokens + std::vector new_candidates; + for (size_t i = 0; i < last_idx; ++i) { + size_t idx = indices[i]; + new_candidates.push_back(candidates->data[idx]); + } + + // Replace the data in candidates with the new_candidates data + std::copy(new_candidates.begin(), new_candidates.end(), candidates->data); + candidates->size = new_candidates.size(); + candidates->sorted = false; + + if (smpl) { + smpl->t_sample_us += ggml_time_us() - t_start_sample_us; + } +} + +void llama_sample_entropy_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val) { + const int64_t t_start_sample_us = ggml_time_us(); + + // no need to do anything if there is only one (or zero) candidates + if(candidates->size <= 1) { + return; + } + + // Calculate maximum possible entropy + float max_entropy = -logf(1.0f / candidates->size); + + llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates); + + // Calculate entropy of the softmax probabilities + float entropy = 0.0f; + for (size_t i = 0; i < candidates->size; ++i) { + float prob = candidates->data[i].p; + if (prob > 0.0f) { // Ensure no log(0) + entropy -= prob * logf(prob); + } + } + + // Normalize the entropy (max_entropy cannot be 0 here because we checked candidates->size != 1 above) + float normalized_entropy = entropy / max_entropy; + + // Map the normalized entropy to the desired temperature range using the power function + float dyn_temp = min_temp + (max_temp - min_temp) * powf(normalized_entropy, exponent_val); + +#ifdef DEBUG + LLAMA_LOG_INFO("Your text maxtemp value is: %f\n", max_temp); + LLAMA_LOG_INFO("Entropy: %f\n", entropy); + LLAMA_LOG_INFO("Max Possible Entropy: %f\n", max_entropy); + LLAMA_LOG_INFO("Normalized Entropy: %f\n", normalized_entropy); + LLAMA_LOG_INFO("Exponent: %f\n", exponent_val); + LLAMA_LOG_INFO("Dynamic Temperature (dyn_temp): %f\n", dyn_temp); +#endif + + // Apply the dynamically calculated temperature scaling + for (size_t i = 0; i < candidates->size; ++i) { + candidates->data[i].logit /= dyn_temp; + } + + // Re-compute softmax probabilities after scaling logits with dynamic temperature + double max_l_double = candidates->data[0].logit; + double cum_sum_double = 0.0; + for (size_t i = 0; i < candidates->size; ++i) { + double p = exp(candidates->data[i].logit - max_l_double); + candidates->data[i].p = p; // Store the scaled probability + cum_sum_double += p; + } + for (size_t i = 0; i < candidates->size; ++i) { + candidates->data[i].p /= cum_sum_double; // Re-normalize the probabilities + } + +#ifdef DEBUG + // Print the updated top 25 probabilities after temperature scaling + LLAMA_LOG_INFO("\nUpdated Top 25 Probabilities After Dynamic Temperature Scaling (in percentages):\n"); + for (size_t i = 0; i < 25 && i < candidates->size; ++i) { + LLAMA_LOG_INFO("Token %zu: %f%%\n", i + 1, candidates->data[i].p * 100.0f); + } +#endif + + if (smpl) { + smpl->t_sample_us += ggml_time_us() - t_start_sample_us; + } +} + +void llama_sample_temp_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float temp) { + const int64_t t_start_sample_us = ggml_time_us(); + + for (size_t i = 0; i < candidates->size; ++i) { + candidates->data[i].logit /= temp; + } + + if (smpl) { + smpl->t_sample_us += ggml_time_us() - t_start_sample_us; + } +} + +void llama_sample_repetition_penalties_impl( + struct llama_sampling * smpl, + llama_token_data_array * candidates, + const llama_token * last_tokens, + size_t penalty_last_n, + float penalty_repeat, + float penalty_freq, + float penalty_present) { + if (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f)) { + return; + } + + const int64_t t_start_sample_us = ggml_time_us(); + + // Create a frequency map to count occurrences of each token in last_tokens + std::unordered_map token_count; + for (size_t i = 0; i < penalty_last_n; ++i) { + token_count[last_tokens[i]]++; + } + + // Apply frequency and presence penalties to the candidates + for (size_t i = 0; i < candidates->size; ++i) { + const auto token_iter = token_count.find(candidates->data[i].id); + if (token_iter == token_count.end()) { + continue; + } + + const int count = token_iter->second; + + // The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong. + // This is common fix for this problem, which is to multiply by the penalty instead of dividing. + if (candidates->data[i].logit <= 0) { + candidates->data[i].logit *= penalty_repeat; + } else { + candidates->data[i].logit /= penalty_repeat; + } + + candidates->data[i].logit -= float(count) * penalty_freq + float(count > 0) * penalty_present; + } + + candidates->sorted = false; + + if (smpl) { + smpl->t_sample_us += ggml_time_us() - t_start_sample_us; + } +} + +void llama_sample_apply_guidance_impl( + struct llama_sampling * smpl, + float * logits, + float * logits_guidance, + float scale) { + GGML_ASSERT(smpl); + + const auto t_start_sample_us = ggml_time_us(); + const auto n_vocab = smpl->n_vocab; + + llama_log_softmax(logits, n_vocab); + llama_log_softmax(logits_guidance, n_vocab); + + for (int i = 0; i < n_vocab; ++i) { + auto & l = logits[i]; + const auto & g = logits_guidance[i]; + + l = scale * (l - g) + g; + } + + smpl->t_sample_us += ggml_time_us() - t_start_sample_us; +} + +llama_token llama_sample_token_mirostat_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) { + GGML_ASSERT(smpl); + + const int32_t n_vocab = float(smpl->n_vocab); + + int64_t t_start_sample_us = ggml_time_us(); + + llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates); + + // Estimate s_hat using the most probable m tokens + float s_hat = 0.0; + float sum_ti_bi = 0.0; + float sum_ti_sq = 0.0; + for (size_t i = 0; i < size_t(m - 1) && i < candidates->size - 1; ++i) { + float t_i = logf(float(i + 2) / float(i + 1)); + float b_i = logf(candidates->data[i].p / candidates->data[i + 1].p); + sum_ti_bi += t_i * b_i; + sum_ti_sq += t_i * t_i; + } + s_hat = sum_ti_bi / sum_ti_sq; + + // Compute k from the estimated s_hat and target surprise value + float epsilon_hat = s_hat - 1; + float k = powf((epsilon_hat * powf(2, *mu)) / (1 - powf(n_vocab, -epsilon_hat)), 1 / s_hat); + + // Sample the next word X using top-k sampling + llama_sample_top_k_impl((struct llama_sampling *) nullptr, candidates, int(k), 1); + smpl->t_sample_us += ggml_time_us() - t_start_sample_us; + llama_token X = llama_sample_token_impl(smpl, candidates); + t_start_sample_us = ggml_time_us(); + + // Compute error as the difference between observed surprise and target surprise value + size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { + return candidate.id == X; + })); + float observed_surprise = -log2f(candidates->data[X_idx].p); + float e = observed_surprise - tau; + + // Update mu using the learning rate and error + *mu = *mu - eta * e; + + smpl->t_sample_us += ggml_time_us() - t_start_sample_us; + return X; +} + +llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, float * mu) { + int64_t t_start_sample_us; + t_start_sample_us = ggml_time_us(); + + llama_sample_softmax_impl(smpl, candidates); + + // Truncate the words with surprise values greater than mu + candidates->size = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { + return -log2f(candidate.p) > *mu; + })); + + if (candidates->size == 0) { + candidates->size = 1; + } + + if (smpl) { + smpl->t_sample_us += ggml_time_us() - t_start_sample_us; + } + + // Normalize the probabilities of the remaining words + llama_sample_softmax_impl(smpl, candidates); + + // Sample the next word X from the remaining words + llama_token X = llama_sample_token_impl(smpl, candidates); + t_start_sample_us = ggml_time_us(); + + // Compute error as the difference between observed surprise and target surprise value + size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { + return candidate.id == X; + })); + float observed_surprise = -log2f(candidates->data[X_idx].p); + float e = observed_surprise - tau; + + // Update mu using the learning rate and error + *mu = *mu - eta * e; + + if (smpl) { + smpl->t_sample_us += ggml_time_us() - t_start_sample_us; + } + return X; +} + +llama_token llama_sample_token_greedy_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) { + const int64_t t_start_sample_us = ggml_time_us(); + + // Find max element + auto * max_iter = std::max_element(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) { + return a.logit < b.logit; + }); + + llama_token result = max_iter->id; + if (smpl) { + smpl->t_sample_us += ggml_time_us() - t_start_sample_us; + smpl->n_sample++; + } + return result; +} + +llama_token llama_sample_token_with_rng_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, std::mt19937 & rng) { + GGML_ASSERT(smpl); + + const int64_t t_start_sample_us = ggml_time_us(); + llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates); + + std::vector probs; + probs.reserve(candidates->size); + for (size_t i = 0; i < candidates->size; ++i) { + probs.push_back(candidates->data[i].p); + } + + std::discrete_distribution<> dist(probs.begin(), probs.end()); + int idx = dist(rng); + + llama_token result = candidates->data[idx].id; + + smpl->t_sample_us += ggml_time_us() - t_start_sample_us; + smpl->n_sample++; + + return result; +} + +llama_token llama_sample_token_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) { + return llama_sample_token_with_rng_impl(smpl, candidates, smpl->rng); +} diff --git a/llama/llama-sampling.h b/llama/llama-sampling.h new file mode 100644 index 00000000..89b8d33a --- /dev/null +++ b/llama/llama-sampling.h @@ -0,0 +1,82 @@ +/** + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file + * + * MIT License + * + * Copyright (c) 2023-2024 The ggml authors + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +#include "llama-impl.h" + +struct llama_sampling { + llama_sampling(int32_t n_vocab) : n_vocab(n_vocab) {} + + std::mt19937 rng; + + int32_t n_vocab = 0; + + mutable int64_t t_sample_us = 0; + mutable int32_t n_sample = 0; + + void reset_timings() const { + t_sample_us = 0; + n_sample = 0; + } +}; + +// +// internal API +// + +void llama_set_rng_seed_impl(struct llama_sampling * smpl, uint32_t seed); + +void llama_sample_softmax_impl (struct llama_sampling * smpl, llama_token_data_array * candidates); +void llama_sample_top_k_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep); +void llama_sample_top_p_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep); +void llama_sample_min_p_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep); +void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep); +void llama_sample_typical_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep); +void llama_sample_entropy_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val); +void llama_sample_temp_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float temp); + +void llama_sample_repetition_penalties_impl( + struct llama_sampling * smpl, + llama_token_data_array * candidates, + const llama_token * last_tokens, + size_t penalty_last_n, + float penalty_repeat, + float penalty_freq, + float penalty_present); + +void llama_sample_apply_guidance_impl( + struct llama_sampling * smpl, + float * logits, + float * logits_guidance, + float scale); + +llama_token llama_sample_token_mirostat_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu); +llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, float * mu); +llama_token llama_sample_token_greedy_impl (struct llama_sampling * smpl, llama_token_data_array * candidates); +llama_token llama_sample_token_with_rng_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, std::mt19937 & rng); +llama_token llama_sample_token_impl (struct llama_sampling * smpl, llama_token_data_array * candidates); + diff --git a/llama/llama-vocab.cpp b/llama/llama-vocab.cpp new file mode 100644 index 00000000..a40a9259 --- /dev/null +++ b/llama/llama-vocab.cpp @@ -0,0 +1,1747 @@ +/** + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file + * + * MIT License + * + * Copyright (c) 2023-2024 The ggml authors + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#include "llama-vocab.h" + +#include "unicode.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// +// helpers +// + +static void replace_all(std::string & s, const std::string & search, const std::string & replace) { + std::string result; + for (size_t pos = 0; ; pos += search.length()) { + auto new_pos = s.find(search, pos); + if (new_pos == std::string::npos) { + result += s.substr(pos, s.size() - pos); + break; + } + result += s.substr(pos, new_pos - pos) + replace; + pos = new_pos; + } + s = std::move(result); +} + +LLAMA_ATTRIBUTE_FORMAT(1, 2) +static std::string format(const char * fmt, ...) { + va_list ap; + va_list ap2; + va_start(ap, fmt); + va_copy(ap2, ap); + int size = vsnprintf(NULL, 0, fmt, ap); + GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT + std::vector buf(size + 1); + int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2); + GGML_ASSERT(size2 == size); + va_end(ap2); + va_end(ap); + return std::string(buf.data(), size); +} + +struct naive_trie { + naive_trie() : has_value(false), value(0) { + } + void insert(const char * key, size_t len, int32_t value = 0) { + if (len == 0) { + this->has_value = true; + this->value = value; + return; + } + char c = key[0]; + auto res = children.find(c); + if (res != children.end()) { + res->second.insert(key + 1, len - 1, value); + } else { + auto res = children.insert(std::make_pair(c, naive_trie())); + res.first->second.insert(key + 1, len - 1, value); + } + } + std::pair get_longest_prefix(const char * key, size_t len, size_t offset = 0) { + if (len == 0 || offset == len) { + return std::make_pair(key, offset); + } + char c = key[offset]; + auto res = children.find(c); + if (res != children.end()) { + return res->second.get_longest_prefix(key, len, offset + 1); + } else { + return std::make_pair(key, offset); + } + } + struct naive_trie * traverse(const char c) { + auto res = children.find(c); + if (res != children.end()) { + return &res->second; + } else { + return NULL; + } + } + std::map children; + bool has_value; + llama_token value; +}; + +// +// impl +// + +int llama_vocab::find_bpe_rank(const std::string & token_left, const std::string & token_right) const { + GGML_ASSERT(token_left.find(' ') == std::string::npos); + GGML_ASSERT(token_left.find('\n') == std::string::npos); + GGML_ASSERT(token_right.find(' ') == std::string::npos); + GGML_ASSERT(token_right.find('\n') == std::string::npos); + + auto it = bpe_ranks.find(std::make_pair(token_left, token_right)); + if (it == bpe_ranks.end()) { + return -1; + } + + return it->second; +} + +static enum llama_vocab_type llama_vocab_get_type(const llama_vocab & vocab) { + return vocab.type; +} + +static bool llama_is_normal_token(const llama_vocab & vocab, llama_token id) { + GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE); + return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_NORMAL; +} + +static bool llama_is_unknown_token(const llama_vocab & vocab, llama_token id) { + GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE); + return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_UNKNOWN; +} + +static bool llama_is_control_token(const llama_vocab & vocab, llama_token id) { + GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE); + return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_CONTROL; +} + +static bool llama_is_byte_token(const llama_vocab & vocab, llama_token id) { + GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE); + return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_BYTE; +} + +static bool llama_is_user_defined_token(const llama_vocab & vocab, llama_token id) { + GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE); + return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_USER_DEFINED; +} + +static bool llama_is_unused_token(const llama_vocab & vocab, llama_token id) { + GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE); + return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_UNUSED; +} + +static uint8_t llama_token_to_byte(const llama_vocab & vocab, llama_token id) { + GGML_ASSERT(llama_vocab_get_type(vocab) != LLAMA_VOCAB_TYPE_NONE); + GGML_ASSERT(llama_is_byte_token(vocab, id)); + const auto & token_data = vocab.id_to_token.at(id); + switch (llama_vocab_get_type(vocab)) { + case LLAMA_VOCAB_TYPE_SPM: + case LLAMA_VOCAB_TYPE_UGM: { + auto buf = token_data.text.substr(3, 2); + return strtol(buf.c_str(), NULL, 16); + } + case LLAMA_VOCAB_TYPE_BPE: { + GGML_ABORT("fatal error"); + //return unicode_utf8_to_byte(token_data.text); // TODO: why is this here after GGML_ASSERT? + } + case LLAMA_VOCAB_TYPE_WPM: { + GGML_ABORT("fatal error"); + } + default: + GGML_ABORT("fatal error"); + } +} + +static void llama_escape_whitespace(std::string & text) { + replace_all(text, " ", "\xe2\x96\x81"); +} + +static void llama_unescape_whitespace(std::string & word) { + replace_all(word, "\xe2\x96\x81", " "); +} + +struct llm_symbol { + using index = int; + index prev; + index next; + const char * text; + size_t n; +}; + +static_assert(std::is_trivially_copyable::value, "llm_symbol is not trivially copyable"); + +// +// SPM tokenizer +// original implementation: +// https://github.com/ggerganov/llama.cpp/commit/074bea2eb1f1349a0118239c4152914aecaa1be4 +// + +struct llm_bigram_spm { + struct comparator { + bool operator()(llm_bigram_spm & l, llm_bigram_spm & r) { + return (l.score < r.score) || (l.score == r.score && l.left > r.left); + } + }; + using queue_storage = std::vector; + using queue = std::priority_queue; + llm_symbol::index left; + llm_symbol::index right; + float score; + size_t size; +}; + +struct llm_tokenizer_spm { + llm_tokenizer_spm(const llama_vocab & vocab) : vocab(vocab) {} + + void tokenize(const std::string & text, std::vector & output) { + // split string into utf8 chars + int index = 0; + size_t offs = 0; + while (offs < text.size()) { + llm_symbol sym; + size_t len = unicode_len_utf8(text[offs]); + sym.text = text.c_str() + offs; + sym.n = std::min(len, text.size() - offs); + offs += sym.n; + sym.prev = index - 1; + sym.next = offs == text.size() ? -1 : index + 1; + index++; + symbols.emplace_back(sym); + } + + // seed the work queue with all possible 2-character tokens. + for (size_t i = 1; i < symbols.size(); ++i) { + try_add_bigram(i - 1, i); + } + + // keep substituting the highest frequency pairs for as long as we can. + while (!work_queue.empty()) { + auto bigram = work_queue.top(); + work_queue.pop(); + + auto & left_sym = symbols[bigram.left]; + auto & right_sym = symbols[bigram.right]; + + // if one of the symbols already got merged, skip it. + if (left_sym.n == 0 || right_sym.n == 0 || + left_sym.n + right_sym.n != bigram.size) { + continue; + } + + // merge the right sym into the left one + left_sym.n += right_sym.n; + right_sym.n = 0; + + //LLAMA_LOG_INFO("left = '%*s' size = %zu\n", (int) left_sym.n, left_sym.text, bigram.size); + + // remove the right sym from the chain + left_sym.next = right_sym.next; + if (right_sym.next >= 0) { + symbols[right_sym.next].prev = bigram.left; + } + + // find more substitutions + try_add_bigram(left_sym.prev, bigram.left); + try_add_bigram(bigram.left, left_sym.next); + } + + for (int i = 0; i != -1; i = symbols[i].next) { + auto & symbol = symbols[i]; + resegment(symbol, output); + } + } + +private: + void resegment(llm_symbol & symbol, std::vector & output) { + auto text = std::string(symbol.text, symbol.n); + auto token = vocab.token_to_id.find(text); + + // Do we need to support is_unused? + if (token != vocab.token_to_id.end()) { + output.push_back((*token).second); + return; + } + + const auto p = rev_merge.find(text); + + if (p == rev_merge.end()) { + // output any symbols that did not form tokens as bytes. + output.reserve(output.size() + symbol.n); + for (int j = 0; j < (int)symbol.n; ++j) { + llama_vocab::id token_id = llama_byte_to_token_impl(vocab, symbol.text[j]); + output.push_back(token_id); + } + return; + } + + resegment(symbols[p->second.first], output); + resegment(symbols[p->second.second], output); + } + + void try_add_bigram(int left, int right) { + if (left == -1 || right == -1) { + return; + } + + const std::string text = std::string(symbols[left].text, symbols[left].n + symbols[right].n); + auto token = vocab.token_to_id.find(text); + + if (token == vocab.token_to_id.end()) { + return; + } + + if (static_cast((*token).second) >= vocab.id_to_token.size()) { + return; + } + + const auto & tok_data = vocab.id_to_token[(*token).second]; + + llm_bigram_spm bigram; + bigram.left = left; + bigram.right = right; + bigram.score = tok_data.score; + bigram.size = text.size(); + + work_queue.push(bigram); + + // Do we need to support is_unused? + rev_merge[text] = std::make_pair(left, right); + } + + const llama_vocab & vocab; + + std::vector symbols; + llm_bigram_spm::queue work_queue; + + std::map> rev_merge; +}; + +// +// BPE tokenizer +// adapted from https://github.com/cmp-nct/ggllm.cpp [MIT License] +// tried to simplify unicode stuff, so most likely does not work 100% correctly! +// + +// TODO: there are a lot of common parts between spm and bpe tokenizers, should be refactored and reused + +struct llm_bigram_bpe { + struct comparator { + bool operator()(const llm_bigram_bpe & l, const llm_bigram_bpe & r) const { + return l.rank > r.rank || (l.rank == r.rank && l.left > r.left); + } + }; + + using queue_storage = std::vector; + using queue = std::priority_queue; + llm_symbol::index left; + llm_symbol::index right; + std::string text; + int rank; + size_t size; +}; + +struct llm_tokenizer_bpe { + llm_tokenizer_bpe(const llama_vocab & vocab): vocab(vocab) { + GGML_ASSERT(vocab.type == LLAMA_VOCAB_TYPE_BPE); + switch (vocab.type_pre) { + case LLAMA_VOCAB_PRE_TYPE_LLAMA3: + regex_exprs = { + // original regex from tokenizer.json + //"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + + // adapted: https://github.com/ggerganov/llama.cpp/pull/6920#issuecomment-2080233989 + "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + }; + break; + case LLAMA_VOCAB_PRE_TYPE_DBRX: + case LLAMA_VOCAB_PRE_TYPE_SMAUG: + regex_exprs = { + // same as llama3 + "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + }; + break; + case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM: + regex_exprs = { + "[\r\n]", + "\\s?[A-Za-zµÀ-ÖØ-öø-ƺƼ-ƿDŽ-ʓʕ-ʯͰ-ͳͶͷͻ-ͽͿΆΈ-ΊΌΎ-ΡΣ-ϵϷ-ҁҊ-ԯԱ-ՖႠ-ჅᎠ-Ᏽᏸ-ᏽᲐ-ᲺᲽ-Ჿᴀ-ᴫᵫ-ᵷᵹ-ᶚḀ-ἕἘ-Ἕἠ-ὅὈ-Ὅὐ-ὗὙὛὝὟ-ώᾀ-ᾴᾶ-ᾼιῂ-ῄῆ-ῌῐ-ΐῖ-Ίῠ-Ῥῲ-ῴῶ-ῼℂℇℊ-ℓℕℙ-ℝℤΩℨK-ℭℯ-ℴℹℼ-ℿⅅ-ⅉⅎↃↄⰀ-ⱻⱾ-ⳤⳫ-ⳮⳲⳳꙀ-ꙭꚀ-ꚛꜢ-ꝯꝱ-ꞇꞋ-ꞎꭰ-ꮿff-stﬓ-ﬗA-Za-z𐐀-𐑏𐒰-𐓓𐓘-𐓻𐲀-𐲲𐳀-𐳲𑢠-𑣟𞤀-𞥃]+", + "\\s?[!-/:-~!-/:-~‘-‟ -。]+", + "\\s+$", + "[一-龥ࠀ-一가-퟿]+", + "\\p{N}+", + }; + break; + case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER: + regex_exprs = { + "[\r\n]", + "\\s?\\p{L}+", + "\\s?\\p{P}+", + "[一-龥ࠀ-一가-퟿]+", + "\\p{N}", + }; + break; + case LLAMA_VOCAB_PRE_TYPE_FALCON: + regex_exprs = { + "[\\p{P}\\$\\+<=>\\^~\\|`]+", + "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", + "[0-9][0-9][0-9]", + }; + break; + case LLAMA_VOCAB_PRE_TYPE_STARCODER: + case LLAMA_VOCAB_PRE_TYPE_REFACT: + case LLAMA_VOCAB_PRE_TYPE_COMMAND_R: + case LLAMA_VOCAB_PRE_TYPE_SMOLLM: + case LLAMA_VOCAB_PRE_TYPE_CODESHELL: + regex_exprs = { + "\\p{N}", + "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", + }; + break; + case LLAMA_VOCAB_PRE_TYPE_GPT2: + case LLAMA_VOCAB_PRE_TYPE_MPT: + case LLAMA_VOCAB_PRE_TYPE_OLMO: + case LLAMA_VOCAB_PRE_TYPE_JAIS: + regex_exprs = { + "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", + }; + break; + case LLAMA_VOCAB_PRE_TYPE_STABLELM2: + case LLAMA_VOCAB_PRE_TYPE_QWEN2: + regex_exprs = { + // original regex from tokenizer.json + // "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" + "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + }; + break; + case LLAMA_VOCAB_PRE_TYPE_PORO: + regex_exprs = { + " ?[^(\\s|.,!?…。,、।۔،)]+", + }; + break; + case LLAMA_VOCAB_PRE_TYPE_CHATGLM4: + regex_exprs = { + "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + }; + break; + case LLAMA_VOCAB_PRE_TYPE_VIKING: + regex_exprs = { + " ?[^(\\s|.,!?…。,、।۔،)]+", + "\\p{N}", + }; + break; + case LLAMA_VOCAB_PRE_TYPE_TEKKEN: + // original regex from tokenizer.json + // "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" + regex_exprs = { + "[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))*((?=[\\p{L}])([^A-Z]))+|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))+((?=[\\p{L}])([^A-Z]))*|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + }; + break; + default: + // default regex for BPE tokenization pre-processing + regex_exprs = { + "[\\p{P}\\$\\+<=>\\^~\\|]+", + "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", + "\\p{N}+", + "[0-9][0-9][0-9]", + }; + break; + } + } + + void append(const llama_vocab::id token_id, std::vector & output) const { + output.push_back(token_id); + } + + bool append_bos(std::vector & output) const { + if (vocab.tokenizer_add_bos) { + GGML_ASSERT(vocab.special_bos_id != -1); + output.push_back(vocab.special_bos_id); + return true; + } + return false; + } + + bool append_eos(std::vector & output) const { + if (vocab.tokenizer_add_eos) { + GGML_ASSERT(vocab.special_eos_id != -1); + output.push_back(vocab.special_eos_id); + return true; + } + return false; + } + + void check_double_bos_eos(const std::vector & output) const { + if (vocab.tokenizer_add_bos && output.size() >= 2 && output[1] == vocab.special_bos_id) { + LLAMA_LOG_WARN( + "%s: Added a BOS token to the prompt as specified by the model but the prompt " + "also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. " + "Are you sure this is what you want?\n", __FUNCTION__); + } + if (vocab.tokenizer_add_eos && output.size() >= 2 && *(output.end()-2) == vocab.special_eos_id) { + LLAMA_LOG_WARN( + "%s: Added a EOS token to the prompt as specified by the model but the prompt " + "also ends with a EOS token. So now the final prompt ends with 2 EOS tokens. " + "Are you sure this is what you want?\n", __FUNCTION__); + } + } + + void tokenize(const std::string & text, std::vector & output) { + int final_prev_index = -1; + + const auto word_collection = unicode_regex_split(text, regex_exprs); + + symbols_final.clear(); + + for (auto & word : word_collection) { + work_queue = llm_bigram_bpe::queue(); + symbols.clear(); + + int index = 0; + size_t offset = 0; + + if (vocab.tokenizer_ignore_merges && vocab.token_to_id.find(word) != vocab.token_to_id.end()) { + symbols.emplace_back(llm_symbol{-1, -1, word.c_str(), word.size()}); + offset = word.size(); + } + + while (offset < word.size()) { + llm_symbol sym; + size_t char_len = std::min(word.size() - offset, (size_t) unicode_len_utf8(word[offset])); + sym.text = word.c_str() + offset; + sym.n = char_len; + offset += sym.n; + sym.prev = index - 1; + sym.next = offset == word.size() ? -1 : index + 1; + index++; + symbols.emplace_back(sym); + } + for (size_t i = 1; i < symbols.size(); ++i) { + add_new_bigram(i - 1, i); + } + + // build token(s) + while (!work_queue.empty()) { + auto bigram = work_queue.top(); + work_queue.pop(); + + auto & left_symbol = symbols[bigram.left]; + auto & right_symbol = symbols[bigram.right]; + + if (left_symbol.n == 0 || right_symbol.n == 0) { + continue; + } + std::string left_token = std::string(left_symbol.text, left_symbol.n); + std::string right_token = std::string(right_symbol.text, right_symbol.n); + if (left_token + right_token != bigram.text) { + continue; // Skip this bigram if it's outdated + } + + // merge the right sym into the left one + left_symbol.n += right_symbol.n; + right_symbol.n = 0; + + // remove the right sym from the chain + left_symbol.next = right_symbol.next; + if (right_symbol.next >= 0) { + symbols[right_symbol.next].prev = bigram.left; + } + + add_new_bigram(left_symbol.prev, bigram.left); // left side of current symbol + add_new_bigram(bigram.left, left_symbol.next); // right side of current symbol + } + + // add the finished tokens to the final list keeping correct order for next and prev + for (auto & sym : symbols) { + if (sym.n > 0) { + sym.prev = final_prev_index; + sym.next = -1; + if (final_prev_index != -1) { + symbols_final[final_prev_index].next = symbols_final.size(); + } + symbols_final.emplace_back(sym); + final_prev_index = symbols_final.size() - 1; + } + } + } + + symbols = symbols_final; + + if (!symbols.empty()) { + for (int i = 0; i != -1; i = symbols[i].next) { + auto & symbol = symbols[i]; + if (symbol.n == 0) { + continue; + } + + const std::string str = std::string(symbol.text, symbol.n); + const auto token = vocab.token_to_id.find(str); + + if (token == vocab.token_to_id.end()) { + for (auto j = str.begin(); j != str.end(); ++j) { + std::string byte_str(1, *j); + auto token_multibyte = vocab.token_to_id.find(byte_str); + if (token_multibyte != vocab.token_to_id.end()) { + output.push_back(token_multibyte->second); + } + } + } else { + output.push_back((*token).second); + } + } + } + } + +private: + void add_new_bigram(int left, int right) { + if (left == -1 || right == -1) { + return; + } + + std::string left_token = std::string(symbols[left].text, symbols[left].n); + std::string right_token = std::string(symbols[right].text, symbols[right].n); + + int rank_found = -1; + + rank_found = vocab.find_bpe_rank(left_token, right_token); + + if (rank_found < 0) { + return; + } + + llm_bigram_bpe bigram; + + bigram.left = left; + bigram.right = right; + bigram.text = left_token + right_token; + bigram.size = left_token.size() + right_token.size(); + bigram.rank = rank_found; + + work_queue.push(bigram); + } + + const llama_vocab & vocab; + + std::vector regex_exprs; + + std::vector symbols; + std::vector symbols_final; + + llm_bigram_bpe::queue work_queue; +}; + +// +// WPM tokenizer +// + +struct llm_tokenizer_wpm { + llm_tokenizer_wpm(const llama_vocab & vocab): vocab(vocab) {} + + void tokenize(const std::string & text, std::vector & output) const { + const auto & token_map = vocab.token_to_id; + + // normalize and split by whitespace + std::vector words = preprocess(text); + + // bos token prepended already + + // find the longest tokens that form the words + for (const std::string & word : words) { + // skip empty words + if (word.size() == 0) { + continue; + } + + // prepend phantom space + const std::string word1 = "\xe2\x96\x81" + word; + const int n = word1.size(); + + const size_t current_tokens = output.size(); + + // we're at the start of a new word + // move through character position in word + for (int i = 0; i < n; ++i) { + // loop through possible match length + bool match = false; + for (int j = std::min(n, i + vocab.max_token_len + 1); j > i; j--) { + auto it = token_map.find(word1.substr(i, j - i)); + if (it != token_map.end()) { + output.push_back(it->second); + match = true; + i = j - 1; + break; + } + } + + if (!match) { // discard all + output.resize(current_tokens); + break; // and discard next tokens + } + } + + // we didn't find any matches for this word + if (current_tokens == output.size()) { + output.push_back(vocab.special_unk_id); + } + } + } + + // TODO: reduce string copies by using cpts_offs array + std::vector preprocess(const std::string & text) const { + const std::vector cpts_nfd = unicode_cpts_normalize_nfd(unicode_cpts_from_utf8(text)); + std::vector words(1, ""); + + for (const uint32_t cpt : cpts_nfd) { + const auto flags = unicode_cpt_flags(cpt); + + if (flags.is_whitespace) { + if (words.back().size()) { // finish previous word if any + words.emplace_back(); + } + continue; + } + + assert (!flags.is_separator); + if (cpt == 0 || cpt == 0xFFFD || flags.is_control) { + continue; + } + + const std::string s = unicode_cpt_to_utf8(unicode_tolower(cpt)); + if (flags.is_punctuation || ( cpt < 0x7F && flags.is_symbol ) || is_chinese_char(cpt)) { + if (words.back().size()) { // finish previous word if any + words.emplace_back(); + } + words.back() = s; // single char word + words.emplace_back(); // start a new word + } else { + words.back() += s; // append char to word + } + } + + if (!words.back().size()) { + words.pop_back(); + } + + return words; + } + + static bool is_chinese_char(uint32_t cpt) { + return + (cpt >= 0x04E00 && cpt <= 0x09FFF) || + (cpt >= 0x03400 && cpt <= 0x04DBF) || + (cpt >= 0x20000 && cpt <= 0x2A6DF) || + (cpt >= 0x2A700 && cpt <= 0x2B73F) || + (cpt >= 0x2B740 && cpt <= 0x2B81F) || + (cpt >= 0x2B920 && cpt <= 0x2CEAF) || // this should be 0x2B820 but in hf rust code it is 0x2B920 + (cpt >= 0x0F900 && cpt <= 0x0FAFF) || + (cpt >= 0x2F800 && cpt <= 0x2FA1F); + //(cpt >= 0x3000 && cpt <= 0x303F) || + //(cpt >= 0xFF00 && cpt <= 0xFFEF); + } + + const llama_vocab & vocab; +}; + +// +// UGM tokenizer +// + +struct llm_tokenizer_ugm { + llm_tokenizer_ugm(const llama_vocab & vocab) : vocab(vocab) { + if (vocab.precompiled_charsmap.size() > 0) { + size_t charsmap_offset = 0; + + // First four bytes of precompiled_charsmap contains length of binary + // blob containing XOR-compressed compact double array (XCDA) entries + uint32_t xcda_blob_size = *(const uint32_t *) &vocab.precompiled_charsmap[0]; + charsmap_offset += sizeof(xcda_blob_size); + if (xcda_blob_size + charsmap_offset >= vocab.precompiled_charsmap.size()) { + throw std::runtime_error("Index out of array bounds in precompiled charsmap!"); + } + + // Next xcda_blob_size bytes contain entries of XOR-compressed compact + // double array (XCDA). Each entry is bit-packed into a 32-bit integer. + xcda_array = (const uint32_t *) &vocab.precompiled_charsmap[charsmap_offset]; + xcda_array_size = xcda_blob_size / sizeof(uint32_t); + charsmap_offset += xcda_blob_size; + + // Remaining bytes of precompiled charsmap contain null-terminated + // replacement strings for prefixes matched by the XCDA. + prefix_replacements = &vocab.precompiled_charsmap[charsmap_offset]; + prefix_replacements_size = vocab.precompiled_charsmap.size() - charsmap_offset; + } + + for (unsigned int id = 0; id < vocab.id_to_token.size(); ++id) { + const auto &token_data = vocab.id_to_token[id]; + + if (llama_is_normal_token(vocab, id)) { + min_score = std::min(min_score, token_data.score); + max_score = std::max(max_score, token_data.score); + } + + if (llama_is_normal_token(vocab, id) || + llama_is_user_defined_token(vocab, id) || + llama_is_unused_token(vocab, id)) { + token_matcher.insert(token_data.text.data(), token_data.text.size(), id); + } + + if (llama_is_user_defined_token(vocab, id)) { + user_defined_token_matcher.insert(token_data.text.data(), token_data.text.size()); + } + } + + unknown_token_score = min_score - unknown_token_score_penalty; + } + + /* This implementation is based on SentencePiece optimized Viterbi algorithm for + * unigram language models. The general idea is to: + * - move along the input sequence in steps of one UTF code point, + * - at each step find all possible tokenizations of the prefix by + * traversing the tokens trie, + * - for each tokenization store the best one so far (by higher score) + * - use the position in sequence after given token as an index to store + * results + * - if there was no valid tokenization of the current UTF code point + * then use unknown token with additional score penalty + * After processing the whole sequence we backtrack from the end to get + * the best tokenization. + */ + void tokenize(const std::string & text, std::vector & output) { + // normalize the input first + std::string normalized; + normalize(text, &normalized); + size_t input_len = normalized.size(); + if (input_len == 0) { + return; + } + + // initialize score_sum to -FLT_MAX so it will be always lower than sums of token scores + std::vector tokenization_results(input_len + 1, {vocab.special_unk_id, 0, -FLT_MAX}); + // at the beginning tokenization score is zero + tokenization_results[0] = { vocab.special_unk_id, 0, 0 }; + + for (size_t input_offset = 0; input_offset < input_len;) { + size_t prefix_offset = input_offset; + // calculate how many code units are in the currently processed UTF code point + size_t n_utf8_code_units = std::min(unicode_len_utf8(normalized[input_offset]), input_len - input_offset); + + // traverse the token matcher trie to find a matching token + bool single_codepoint_token_found = false; + const struct best_tokenization & current_best = tokenization_results[input_offset]; + struct naive_trie * node = token_matcher.traverse(normalized[prefix_offset++]); + + while (prefix_offset <= input_len && node != NULL) { + // check if we found valid token in prefix + if (node->has_value) { + // check if it corresponds to the whole UTF code point + if (prefix_offset - input_offset == n_utf8_code_units) { + single_codepoint_token_found = true; + } + llama_token token_id = node->value; + const auto & token_data = vocab.id_to_token[token_id]; + + // we set the user-defined token scores to 0 to make them more likely to be selected + // (normal token scores are log probabilities, so they are negative) + // score type is double here to make tokenization results exactly + // the same as in the HF tokenizer using SentencePiece + const double token_score = llama_is_user_defined_token(vocab, token_id) ? 0.0 : token_data.score; + const double challenger_score = current_best.score_sum + token_score; + struct best_tokenization & current_champ = tokenization_results[prefix_offset]; + if (challenger_score > current_champ.score_sum) { + struct best_tokenization challenger = { token_id, input_offset, (float) challenger_score }; + current_champ = challenger; + } + } + node = node->traverse(normalized[prefix_offset++]); + } + + // if we didn't find a valid token corresponding to the whole UTF code point + // then use unknown token as the tokenization of this UTF code point + if (!single_codepoint_token_found) { + const double challenger_score = current_best.score_sum + unknown_token_score; + prefix_offset = input_offset + n_utf8_code_units; + struct best_tokenization & current_champ = tokenization_results[prefix_offset]; + if (challenger_score > current_champ.score_sum) { + struct best_tokenization challenger = { vocab.special_unk_id, input_offset, (float) challenger_score }; + current_champ = challenger; + } + } + + // move to the next UTF code point + input_offset += n_utf8_code_units; + } + + // now backtrack from the end to gather token ids of the best tokenization + // merge sequences of consecutive unknown tokens into single unknown tokens + bool is_prev_unknown = false; + for (struct best_tokenization & tokenization = tokenization_results[input_len]; ; tokenization = tokenization_results[tokenization.input_offset]) { + bool is_unknown = tokenization.token_id == vocab.special_unk_id; + if (!(is_prev_unknown && is_unknown)) { + output.push_back(tokenization.token_id); + } + if (tokenization.input_offset == 0) { + break; + } + is_prev_unknown = is_unknown; + } + + // reverse the output since we added tokens starting from the end of the input + std::reverse(output.begin(), output.end()); + } + +private: + const llama_vocab & vocab; + + // helper structure for returning normalization results + struct normalization_result { + const char * normalized; + size_t normalized_len; + size_t consumed_input; + }; + + void normalize(const std::string& input, std::string * normalized) { + normalized->clear(); + normalized->reserve(input.size() * 3); + + const std::string space = vocab.tokenizer_escape_whitespaces ? escaped_space : " "; + + bool shall_prepend_space = !vocab.tokenizer_treat_whitespace_as_suffix && vocab.tokenizer_add_space_prefix; + bool shall_append_space = vocab.tokenizer_treat_whitespace_as_suffix && vocab.tokenizer_add_space_prefix; + bool shall_merge_spaces = vocab.tokenizer_remove_extra_whitespaces; + + bool is_space_prepended = false; + bool processing_non_ws = false; + + size_t input_len = input.size(); + + for (size_t input_offset = 0; input_offset < input_len; ) { + auto norm_res = normalize_prefix(input, input_offset); + for (size_t i = 0; i < norm_res.normalized_len; i++) { + char c = norm_res.normalized[i]; + if (c != ' ') { + if (!processing_non_ws) { + processing_non_ws = true; + if ((shall_prepend_space && !is_space_prepended) || shall_merge_spaces) { + normalized->append(space); + is_space_prepended = true; + } + } + normalized->push_back(c); + } else { + if (processing_non_ws) { + processing_non_ws = false; + } + if (!shall_merge_spaces) { + normalized->append(space); + } + } + } + + input_offset += norm_res.consumed_input; + } + + if (shall_append_space) { + normalized->append(space); + } + } + + /* + * This structure is a view wrapper for XOR-compressed double array (XCDA) + * See Shunsuke Kanda (2018). Space- and Time-Efficient String Dictionaries. + * Eeach bit-packed entry contains: + * - BASE array value in bits 10-30 + * - LCHECK array value in bits 0-7 + * - LEAF array value in bit 9 + * Entries containing indexes of replacement sequences have set bit 31 + */ + struct xcda_array_view { + public: + xcda_array_view(const uint32_t * xcda_array, size_t xcda_array_size) : xcda_array(xcda_array), xcda_array_size(xcda_array_size) { + } + uint32_t get_base(size_t index) { + uint32_t packed_node = get_node(index); + return (packed_node >> 10) << ((packed_node & (1U << 9)) >> 6); + } + uint32_t get_lcheck(size_t index) { + uint32_t packed_node = get_node(index); + return packed_node & ((1U << 31) | 0xff); + } + bool get_leaf(size_t index) { + uint32_t packed_node = get_node(index); + return (packed_node >> 8) & 1; + } + uint32_t get_value(size_t index) { + uint32_t packed_node = get_node(index); + return packed_node & ((1U << 31) - 1); + } + private: + uint32_t get_node(size_t index) { + if (index > xcda_array_size) { + throw std::runtime_error("Index out of array bounds in XCDA array!"); + } + return xcda_array[index]; + } + const uint32_t * xcda_array; + size_t xcda_array_size; + }; + + struct normalization_result normalize_prefix(const std::string & input, size_t input_offset) { + if (input_offset == input.size()) { + return { &input[input_offset], 0, 0 }; + } + + // if input prefix matches some user-defined token return this token as normalization result + auto user_defined_token_match = user_defined_token_matcher.get_longest_prefix(&input[input_offset], input.size() - input_offset); + if (user_defined_token_match.second > 0) { + return { &input[input_offset], user_defined_token_match.second, user_defined_token_match.second }; + } + + size_t longest_prefix_length = 0; + size_t longest_prefix_offset = 0; + + if (xcda_array_size > 0) { + struct xcda_array_view xcda_view(xcda_array, xcda_array_size); + + // Find the longest normalized sequence matching the input prefix by walking + // the XOR-compressed compact double array (XCDA) starting from the root node + // We find the index of the next node by calculating BASE[s] ^ c where s is + // the index of the previous node and c is a numerical character value + uint32_t node_index = 0; + // get BASE of the root node + node_index = xcda_view.get_base(node_index); + for (size_t prefix_offset = input_offset; prefix_offset < input.size(); prefix_offset++) { + unsigned char c = input[prefix_offset]; + if (c == 0) { + break; + } + node_index ^= c; + // if value of LCHECK is not c it means that this is not a child of + // the previous node, so we stop matching + if (xcda_view.get_lcheck(node_index) != c) { + break; + } + bool is_leaf = xcda_view.get_leaf(node_index); + // get BASE of the current node + node_index ^= xcda_view.get_base(node_index); + // if LEAF of the current node is true, it means that its BASE points to the node + // containing index of replacement sequence for currently matched input prefix + if (is_leaf) + { + longest_prefix_length = prefix_offset - input_offset + 1; + // get index of replacement sequence for currently matched input prefix + longest_prefix_offset = xcda_view.get_value(node_index); + } + } + } + + if (longest_prefix_length > 0) { + // we have a match, so return the replacement sequence + if (longest_prefix_offset >= prefix_replacements_size) { + throw std::runtime_error("Index out of array bounds in precompiled charsmap!"); + } + const char * prefix_replacement = &prefix_replacements[longest_prefix_offset]; + return { prefix_replacement, strlen(prefix_replacement), longest_prefix_length }; + } else { + // check if the input prefix contains a valid sequence of UTF-8 code units + try { + // if yes, return this sequence unmodified + size_t prefix_offset = input_offset; + unicode_cpt_from_utf8(input, prefix_offset); + return { &input[input_offset], prefix_offset - input_offset, prefix_offset - input_offset }; + } catch (std::invalid_argument & /*ex*/) { + // if no, consume 1 byte and return U+FFFD - REPLACEMENT CHARACTER + return { "\xEF\xBF\xBD", 3, 1 }; + } + } + } + + // escaped space symbol - U+2581 (Lower One Eighth Block) + const std::string escaped_space = "\xE2\x96\x81"; + + const char * prefix_replacements = NULL; + size_t prefix_replacements_size = 0; + + const uint32_t * xcda_array = NULL; + size_t xcda_array_size = 0; + + struct naive_trie user_defined_token_matcher; + + // this structure stores the best tokenization so far at input_offset + struct best_tokenization { + llama_token token_id; + size_t input_offset; + float score_sum; + }; + + float min_score = FLT_MAX; + float max_score = -FLT_MAX; + + float unknown_token_score_penalty = 10.0; + float unknown_token_score; + + struct naive_trie token_matcher; +}; + +// +// (de-) tokenize +// + +typedef enum FRAGMENT_BUFFER_VARIANT_TYPE { + FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN, + FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT +} FRAGMENT_BUFFER_VARIANT_TYPE; + +struct fragment_buffer_variant { + fragment_buffer_variant(llama_vocab::id _token) + : + type(FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN), + token(_token), + raw_text(_dummy), + offset(0), + length(0) {} + + fragment_buffer_variant(const std::string & _raw_text, int64_t _offset, int64_t _length) + : + type(FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT), + token((llama_vocab::id) - 1), + raw_text(_raw_text), + offset(_offset), + length(_length){ + GGML_ASSERT(_offset >= 0); + GGML_ASSERT(_length >= 1); + GGML_ASSERT(offset + length <= raw_text.length()); + } + + const FRAGMENT_BUFFER_VARIANT_TYPE type; + const llama_vocab::id token; + const std::string _dummy; + const std::string & raw_text; + const uint64_t offset; + const uint64_t length; +}; + +// #define PRETOKENIZERDEBUG + +static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list & buffer, bool parse_special) { + // for each special token + for (const llama_vocab::id special_id : vocab.cache_special_tokens) { + const auto & data = vocab.id_to_token[special_id]; + const auto & special_token = data.text; + + if (!parse_special && (data.attr & (LLAMA_TOKEN_ATTR_CONTROL | LLAMA_TOKEN_ATTR_UNKNOWN))) { + // Ignore control and unknown tokens when parse_special == false + continue; + // User-defined tokens are still pre-tokenized before everything else + // ref: https://github.com/huggingface/tokenizers/blob/fdd26ba9a3f0c133427aab0423888cbde91362d7/tokenizers/src/tokenizer/mod.rs#L726 + // This is mostly relevant for neox-style tokenizers (mpt, olmo, stablelm, etc.) + } + + // for each text fragment + std::forward_list::iterator it = buffer.begin(); + while (it != buffer.end()) { + auto & fragment = (*it); + + // if a fragment is text ( not yet processed ) + if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { + auto & raw_text = fragment.raw_text; + + auto raw_text_base_offset = fragment.offset; + auto raw_text_base_length = fragment.length; + + // loop over the text + while (true) { + // find the first occurrence of a given special token in this fragment + // passing offset argument only limit the "search area" but match coordinates + // are still relative to the source full raw_text + auto match = raw_text.find(special_token, raw_text_base_offset); + + // no occurrences found, stop processing this fragment for a given special token + if (match == std::string::npos) break; + + // check if match is within bounds of offset <-> length + if (match + special_token.length() > raw_text_base_offset + raw_text_base_length) break; + +#ifdef PRETOKENIZERDEBUG + LLAMA_LOG_WARN("FF: (%ld %ld %ld) '%s'\n", raw_text->length(), raw_text_base_offset, raw_text_base_length, raw_text->substr(raw_text_base_offset, raw_text_base_length).c_str()); +#endif + auto source = std::distance(buffer.begin(), it); + + // if match is further than base offset + // then we have some text to the left of it + if (match > raw_text_base_offset) { + // left + const int64_t left_reminder_offset = raw_text_base_offset + 0; + int64_t left_reminder_length = match - raw_text_base_offset; + + if (data.attr & LLAMA_TOKEN_ATTR_LSTRIP) { + while (left_reminder_length > 0 && isspace(raw_text[left_reminder_offset + left_reminder_length - 1])) { + left_reminder_length--; + } + } + + if (left_reminder_length > 0) { + buffer.emplace_after(it, raw_text, left_reminder_offset, left_reminder_length); + it++; + } + +#ifdef PRETOKENIZERDEBUG + LLAMA_LOG_WARN("FL: (%ld %ld) '%s'\n", left_reminder_offset, left_reminder_length, raw_text->substr(left_reminder_offset, left_reminder_length).c_str()); +#endif + } + + // special token + buffer.emplace_after(it, special_id); + it++; + + // right + if (match + special_token.length() < raw_text_base_offset + raw_text_base_length) { + int64_t right_reminder_offset = match + special_token.length(); + int64_t right_reminder_length = raw_text_base_length - ((match - raw_text_base_offset) + special_token.length()); + + if (data.attr & LLAMA_TOKEN_ATTR_RSTRIP) { + while (right_reminder_length > 0 && isspace(raw_text[right_reminder_offset])) { + right_reminder_offset++; + right_reminder_length--; + } + } + + if (right_reminder_length > 0) { + buffer.emplace_after(it, raw_text, right_reminder_offset, right_reminder_length); + it++; + } + +#ifdef PRETOKENIZERDEBUG + LLAMA_LOG_WARN("FR: (%ld %ld) '%s'\n", right_reminder_offset, right_reminder_length, raw_text->substr(right_reminder_offset, right_reminder_length).c_str()); +#endif + + if (source == 0) { + buffer.erase_after(buffer.before_begin()); + } else { + buffer.erase_after(std::next(buffer.begin(), (source-1))); + } + + // repeat for the right side + raw_text_base_offset = right_reminder_offset; + raw_text_base_length = right_reminder_length; + +#ifdef PRETOKENIZERDEBUG + LLAMA_LOG_WARN("RR: (%ld %ld) '%s'\n", raw_text_base_offset, raw_text_base_length, raw_text->substr(raw_text_base_offset, raw_text_base_length).c_str()); +#endif + } else { + if (source == 0) { + buffer.erase_after(buffer.before_begin()); + } else { + buffer.erase_after(std::next(buffer.begin(), (source-1))); + } + break; + } + } + } + it++; + } + } +} + +std::vector llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool add_special, bool parse_special) { + std::vector output; + std::forward_list fragment_buffer; + + if (!raw_text.empty()) { + fragment_buffer.emplace_front(raw_text, 0, raw_text.length()); + tokenizer_st_partition(vocab, fragment_buffer, parse_special); + } + + switch (vocab.type) { + case LLAMA_VOCAB_TYPE_SPM: + { + // OG tokenizer behavior: + // + // tokenizer.encode('', add_special_tokens=True) returns [1] + // tokenizer.encode('', add_special_tokens=False) returns [] + + bool is_prev_special = true; // prefix with space if first token + + if (add_special && vocab.tokenizer_add_bos) { + GGML_ASSERT(vocab.special_bos_id != -1); + output.push_back(vocab.special_bos_id); + is_prev_special = true; + } + + for (const auto & fragment : fragment_buffer) { + if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { + auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length); + + // prefix with space if previous is special + if (vocab.tokenizer_add_space_prefix && is_prev_special) { + raw_text = " " + raw_text; + } + +#ifdef PRETOKENIZERDEBUG + LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str()); +#endif + llm_tokenizer_spm tokenizer(vocab); + llama_escape_whitespace(raw_text); + tokenizer.tokenize(raw_text, output); + is_prev_special = false; + } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN) + output.push_back(fragment.token); + is_prev_special = true; + } + } + + if (add_special && vocab.tokenizer_add_bos && output.size() >= 2 && output[1] == vocab.special_bos_id) { + LLAMA_LOG_WARN( + "%s: Added a BOS token to the prompt as specified by the model but the prompt " + "also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. " + "Are you sure this is what you want?\n", __FUNCTION__); + } + + if (add_special && vocab.tokenizer_add_eos) { + GGML_ASSERT(vocab.special_eos_id != -1); + output.push_back(vocab.special_eos_id); + } + } break; + case LLAMA_VOCAB_TYPE_BPE: + { + llm_tokenizer_bpe tokenizer(vocab); + + if (add_special) { + tokenizer.append_bos(output); + } + for (const auto & fragment : fragment_buffer) { + if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { + auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length); + +#ifdef PRETOKENIZERDEBUG + LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str()); +#endif + tokenizer.tokenize(raw_text, output); + } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN) + tokenizer.append(fragment.token, output); + } + } + + if (add_special) { + tokenizer.append_eos(output); + tokenizer.check_double_bos_eos(output); + } + } break; + case LLAMA_VOCAB_TYPE_WPM: + { + if (add_special) { + GGML_ASSERT(vocab.special_cls_id != -1); + output.push_back(vocab.special_cls_id); + } + + llm_tokenizer_wpm tokenizer(vocab); + + for (const auto & fragment : fragment_buffer) { + if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { + auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length); + +#ifdef PRETOKENIZERDEBUG + LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str()); +#endif + tokenizer.tokenize(raw_text, output); + } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN) + output.push_back(fragment.token); + } + } + + if (add_special) { + GGML_ASSERT(vocab.special_sep_id != -1); + output.push_back(vocab.special_sep_id); + } + } break; + case LLAMA_VOCAB_TYPE_UGM: + { + llm_tokenizer_ugm tokenizer(vocab); + + if (add_special && vocab.tokenizer_add_bos != 0) { + GGML_ASSERT(vocab.special_bos_id != -1); + output.push_back(vocab.special_bos_id); + } + + for (const auto & fragment : fragment_buffer) { + if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { + auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length); +#ifdef PRETOKENIZERDEBUG + LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str()); +#endif + tokenizer.tokenize(raw_text, output); + } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN) + output.push_back(fragment.token); + } + } + + if (add_special && vocab.tokenizer_add_bos != 0 && output.size() >= 2 && output[1] == vocab.special_bos_id) { + LLAMA_LOG_WARN( + "%s: Added a BOS token to the prompt as specified by the model but the prompt " + "also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. " + "Are you sure this is what you want?\n", __FUNCTION__); + } + + if (add_special && vocab.tokenizer_add_eos == 1) { + GGML_ASSERT(vocab.special_eos_id != -1); + output.push_back(vocab.special_eos_id); + } + } break; + case LLAMA_VOCAB_TYPE_NONE: + GGML_ABORT("fatal error"); + } + + return output; +} + +llama_token llama_byte_to_token_impl(const llama_vocab & vocab, uint8_t ch) { + GGML_ASSERT(llama_vocab_get_type(vocab) != LLAMA_VOCAB_TYPE_NONE); + static const char * hex = "0123456789ABCDEF"; + switch (llama_vocab_get_type(vocab)) { + case LLAMA_VOCAB_TYPE_SPM: + case LLAMA_VOCAB_TYPE_UGM: { + const char buf[7] = { '<', '0', 'x', hex[ch >> 4], hex[ch & 15], '>', 0 }; + auto token = vocab.token_to_id.find(buf); + if (token != vocab.token_to_id.end()) { + return (*token).second; + } + // Try to fall back to just the byte as a string + const char buf2[2] = { (char)ch, 0 }; + return vocab.token_to_id.at(buf2); + } + case LLAMA_VOCAB_TYPE_WPM: + case LLAMA_VOCAB_TYPE_BPE: { + return vocab.token_to_id.at(unicode_byte_to_utf8(ch)); + } + default: + GGML_ABORT("fatal error"); + } +} + +const char * llama_token_get_text_impl(const struct llama_vocab & vocab, llama_token token) { + GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE); + return vocab.id_to_token[token].text.c_str(); +} + +float llama_token_get_score_impl(const struct llama_vocab & vocab, llama_token token) { + GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE); + return vocab.id_to_token[token].score; +} + +llama_token_attr llama_token_get_attr_impl(const struct llama_vocab & vocab, llama_token token) { + GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE); + return vocab.id_to_token[token].attr; +} + +bool llama_token_is_eog_impl(const struct llama_vocab & vocab, llama_token token) { + return token != -1 && ( + token == llama_token_eos_impl(vocab) || + token == llama_token_eot_impl(vocab) + ); +} + +bool llama_token_is_control_impl(const struct llama_vocab & vocab, llama_token token) { + return llama_is_control_token(vocab, token); +} + +llama_token llama_token_bos_impl(const struct llama_vocab & vocab) { + return vocab.special_bos_id; +} + +llama_token llama_token_eos_impl(const struct llama_vocab & vocab) { + return vocab.special_eos_id; +} + +llama_token llama_token_cls_impl(const struct llama_vocab & vocab) { + return vocab.special_cls_id; +} + +llama_token llama_token_sep_impl(const struct llama_vocab & vocab) { + return vocab.special_sep_id; +} + +llama_token llama_token_nl_impl(const struct llama_vocab & vocab) { + return vocab.linefeed_id; +} + +llama_token llama_token_pad_impl(const struct llama_vocab & vocab) { + return vocab.special_pad_id; +} + +int32_t llama_add_bos_token_impl(const struct llama_vocab & vocab) { + return vocab.tokenizer_add_bos; +} + +int32_t llama_add_eos_token_impl(const struct llama_vocab & vocab) { + return vocab.tokenizer_add_eos; +} + +llama_token llama_token_prefix_impl(const struct llama_vocab & vocab) { + return vocab.special_prefix_id; +} + +llama_token llama_token_middle_impl(const struct llama_vocab & vocab) { + return vocab.special_middle_id; +} + +llama_token llama_token_suffix_impl(const struct llama_vocab & vocab) { + return vocab.special_suffix_id; +} + +llama_token llama_token_eot_impl(const struct llama_vocab & vocab) { + return vocab.special_eot_id; +} + +int32_t llama_tokenize_impl( + const struct llama_vocab & vocab, + const char * text, + int32_t text_len, + llama_token * tokens, + int32_t n_tokens_max, + bool add_special, + bool parse_special) { + auto res = llama_tokenize_internal(vocab, std::string(text, text_len), add_special, parse_special); + if (n_tokens_max < (int) res.size()) { + // LLAMA_LOG_ERROR("%s: too many tokens\n", __func__); + return -((int) res.size()); + } + + for (size_t i = 0; i < res.size(); i++) { + tokens[i] = res[i]; + } + + return res.size(); +} + +static std::string llama_decode_text(const std::string & text) { + std::string decoded_text; + + const auto cpts = unicode_cpts_from_utf8(text); + for (const auto cpt : cpts) { + const auto utf8 = unicode_cpt_to_utf8(cpt); + try { + decoded_text += unicode_utf8_to_byte(utf8); + } catch (const std::out_of_range & /*e*/) { + decoded_text += "[UNK_BYTE_0x"; + for (const auto c : utf8) { + decoded_text += format("%02x", (uint8_t) c); + } + decoded_text += text + "]"; + } + } + + return decoded_text; +} + +// does not write null-terminator to buf +int32_t llama_token_to_piece_impl(const struct llama_vocab & vocab, llama_token token, char * buf, int32_t length, int32_t lstrip, bool special) { + // ref: https://github.com/ggerganov/llama.cpp/pull/7587#discussion_r1620983843 + static const int attr_special = LLAMA_TOKEN_ATTR_UNKNOWN | LLAMA_TOKEN_ATTR_CONTROL; + const llama_token_attr attr = llama_token_get_attr_impl(vocab, token); + if (!special && (attr & attr_special)) { + return 0; + } + + // copy piece chars to output text buffer + // skip up to 'lstrip' leading spaces before copying + auto _try_copy = [=] (const char * token, size_t size) -> int32_t { + for (int32_t i = 0; i < lstrip && size && *token == ' '; ++i) { + token++; + size--; + } + if (length < (int32_t)size) { + return -(int32_t) size; + } + memcpy(buf, token, size); + return (int32_t) size; + }; + + // if we have a cache - use it + { + const auto & cache = vocab.cache_token_to_piece; + + if (!cache.empty()) { + const auto & result = cache.at(token); + return _try_copy(result.data(), result.size()); + } + } + + if (0 <= token && token < (int32_t) vocab.id_to_token.size()) { + const std::string & token_text = vocab.id_to_token[token].text; + switch (llama_vocab_get_type(vocab)) { + case LLAMA_VOCAB_TYPE_WPM: + case LLAMA_VOCAB_TYPE_SPM: + case LLAMA_VOCAB_TYPE_UGM: { + // NOTE: we accept all unsupported token types, + // suppressing them like CONTROL tokens. + if (attr & (attr_special | LLAMA_TOKEN_ATTR_USER_DEFINED)) { + return _try_copy(token_text.data(), token_text.size()); + } else if (attr & LLAMA_TOKEN_ATTR_NORMAL) { + std::string result = token_text; + llama_unescape_whitespace(result); + return _try_copy(result.data(), result.size()); + } else if (attr & LLAMA_TOKEN_ATTR_BYTE) { + char byte = (char) llama_token_to_byte(vocab, token); + return _try_copy((char*) &byte, 1); + } + break; + } + case LLAMA_VOCAB_TYPE_BPE: { + // NOTE: we accept all unsupported token types, + // suppressing them like CONTROL tokens. + if (attr & (attr_special | LLAMA_TOKEN_ATTR_USER_DEFINED)) { + return _try_copy(token_text.data(), token_text.size()); + } else if (attr & LLAMA_TOKEN_ATTR_NORMAL) { + std::string result = llama_decode_text(token_text); + return _try_copy(result.data(), result.size()); + } + break; + } + default: + GGML_ABORT("fatal error"); + } + } + + return 0; +} + +int32_t llama_detokenize_impl( + const struct llama_vocab & vocab, + const llama_token * tokens, + int32_t n_tokens, + char * text, + int32_t text_len_max, + bool remove_special, + bool unparse_special) { + int32_t avail = text_len_max; + int32_t total = 0; + + // remove the leading space + bool remove_space = vocab.tokenizer_add_space_prefix; + + if (remove_special && vocab.tokenizer_add_bos) { + if (n_tokens > 0 && tokens[0] == vocab.special_bos_id) { + remove_space = false; + n_tokens--; + tokens++; + } + } + + if (remove_special && vocab.tokenizer_add_eos) { + if (n_tokens > 0 && tokens[n_tokens-1] == vocab.special_eos_id) { + n_tokens--; + } + } + + for (int32_t i = 0; i < n_tokens; ++i) { + GGML_ASSERT(avail >= 0); + int32_t n_chars = llama_token_to_piece_impl(vocab, tokens[i], text, avail, remove_space, unparse_special); + remove_space = false; + if (n_chars < 0) { + avail = 0; + total -= n_chars; + } else if (n_chars > 0) { + avail -= n_chars; + text += n_chars; + total += n_chars; + } + } + + if (total > text_len_max) { + return -total; + } + + if (vocab.tokenizer_clean_spaces) { + text -= total; // restart text + + // first pass: characters ?!., //TODO: where do these characters come from? + const int32_t total1 = total; + total = total ? 1 : 0; + for (int32_t i = 1; i < total1; ++i) { + const char x = text[i]; + if (text[i - 1] == ' ') { + if (x == '?' || x == '!' || x == '.' || x == ',') { // " ?", " !", " .", " ," + total--; // remove space + } + } + text[total++] = x; + } + + // second pass: strip single apostrophe between spaces + const int32_t total2 = total; + total = total ? 1 : 0; + for (int32_t i = 1; i < total2; ++i) { + const char x = text[i]; + if (x == '\'' && i + 1 < total2 && text[i - 1] == ' ' && text[i + 1] == ' ') { // " ' " + total--; // remove prev space + text[++i] = '\0'; // remove next space + } + text[total++] = x; + } + + // third pass: apostrophe contractions //NOTE: this makes sense? + const int32_t total3 = total; + total = total ? 1 : 0; + for (int32_t i = 1; i < total3; ++i) { + const char x = text[i]; + if (text[i - 1] == ' ') { + if (x == '\'' && i + 1 < total3) { + const char x1 = text[i + 1]; + if (x1 == 't' || x1 == 'd') { // " 't", " 'd" + //total--; // remove space + } else if (x1 == 's' || x1 == 'm') { // " 's", " 'm" + total--; // remove space + } else if (i + 2 < total3) { + const char x2 = text[i + 2]; + if ((x1 == 'l' && x2 == 'l')) { // " 'll" + //total--; // remove space + } else if ((x1 == 'r' && x2 == 'e') || (x1 == 'v' && x2 == 'e')) { // " 're", " 've" + total--; // remove space + } else { + //total--; // remove space + } + } else { + //total--; // remove space + } + } + } + text[total++] = x; + } + } + + return total <= text_len_max ? total : -total; +} diff --git a/llama/llama-vocab.h b/llama/llama-vocab.h new file mode 100644 index 00000000..84826366 --- /dev/null +++ b/llama/llama-vocab.h @@ -0,0 +1,156 @@ +/** + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file + * + * MIT License + * + * Copyright (c) 2023-2024 The ggml authors + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#pragma once + +#include "llama-impl.h" + +#include +#include +#include +#include + +struct llama_vocab { + using id = llama_token; + using token = std::string; + using tattr = llama_token_attr; + + struct token_data { + token text; + float score; + tattr attr; + }; + + enum llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM; + enum llama_vocab_pre_type type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; + + int max_token_len = 0; // used for optimizing longest token search + + std::unordered_map token_to_id; + std::vector id_to_token; + + std::vector cache_special_tokens; + std::vector cache_token_to_piece; // llama_token_to_piece(special = true); + + std::map, int> bpe_ranks; + + // default LLaMA special tokens + id special_bos_id = 1; + id special_eos_id = 2; + id special_unk_id = 0; + id special_sep_id = -1; + id special_pad_id = -1; + id special_cls_id = -1; + id special_mask_id = -1; + + id linefeed_id = 13; + id special_prefix_id = -1; + id special_suffix_id = -1; + id special_middle_id = -1; + id special_eot_id = -1; // TODO: move above after "eos_id", and here add "file separator" token + + // tokenizer flags + bool tokenizer_add_space_prefix = false; + bool tokenizer_add_bos = false; + bool tokenizer_add_eos = false; + bool tokenizer_ignore_merges = false; + bool tokenizer_clean_spaces = false; // clean_up_tokenization_spaces + bool tokenizer_remove_extra_whitespaces = false; + bool tokenizer_escape_whitespaces = true; + bool tokenizer_treat_whitespace_as_suffix = false; + + std::vector precompiled_charsmap; + + int find_bpe_rank(const std::string & token_left, const std::string & token_right) const; +}; + +const struct llama_vocab * llama_get_vocab(const struct llama_context * ctx); + +// +// internal API +// + +// TODO: rename to llama_tokenize_impl +// TODO: This should probably be in llama.h +std::vector llama_tokenize_internal( + const llama_vocab & vocab, + std::string raw_text, + bool add_special, + bool parse_special = false); + +llama_token llama_byte_to_token_impl(const llama_vocab & vocab, uint8_t ch); + +const char * llama_token_get_text_impl(const struct llama_vocab & vocab, llama_token token); + +float llama_token_get_score_impl(const struct llama_vocab & vocab, llama_token token); + +llama_token_attr llama_token_get_attr_impl(const struct llama_vocab & vocab, llama_token token); + +bool llama_token_is_eog_impl(const struct llama_vocab & vocab, llama_token token); + +bool llama_token_is_control_impl(const struct llama_vocab & vocab, llama_token token); + +llama_token llama_token_bos_impl(const struct llama_vocab & vocab); +llama_token llama_token_eos_impl(const struct llama_vocab & vocab); +llama_token llama_token_cls_impl(const struct llama_vocab & vocab); +llama_token llama_token_sep_impl(const struct llama_vocab & vocab); +llama_token llama_token_nl_impl (const struct llama_vocab & vocab); +llama_token llama_token_pad_impl(const struct llama_vocab & vocab); + +int32_t llama_add_bos_token_impl(const struct llama_vocab & vocab); +int32_t llama_add_eos_token_impl(const struct llama_vocab & vocab); + +llama_token llama_token_prefix_impl(const struct llama_vocab & vocab); +llama_token llama_token_middle_impl(const struct llama_vocab & vocab); +llama_token llama_token_suffix_impl(const struct llama_vocab & vocab); +llama_token llama_token_eot_impl (const struct llama_vocab & vocab); + +int32_t llama_tokenize_impl( + const struct llama_vocab & vocab, + const char * text, + int32_t text_len, + llama_token * tokens, + int32_t n_tokens_max, + bool add_special, + bool parse_special); + +// does not write null-terminator to buf +int32_t llama_token_to_piece_impl( + const struct llama_vocab & vocab, + llama_token token, + char * buf, + int32_t length, + int32_t lstrip, + bool special); + +int32_t llama_detokenize_impl( + const struct llama_vocab & vocab, + const llama_token * tokens, + int32_t n_tokens, + char * text, + int32_t text_len_max, + bool remove_special, + bool unparse_special); diff --git a/llama/llama.cpp b/llama/llama.cpp index b995e497..b95ed228 100644 --- a/llama/llama.cpp +++ b/llama/llama.cpp @@ -1,5 +1,5 @@ /** - * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file + * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file * * MIT License * @@ -24,8 +24,10 @@ * SOFTWARE. */ -#define LLAMA_API_INTERNAL -#include "llama.h" +#include "llama-impl.h" +#include "llama-vocab.h" +#include "llama-grammar.h" +#include "llama-sampling.h" #include "unicode.h" @@ -45,6 +47,12 @@ # include "ggml-sycl.h" #elif defined(GGML_USE_KOMPUTE) # include "ggml-kompute.h" +#elif defined(GGML_USE_CANN) +# include "ggml-cann.h" +#endif + +#ifdef GGML_USE_BLAS +# include "ggml-blas.h" #endif #ifdef GGML_USE_METAL @@ -79,6 +87,12 @@ #include #endif +#if __cplusplus >= 202000L + #define LU8(x) (const char*)(u8##x) +#else + #define LU8(x) u8##x +#endif + #include #include #include @@ -93,7 +107,6 @@ #include #include #include -#include #include #include #include @@ -103,9 +116,6 @@ #include #include #include -#include -#include -#include #include #include #include @@ -116,39 +126,25 @@ #pragma warning(disable: 4244 4267) // possible loss of data #endif -#ifdef __GNUC__ -#ifdef __MINGW32__ -#define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__))) -#else -#define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__))) -#endif -#else -#define LLAMA_ATTRIBUTE_FORMAT(...) -#endif - -#define LLAMA_MAX_NODES 8192 -#define LLAMA_MAX_EXPERTS 160 - -// -// logging -// - -LLAMA_ATTRIBUTE_FORMAT(2, 3) -static void llama_log_internal (ggml_log_level level, const char * format, ...); -static void llama_log_callback_default(ggml_log_level level, const char * text, void * user_data); - -#define LLAMA_LOG_INFO(...) llama_log_internal(GGML_LOG_LEVEL_INFO , __VA_ARGS__) -#define LLAMA_LOG_WARN(...) llama_log_internal(GGML_LOG_LEVEL_WARN , __VA_ARGS__) -#define LLAMA_LOG_ERROR(...) llama_log_internal(GGML_LOG_LEVEL_ERROR, __VA_ARGS__) +// bump if necessary +#define LLAMA_MAX_LAYERS 512 +#define LLAMA_MAX_EXPERTS 160 // DeepSeekV2 // // helpers // -static size_t utf8_len(char src) { - const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; - uint8_t highbits = static_cast(src) >> 4; - return lookup[highbits]; +// trim whitespace from the beginning and end of a string +static std::string trim(const std::string & str) { + size_t start = 0; + size_t end = str.size(); + while (start < end && isspace(str[start])) { + start += 1; + } + while (end > start && isspace(str[end - 1])) { + end -= 1; + } + return str.substr(start, end - start); } static void replace_all(std::string & s, const std::string & search, const std::string & replace) { @@ -239,14 +235,20 @@ enum llm_arch { LLM_ARCH_INTERNLM2, LLM_ARCH_MINICPM, LLM_ARCH_GEMMA, + LLM_ARCH_GEMMA2, LLM_ARCH_STARCODER2, LLM_ARCH_MAMBA, LLM_ARCH_XVERSE, LLM_ARCH_COMMAND_R, LLM_ARCH_DBRX, LLM_ARCH_OLMO, + LLM_ARCH_OPENELM, LLM_ARCH_ARCTIC, LLM_ARCH_DEEPSEEK2, + LLM_ARCH_CHATGLM, + LLM_ARCH_BITNET, + LLM_ARCH_T5, + LLM_ARCH_JAIS, LLM_ARCH_UNKNOWN, }; @@ -277,18 +279,25 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_INTERNLM2, "internlm2" }, { LLM_ARCH_MINICPM, "minicpm" }, { LLM_ARCH_GEMMA, "gemma" }, + { LLM_ARCH_GEMMA2, "gemma2" }, { LLM_ARCH_STARCODER2, "starcoder2" }, { LLM_ARCH_MAMBA, "mamba" }, { LLM_ARCH_XVERSE, "xverse" }, { LLM_ARCH_COMMAND_R, "command-r" }, { LLM_ARCH_DBRX, "dbrx" }, { LLM_ARCH_OLMO, "olmo" }, + { LLM_ARCH_OPENELM, "openelm" }, { LLM_ARCH_ARCTIC, "arctic" }, { LLM_ARCH_DEEPSEEK2, "deepseek2" }, + { LLM_ARCH_CHATGLM, "chatglm" }, + { LLM_ARCH_BITNET, "bitnet" }, + { LLM_ARCH_T5, "t5" }, + { LLM_ARCH_JAIS, "jais" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; enum llm_kv { + LLM_KV_GENERAL_TYPE, LLM_KV_GENERAL_ARCHITECTURE, LLM_KV_GENERAL_QUANTIZATION_VERSION, LLM_KV_GENERAL_ALIGNMENT, @@ -308,6 +317,7 @@ enum llm_kv { LLM_KV_LEADING_DENSE_BLOCK_COUNT, LLM_KV_FEED_FORWARD_LENGTH, LLM_KV_EXPERT_FEED_FORWARD_LENGTH, + LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, LLM_KV_USE_PARALLEL_RESIDUAL, LLM_KV_TENSOR_DATA_LAYOUT, LLM_KV_EXPERT_COUNT, @@ -316,6 +326,9 @@ enum llm_kv { LLM_KV_EXPERT_WEIGHTS_SCALE, LLM_KV_POOLING_TYPE, LLM_KV_LOGIT_SCALE, + LLM_KV_DECODER_START_TOKEN_ID, + LLM_KV_ATTN_LOGIT_SOFTCAPPING, + LLM_KV_FINAL_LOGIT_SOFTCAPPING, LLM_KV_ATTENTION_HEAD_COUNT, LLM_KV_ATTENTION_HEAD_COUNT_KV, @@ -328,6 +341,8 @@ enum llm_kv { LLM_KV_ATTENTION_CAUSAL, LLM_KV_ATTENTION_Q_LORA_RANK, LLM_KV_ATTENTION_KV_LORA_RANK, + LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, + LLM_KV_ATTENTION_SLIDING_WINDOW, LLM_KV_ROPE_DIMENSION_COUNT, LLM_KV_ROPE_FREQ_BASE, @@ -365,15 +380,21 @@ enum llm_kv { LLM_KV_TOKENIZER_ADD_BOS, LLM_KV_TOKENIZER_ADD_EOS, LLM_KV_TOKENIZER_ADD_PREFIX, + LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, + LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, LLM_KV_TOKENIZER_HF_JSON, LLM_KV_TOKENIZER_RWKV, LLM_KV_TOKENIZER_PREFIX_ID, LLM_KV_TOKENIZER_SUFFIX_ID, LLM_KV_TOKENIZER_MIDDLE_ID, LLM_KV_TOKENIZER_EOT_ID, + + LLM_KV_ADAPTER_TYPE, + LLM_KV_ADAPTER_LORA_ALPHA, }; static const std::map LLM_KV_NAMES = { + { LLM_KV_GENERAL_TYPE, "general.type" }, { LLM_KV_GENERAL_ARCHITECTURE, "general.architecture" }, { LLM_KV_GENERAL_QUANTIZATION_VERSION, "general.quantization_version" }, { LLM_KV_GENERAL_ALIGNMENT, "general.alignment" }, @@ -386,33 +407,39 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_GENERAL_SOURCE_URL, "general.source.url" }, { LLM_KV_GENERAL_SOURCE_HF_REPO, "general.source.huggingface.repository" }, - { LLM_KV_VOCAB_SIZE, "%s.vocab_size" }, - { LLM_KV_CONTEXT_LENGTH, "%s.context_length" }, - { LLM_KV_EMBEDDING_LENGTH, "%s.embedding_length" }, - { LLM_KV_BLOCK_COUNT, "%s.block_count" }, - { LLM_KV_LEADING_DENSE_BLOCK_COUNT, "%s.leading_dense_block_count" }, - { LLM_KV_FEED_FORWARD_LENGTH, "%s.feed_forward_length" }, - { LLM_KV_EXPERT_FEED_FORWARD_LENGTH, "%s.expert_feed_forward_length" }, - { LLM_KV_USE_PARALLEL_RESIDUAL, "%s.use_parallel_residual" }, - { LLM_KV_TENSOR_DATA_LAYOUT, "%s.tensor_data_layout" }, - { LLM_KV_EXPERT_COUNT, "%s.expert_count" }, - { LLM_KV_EXPERT_USED_COUNT, "%s.expert_used_count" }, - { LLM_KV_EXPERT_SHARED_COUNT, "%s.expert_shared_count" }, - { LLM_KV_EXPERT_WEIGHTS_SCALE, "%s.expert_weights_scale" }, - { LLM_KV_POOLING_TYPE , "%s.pooling_type" }, - { LLM_KV_LOGIT_SCALE, "%s.logit_scale" }, + { LLM_KV_VOCAB_SIZE, "%s.vocab_size" }, + { LLM_KV_CONTEXT_LENGTH, "%s.context_length" }, + { LLM_KV_EMBEDDING_LENGTH, "%s.embedding_length" }, + { LLM_KV_BLOCK_COUNT, "%s.block_count" }, + { LLM_KV_LEADING_DENSE_BLOCK_COUNT, "%s.leading_dense_block_count" }, + { LLM_KV_FEED_FORWARD_LENGTH, "%s.feed_forward_length" }, + { LLM_KV_EXPERT_FEED_FORWARD_LENGTH, "%s.expert_feed_forward_length" }, + { LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, "%s.expert_shared_feed_forward_length" }, + { LLM_KV_USE_PARALLEL_RESIDUAL, "%s.use_parallel_residual" }, + { LLM_KV_TENSOR_DATA_LAYOUT, "%s.tensor_data_layout" }, + { LLM_KV_EXPERT_COUNT, "%s.expert_count" }, + { LLM_KV_EXPERT_USED_COUNT, "%s.expert_used_count" }, + { LLM_KV_EXPERT_SHARED_COUNT, "%s.expert_shared_count" }, + { LLM_KV_EXPERT_WEIGHTS_SCALE, "%s.expert_weights_scale" }, + { LLM_KV_POOLING_TYPE , "%s.pooling_type" }, + { LLM_KV_LOGIT_SCALE, "%s.logit_scale" }, + { LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" }, + { LLM_KV_ATTN_LOGIT_SOFTCAPPING, "%s.attn_logit_softcapping" }, + { LLM_KV_FINAL_LOGIT_SOFTCAPPING, "%s.final_logit_softcapping" }, - { LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" }, - { LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" }, - { LLM_KV_ATTENTION_MAX_ALIBI_BIAS, "%s.attention.max_alibi_bias" }, - { LLM_KV_ATTENTION_CLAMP_KQV, "%s.attention.clamp_kqv" }, - { LLM_KV_ATTENTION_KEY_LENGTH, "%s.attention.key_length" }, - { LLM_KV_ATTENTION_VALUE_LENGTH, "%s.attention.value_length" }, - { LLM_KV_ATTENTION_LAYERNORM_EPS, "%s.attention.layer_norm_epsilon" }, - { LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, "%s.attention.layer_norm_rms_epsilon" }, - { LLM_KV_ATTENTION_CAUSAL, "%s.attention.causal" }, - { LLM_KV_ATTENTION_Q_LORA_RANK, "%s.attention.q_lora_rank" }, - { LLM_KV_ATTENTION_KV_LORA_RANK, "%s.attention.kv_lora_rank" }, + { LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" }, + { LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" }, + { LLM_KV_ATTENTION_MAX_ALIBI_BIAS, "%s.attention.max_alibi_bias" }, + { LLM_KV_ATTENTION_CLAMP_KQV, "%s.attention.clamp_kqv" }, + { LLM_KV_ATTENTION_KEY_LENGTH, "%s.attention.key_length" }, + { LLM_KV_ATTENTION_VALUE_LENGTH, "%s.attention.value_length" }, + { LLM_KV_ATTENTION_LAYERNORM_EPS, "%s.attention.layer_norm_epsilon" }, + { LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, "%s.attention.layer_norm_rms_epsilon" }, + { LLM_KV_ATTENTION_CAUSAL, "%s.attention.causal" }, + { LLM_KV_ATTENTION_Q_LORA_RANK, "%s.attention.q_lora_rank" }, + { LLM_KV_ATTENTION_KV_LORA_RANK, "%s.attention.kv_lora_rank" }, + { LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" }, + { LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" }, { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, { LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" }, @@ -433,29 +460,34 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_SSM_STATE_SIZE, "%s.ssm.state_size" }, { LLM_KV_SSM_TIME_STEP_RANK, "%s.ssm.time_step_rank" }, - { LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" }, - { LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" }, - { LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" }, - { LLM_KV_TOKENIZER_TOKEN_TYPE, "tokenizer.ggml.token_type" }, - { LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, "tokenizer.ggml.token_type_count" }, - { LLM_KV_TOKENIZER_SCORES, "tokenizer.ggml.scores" }, - { LLM_KV_TOKENIZER_MERGES, "tokenizer.ggml.merges" }, - { LLM_KV_TOKENIZER_BOS_ID, "tokenizer.ggml.bos_token_id" }, - { LLM_KV_TOKENIZER_EOS_ID, "tokenizer.ggml.eos_token_id" }, - { LLM_KV_TOKENIZER_UNK_ID, "tokenizer.ggml.unknown_token_id" }, - { LLM_KV_TOKENIZER_SEP_ID, "tokenizer.ggml.seperator_token_id" }, - { LLM_KV_TOKENIZER_PAD_ID, "tokenizer.ggml.padding_token_id" }, - { LLM_KV_TOKENIZER_CLS_ID, "tokenizer.ggml.cls_token_id" }, - { LLM_KV_TOKENIZER_MASK_ID, "tokenizer.ggml.mask_token_id" }, - { LLM_KV_TOKENIZER_ADD_BOS, "tokenizer.ggml.add_bos_token" }, - { LLM_KV_TOKENIZER_ADD_EOS, "tokenizer.ggml.add_eos_token" }, - { LLM_KV_TOKENIZER_ADD_PREFIX, "tokenizer.ggml.add_space_prefix" }, - { LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" }, - { LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" }, - { LLM_KV_TOKENIZER_PREFIX_ID, "tokenizer.ggml.prefix_token_id" }, - { LLM_KV_TOKENIZER_SUFFIX_ID, "tokenizer.ggml.suffix_token_id" }, - { LLM_KV_TOKENIZER_MIDDLE_ID, "tokenizer.ggml.middle_token_id" }, - { LLM_KV_TOKENIZER_EOT_ID, "tokenizer.ggml.eot_token_id" }, + { LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" }, + { LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" }, + { LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" }, + { LLM_KV_TOKENIZER_TOKEN_TYPE, "tokenizer.ggml.token_type" }, + { LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, "tokenizer.ggml.token_type_count" }, + { LLM_KV_TOKENIZER_SCORES, "tokenizer.ggml.scores" }, + { LLM_KV_TOKENIZER_MERGES, "tokenizer.ggml.merges" }, + { LLM_KV_TOKENIZER_BOS_ID, "tokenizer.ggml.bos_token_id" }, + { LLM_KV_TOKENIZER_EOS_ID, "tokenizer.ggml.eos_token_id" }, + { LLM_KV_TOKENIZER_UNK_ID, "tokenizer.ggml.unknown_token_id" }, + { LLM_KV_TOKENIZER_SEP_ID, "tokenizer.ggml.seperator_token_id" }, + { LLM_KV_TOKENIZER_PAD_ID, "tokenizer.ggml.padding_token_id" }, + { LLM_KV_TOKENIZER_CLS_ID, "tokenizer.ggml.cls_token_id" }, + { LLM_KV_TOKENIZER_MASK_ID, "tokenizer.ggml.mask_token_id" }, + { LLM_KV_TOKENIZER_ADD_BOS, "tokenizer.ggml.add_bos_token" }, + { LLM_KV_TOKENIZER_ADD_EOS, "tokenizer.ggml.add_eos_token" }, + { LLM_KV_TOKENIZER_ADD_PREFIX, "tokenizer.ggml.add_space_prefix" }, + { LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, "tokenizer.ggml.remove_extra_whitespaces" }, + { LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, "tokenizer.ggml.precompiled_charsmap" }, + { LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" }, + { LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" }, + { LLM_KV_TOKENIZER_PREFIX_ID, "tokenizer.ggml.prefix_token_id" }, + { LLM_KV_TOKENIZER_SUFFIX_ID, "tokenizer.ggml.suffix_token_id" }, + { LLM_KV_TOKENIZER_MIDDLE_ID, "tokenizer.ggml.middle_token_id" }, + { LLM_KV_TOKENIZER_EOT_ID, "tokenizer.ggml.eot_token_id" }, + + { LLM_KV_ADAPTER_TYPE, "adapter.type" }, + { LLM_KV_ADAPTER_LORA_ALPHA, "adapter.lora.alpha" }, }; struct LLM_KV { @@ -486,10 +518,12 @@ enum llm_tensor { LLM_TENSOR_ATTN_NORM, LLM_TENSOR_ATTN_NORM_2, LLM_TENSOR_ATTN_OUT_NORM, + LLM_TENSOR_ATTN_POST_NORM, LLM_TENSOR_ATTN_ROT_EMBD, LLM_TENSOR_FFN_GATE_INP, LLM_TENSOR_FFN_GATE_INP_SHEXP, LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_POST_NORM, LLM_TENSOR_FFN_GATE, LLM_TENSOR_FFN_DOWN, LLM_TENSOR_FFN_UP, @@ -520,6 +554,36 @@ enum llm_tensor { LLM_TENSOR_ATTN_KV_B, LLM_TENSOR_ATTN_Q_A_NORM, LLM_TENSOR_ATTN_KV_A_NORM, + LLM_TENSOR_ATTN_SUB_NORM, + LLM_TENSOR_FFN_SUB_NORM, + LLM_TENSOR_DEC_ATTN_NORM, + LLM_TENSOR_DEC_ATTN_Q, + LLM_TENSOR_DEC_ATTN_K, + LLM_TENSOR_DEC_ATTN_V, + LLM_TENSOR_DEC_ATTN_OUT, + LLM_TENSOR_DEC_ATTN_REL_B, + LLM_TENSOR_DEC_CROSS_ATTN_NORM, + LLM_TENSOR_DEC_CROSS_ATTN_Q, + LLM_TENSOR_DEC_CROSS_ATTN_K, + LLM_TENSOR_DEC_CROSS_ATTN_V, + LLM_TENSOR_DEC_CROSS_ATTN_OUT, + LLM_TENSOR_DEC_CROSS_ATTN_REL_B, + LLM_TENSOR_DEC_FFN_NORM, + LLM_TENSOR_DEC_FFN_GATE, + LLM_TENSOR_DEC_FFN_DOWN, + LLM_TENSOR_DEC_FFN_UP, + LLM_TENSOR_DEC_OUTPUT_NORM, + LLM_TENSOR_ENC_ATTN_NORM, + LLM_TENSOR_ENC_ATTN_Q, + LLM_TENSOR_ENC_ATTN_K, + LLM_TENSOR_ENC_ATTN_V, + LLM_TENSOR_ENC_ATTN_OUT, + LLM_TENSOR_ENC_ATTN_REL_B, + LLM_TENSOR_ENC_FFN_NORM, + LLM_TENSOR_ENC_FFN_GATE, + LLM_TENSOR_ENC_FFN_DOWN, + LLM_TENSOR_ENC_FFN_UP, + LLM_TENSOR_ENC_OUTPUT_NORM, }; static const std::map> LLM_TENSOR_NAMES = { @@ -982,6 +1046,24 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, + { + LLM_ARCH_GEMMA2, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, + }, + }, { LLM_ARCH_STARCODER2, { @@ -1082,6 +1164,22 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, + { + LLM_ARCH_OPENELM, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, { LLM_ARCH_ARCTIC, { @@ -1133,6 +1231,89 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, }, }, + { + LLM_ARCH_CHATGLM, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + }, + }, + { + LLM_ARCH_BITNET, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_SUB_NORM, "blk.%d.attn_sub_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_SUB_NORM, "blk.%d.ffn_sub_norm" }, + }, + }, + { + LLM_ARCH_T5, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_DEC_OUTPUT_NORM, "dec.output_norm" }, + { LLM_TENSOR_DEC_ATTN_NORM, "dec.blk.%d.attn_norm" }, + { LLM_TENSOR_DEC_ATTN_Q, "dec.blk.%d.attn_q" }, + { LLM_TENSOR_DEC_ATTN_K, "dec.blk.%d.attn_k" }, + { LLM_TENSOR_DEC_ATTN_V, "dec.blk.%d.attn_v" }, + { LLM_TENSOR_DEC_ATTN_OUT, "dec.blk.%d.attn_o" }, + { LLM_TENSOR_DEC_ATTN_REL_B, "dec.blk.%d.attn_rel_b" }, + { LLM_TENSOR_DEC_CROSS_ATTN_NORM, "dec.blk.%d.cross_attn_norm" }, + { LLM_TENSOR_DEC_CROSS_ATTN_Q, "dec.blk.%d.cross_attn_q" }, + { LLM_TENSOR_DEC_CROSS_ATTN_K, "dec.blk.%d.cross_attn_k" }, + { LLM_TENSOR_DEC_CROSS_ATTN_V, "dec.blk.%d.cross_attn_v" }, + { LLM_TENSOR_DEC_CROSS_ATTN_OUT, "dec.blk.%d.cross_attn_o" }, + { LLM_TENSOR_DEC_CROSS_ATTN_REL_B, "dec.blk.%d.cross_attn_rel_b" }, + { LLM_TENSOR_DEC_FFN_NORM, "dec.blk.%d.ffn_norm" }, + { LLM_TENSOR_DEC_FFN_GATE, "dec.blk.%d.ffn_gate" }, + { LLM_TENSOR_DEC_FFN_DOWN, "dec.blk.%d.ffn_down" }, + { LLM_TENSOR_DEC_FFN_UP, "dec.blk.%d.ffn_up" }, + { LLM_TENSOR_ENC_OUTPUT_NORM, "enc.output_norm" }, + { LLM_TENSOR_ENC_ATTN_NORM, "enc.blk.%d.attn_norm" }, + { LLM_TENSOR_ENC_ATTN_Q, "enc.blk.%d.attn_q" }, + { LLM_TENSOR_ENC_ATTN_K, "enc.blk.%d.attn_k" }, + { LLM_TENSOR_ENC_ATTN_V, "enc.blk.%d.attn_v" }, + { LLM_TENSOR_ENC_ATTN_OUT, "enc.blk.%d.attn_o" }, + { LLM_TENSOR_ENC_ATTN_REL_B, "enc.blk.%d.attn_rel_b" }, + { LLM_TENSOR_ENC_FFN_NORM, "enc.blk.%d.ffn_norm" }, + { LLM_TENSOR_ENC_FFN_GATE, "enc.blk.%d.ffn_gate" }, + { LLM_TENSOR_ENC_FFN_DOWN, "enc.blk.%d.ffn_down" }, + { LLM_TENSOR_ENC_FFN_UP, "enc.blk.%d.ffn_up" }, + }, + }, + { + LLM_ARCH_JAIS, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + }, + }, { LLM_ARCH_UNKNOWN, { @@ -1300,6 +1481,126 @@ struct no_init { }; struct llama_file { + +#if defined(_WIN32) + // use FILE * so we don't have to re-open the file to mmap + FILE * fp; + HANDLE fp_win32; + size_t size; + +private: + std::string GetErrorMessageWin32(DWORD error_code) const { + std::string ret; + LPSTR lpMsgBuf = NULL; + DWORD bufLen = FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, + NULL, error_code, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&lpMsgBuf, 0, NULL); + if (!bufLen) { + ret = format("Win32 error code: %s", error_code); + } else { + ret = lpMsgBuf; + LocalFree(lpMsgBuf); + } + + return ret; + } + +public: + + llama_file(const char * fname, const char * mode) { + fp = ggml_fopen(fname, mode); + if (fp == NULL) { + throw std::runtime_error(format("failed to open %s: %s", fname, strerror(errno))); + } + fp_win32 = (HANDLE) _get_osfhandle(_fileno(fp)); + seek(0, SEEK_END); + size = tell(); + seek(0, SEEK_SET); + } + + size_t tell() const { + // SetFilePointerEx returns the current position when seeking relative 0 bytes + LARGE_INTEGER li; + li.QuadPart = 0; + BOOL ret = SetFilePointerEx(fp_win32, li, &li, FILE_CURRENT); + if (!ret) { + throw std::runtime_error(format("read error: %s", GetErrorMessageWin32(GetLastError()).c_str())); + } + + return li.QuadPart; + } + + void seek(size_t offset, int whence) const { + // no need to convert SEEK_* to FILE_*. The enums are the same. + // Still, keep static asserts to avoid failures in the future. + static_assert(SEEK_SET == FILE_BEGIN, "SEEK_SET != FILE_BEGIN"); + static_assert(SEEK_CUR == FILE_CURRENT, "SEEK_CUR != FILE_CURRENT"); + static_assert(SEEK_END == FILE_END, "SEEK_END != FILE_END"); + + LARGE_INTEGER li; + li.QuadPart = offset; + BOOL ret = SetFilePointerEx(fp_win32, li, NULL, whence); + if (!ret) { + throw std::runtime_error(format("read error: %s", GetErrorMessageWin32(GetLastError()).c_str())); + } + } + + void read_raw(void * ptr, size_t len) const { + // On Win32 ReadFile is significant faster than fread which is again significant faster than std::fstream. Thus + // use the Win32 API to do file io instead of the C/C++ library functions. + + // There are conditions under which ReadFile cannot read chunks >64MB. + // Thus split the operation into smaller chunks if len exceeds this limit. + size_t bytes_read = 0; + while (bytes_read < len) { + size_t chunk_size = std::min(len - bytes_read, 64*1024*1024); + DWORD chunk_read = 0; + BOOL result = ReadFile(fp_win32, reinterpret_cast(ptr) + bytes_read, chunk_size, &chunk_read, NULL); + if (!result) { + throw std::runtime_error(format("read error: %s", GetErrorMessageWin32(GetLastError()).c_str())); + } + if (chunk_read < chunk_size || chunk_read == 0) { + throw std::runtime_error("unexpectedly reached end of file"); + } + + bytes_read += chunk_read; + } ; + } + + uint32_t read_u32() const { + uint32_t val; + read_raw(&val, sizeof(val)); + return val; + } + + void write_raw(const void * ptr, size_t len) const { + // There are conditions under which WriteFile cannot write chunks >64MB. + // Thus split the operation into smaller chunks if len exceeds this limit. + size_t bytes_written = 0; + while (bytes_written < len) { + size_t chunk_size = std::min(len - bytes_written, 64*1024*1024); + DWORD chunk_written = 0; + BOOL result = WriteFile(fp_win32, reinterpret_cast(ptr) + bytes_written, chunk_size, &chunk_written, NULL); + if (!result) { + throw std::runtime_error(format("write error: %s", GetErrorMessageWin32(GetLastError()).c_str())); + } + if (chunk_written < chunk_size || chunk_written == 0) { + throw std::runtime_error("unexpectedly failed to write bytes"); + } + + bytes_written += chunk_written; + } + } + + void write_u32(std::uint32_t val) const { + write_raw(&val, sizeof(val)); + } + + ~llama_file() { + if (fp) { + std::fclose(fp); + } + } +#else // use FILE * so we don't have to re-open the file to mmap FILE * fp; size_t size; @@ -1320,7 +1621,10 @@ struct llama_file { #else long ret = std::ftell(fp); #endif - GGML_ASSERT(ret != -1); // this really shouldn't fail + if (ret == -1) { + throw std::runtime_error(format("ftell error: %s", strerror(errno))); + } + return (size_t) ret; } @@ -1330,7 +1634,9 @@ struct llama_file { #else int ret = std::fseek(fp, (long) offset, whence); #endif - GGML_ASSERT(ret == 0); // same + if (ret != 0) { + throw std::runtime_error(format("seek error: %s", strerror(errno))); + } } void read_raw(void * ptr, size_t len) const { @@ -1373,6 +1679,7 @@ struct llama_file { std::fclose(fp); } } +#endif }; using llama_files = std::vector>; @@ -1729,18 +2036,19 @@ using llama_mlocks = std::vector>; // NOTE: avoid ever using this except for building the token_to_piece caches static std::string llama_token_to_piece(const struct llama_model * model, llama_token token, bool special) { - std::vector result(8, 0); - const int n_tokens = llama_token_to_piece(model, token, result.data(), result.size(), special); - if (n_tokens < 0) { - result.resize(-n_tokens); - int check = llama_token_to_piece(model, token, result.data(), result.size(), special); - GGML_ASSERT(check == -n_tokens); + std::string piece; + piece.resize(piece.capacity()); // using string internal cache + const int n_chars = llama_token_to_piece(model, token, &piece[0], piece.size(), 0, special); + if (n_chars < 0) { + piece.resize(-n_chars); + int check = llama_token_to_piece(model, token, &piece[0], piece.size(), 0, special); + GGML_ASSERT(check == -n_chars); } else { - result.resize(n_tokens); + piece.resize(n_chars); } - return std::string(result.data(), result.size()); + return piece; } static ggml_backend_buffer_type_t llama_default_buffer_type_cpu(bool host_buffer) { @@ -1781,6 +2089,8 @@ struct llama_state { ggml_backend_metal_log_set_callback(log_callback, log_callback_user_data); #elif defined(GGML_USE_CUDA) ggml_backend_cuda_log_set_callback(log_callback, log_callback_user_data); +#elif defined(GGML_USE_CANN) + ggml_backend_cann_log_set_callback(log_callback, log_callback_user_data); #endif } @@ -1798,22 +2108,34 @@ enum e_model { MODEL_17M, MODEL_22M, MODEL_33M, + MODEL_60M, MODEL_70M, + MODEL_80M, MODEL_109M, MODEL_137M, MODEL_160M, + MODEL_220M, + MODEL_250M, + MODEL_270M, MODEL_335M, MODEL_410M, + MODEL_450M, + MODEL_770M, + MODEL_780M, MODEL_0_5B, MODEL_1B, + MODEL_1_3B, MODEL_1_4B, MODEL_2B, MODEL_2_8B, MODEL_3B, MODEL_4B, + MODEL_6B, MODEL_6_9B, MODEL_7B, MODEL_8B, + MODEL_9B, + MODEL_11B, MODEL_12B, MODEL_13B, MODEL_14B, @@ -1837,6 +2159,8 @@ enum e_model { MODEL_8x22B, MODEL_16x12B, MODEL_10B_128x3_66B, + MODEL_57B_A14B, + MODEL_27B, }; static const size_t kiB = 1024; @@ -1851,27 +2175,34 @@ struct llama_hparams { uint32_t n_vocab; uint32_t n_ctx_train; // context size the model was trained on uint32_t n_embd; - uint32_t n_head; - uint32_t n_head_kv; uint32_t n_layer; uint32_t n_rot; + uint32_t n_swa = 0; // sliding window attention (SWA) uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head - uint32_t n_ff; uint32_t n_expert = 0; uint32_t n_expert_used = 0; uint32_t n_vocab_type = 0; // for BERT-style token types + uint32_t n_rel_attn_bkts = 0; + + std::array n_head_arr; + std::array n_head_kv_arr; + std::array n_ff_arr; uint32_t n_layer_dense_lead = 0; uint32_t n_lora_q = 0; uint32_t n_lora_kv = 0; uint32_t n_ff_exp = 0; + uint32_t n_ff_shexp = 0; uint32_t n_expert_shared = 0; float expert_weights_scale = 0.0; float f_norm_eps; float f_norm_rms_eps; + float f_attn_logit_softcapping = 50.0f; + float f_final_logit_softcapping = 30.0f; + float rope_attn_factor = 1.0f; float rope_freq_base_train; float rope_freq_scale_train; @@ -1888,8 +2219,13 @@ struct llama_hparams { float f_max_alibi_bias = 0.0f; float f_logit_scale = 0.0f; - bool causal_attn = true; - bool use_alibi = false; + bool causal_attn = true; + bool use_alibi = false; + bool attn_soft_cap = false; + + // needed by encoder-decoder models (e.g. T5, FLAN-T5) + // ref: https://github.com/ggerganov/llama.cpp/pull/8141 + llama_token dec_start_token_id = -1; enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE; enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE; @@ -1900,20 +2236,24 @@ struct llama_hparams { if (this->n_vocab != other.n_vocab) return true; if (this->n_ctx_train != other.n_ctx_train) return true; if (this->n_embd != other.n_embd) return true; - if (this->n_head != other.n_head) return true; - if (this->n_head_kv != other.n_head_kv) return true; if (this->n_layer != other.n_layer) return true; if (this->n_rot != other.n_rot) return true; + if (this->n_swa != other.n_swa) return true; if (this->n_embd_head_k != other.n_embd_head_k) return true; if (this->n_embd_head_v != other.n_embd_head_v) return true; - if (this->n_ff != other.n_ff) return true; if (this->n_expert != other.n_expert) return true; if (this->n_expert_used != other.n_expert_used) return true; + if (this->n_head_arr != other.n_head_arr) return true; + if (this->n_head_kv_arr != other.n_head_kv_arr) return true; + if (this->n_ff_arr != other.n_ff_arr) return true; + + if (this->n_rel_attn_bkts != other.n_rel_attn_bkts) return true; if (this->n_layer_dense_lead != other.n_layer_dense_lead) return true; if (this->n_lora_q != other.n_lora_q) return true; if (this->n_lora_kv != other.n_lora_kv) return true; if (this->n_ff_exp != other.n_ff_exp) return true; + if (this->n_ff_shexp != other.n_ff_shexp) return true; if (this->n_expert_shared != other.n_expert_shared) return true; if (this->rope_finetuned != other.rope_finetuned) return true; @@ -1924,6 +2264,8 @@ struct llama_hparams { if (this->ssm_d_state != other.ssm_d_state) return true; if (this->ssm_dt_rank != other.ssm_dt_rank) return true; + if (this->dec_start_token_id != other.dec_start_token_id) return true; + const float EPSILON = 1e-9f; if (!is_float_close(this->f_norm_eps, other.f_norm_eps, EPSILON)) return true; @@ -1937,18 +2279,50 @@ struct llama_hparams { return false; } - uint32_t n_gqa() const { + uint32_t n_head(uint32_t il = 0) const { + if (il < n_layer) { + return n_head_arr[il]; + } + + GGML_ABORT("fatal error"); + } + + uint32_t n_head_kv(uint32_t il = 0) const { + if (il < n_layer) { + return n_head_kv_arr[il]; + } + + GGML_ABORT("fatal error"); + } + + uint32_t n_ff(uint32_t il = 0) const { + if (il < n_layer) { + return n_ff_arr[il]; + } + + GGML_ABORT("fatal error"); + } + + uint32_t n_gqa(uint32_t il = 0) const { + const uint32_t n_head = this->n_head(il); + const uint32_t n_head_kv = this->n_head_kv(il); + if (n_head_kv == 0) { return 0; } + return n_head/n_head_kv; } - uint32_t n_embd_k_gqa() const { // dimension of key embeddings across all k-v heads + uint32_t n_embd_k_gqa(uint32_t il = 0) const { // dimension of key embeddings across all k-v heads + const uint32_t n_head_kv = this->n_head_kv(il); + return n_embd_head_k * n_head_kv; } - uint32_t n_embd_v_gqa() const { // dimension of value embeddings across all k-v heads + uint32_t n_embd_v_gqa(uint32_t il = 0) const { // dimension of value embeddings across all k-v heads + const uint32_t n_head_kv = this->n_head_kv(il); + return n_embd_head_v * n_head_kv; } @@ -1965,6 +2339,8 @@ struct llama_hparams { } }; +static_assert(std::is_trivially_copyable::value, "llama_hparams must be trivially copyable"); + struct llama_cparams { uint32_t n_ctx; // context size used during inference uint32_t n_batch; @@ -1996,6 +2372,7 @@ struct llama_cparams { void * cb_eval_user_data; }; +// TODO: separate into "llama_layer_enc" and "llama_layer_dec" struct llama_layer { // normalization struct ggml_tensor * attn_norm; @@ -2010,6 +2387,11 @@ struct llama_layer { struct ggml_tensor * attn_out_norm_b; struct ggml_tensor * attn_q_a_norm; struct ggml_tensor * attn_kv_a_norm; + struct ggml_tensor * attn_sub_norm; + struct ggml_tensor * attn_post_norm; + struct ggml_tensor * ffn_sub_norm; + struct ggml_tensor * attn_norm_cross; + struct ggml_tensor * attn_norm_enc; // attention struct ggml_tensor * wq; @@ -2021,6 +2403,14 @@ struct llama_layer { struct ggml_tensor * wq_b; struct ggml_tensor * wkv_a_mqa; struct ggml_tensor * wkv_b; + struct ggml_tensor * wq_cross; + struct ggml_tensor * wk_cross; + struct ggml_tensor * wv_cross; + struct ggml_tensor * wo_cross; + struct ggml_tensor * wq_enc; + struct ggml_tensor * wk_enc; + struct ggml_tensor * wv_enc; + struct ggml_tensor * wo_enc; // attention bias struct ggml_tensor * bq; @@ -2029,17 +2419,27 @@ struct llama_layer { struct ggml_tensor * bo; struct ggml_tensor * bqkv; + // relative position bias + struct ggml_tensor * attn_rel_b; + struct ggml_tensor * attn_rel_b_enc; + struct ggml_tensor * attn_rel_b_cross; + // normalization struct ggml_tensor * ffn_norm; struct ggml_tensor * ffn_norm_b; + struct ggml_tensor * ffn_post_norm; struct ggml_tensor * layer_out_norm; struct ggml_tensor * layer_out_norm_b; struct ggml_tensor * ffn_norm_exps; + struct ggml_tensor * ffn_norm_enc; // ff struct ggml_tensor * ffn_gate; // w1 struct ggml_tensor * ffn_down; // w2 struct ggml_tensor * ffn_up; // w3 + struct ggml_tensor * ffn_gate_enc; + struct ggml_tensor * ffn_down_enc; + struct ggml_tensor * ffn_up_enc; // ff MoE struct ggml_tensor * ffn_gate_inp; @@ -2077,6 +2477,16 @@ struct llama_layer { // long rope factors struct ggml_tensor * rope_long = nullptr; struct ggml_tensor * rope_short = nullptr; + struct ggml_tensor * rope_freqs = nullptr; + + // bitnet scale + struct ggml_tensor * wq_scale; + struct ggml_tensor * wk_scale; + struct ggml_tensor * wv_scale; + struct ggml_tensor * wo_scale; + struct ggml_tensor * ffn_gate_scale; + struct ggml_tensor * ffn_up_scale; + struct ggml_tensor * ffn_down_scale; }; struct llama_kv_cell { @@ -2154,13 +2564,21 @@ struct llama_control_vector { int32_t layer_start = -1; int32_t layer_end = -1; - ggml_tensor * tensor_for(int il) const { + struct ggml_tensor * tensor_for(int il) const { if (il < 0 || il < layer_start || il > layer_end || (size_t) il >= tensors.size()) { return nullptr; } return tensors[il]; } + struct ggml_tensor * apply_to(struct ggml_context * ctx, struct ggml_tensor * cur, int il) const { + ggml_tensor * layer_dir = tensor_for(il); + if (layer_dir != nullptr) { + cur = ggml_add(ctx, cur, layer_dir); + } + return cur; + } + ~llama_control_vector() { for (struct ggml_context * ctx : ctxs) { ggml_free(ctx); @@ -2171,63 +2589,6 @@ struct llama_control_vector { } }; -struct llama_vocab { - using id = int32_t; - using token = std::string; - using tattr = llama_token_attr; - - struct token_data { - token text; - float score; - tattr attr; - }; - - enum llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM; - enum llama_vocab_pre_type type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; - - std::unordered_map token_to_id; - std::vector id_to_token; - - std::vector cache_special_tokens; - std::vector cache_token_to_piece; // llama_token_to_piece(special = true); - - std::map, int> bpe_ranks; - - // default LLaMA special tokens - id special_bos_id = 1; - id special_eos_id = 2; - id special_unk_id = 0; - id special_sep_id = -1; - id special_pad_id = -1; - id special_cls_id = -1; - id special_mask_id = -1; - - int special_add_bos = -1; // -1 unknown, 1 add, 0 don't add. - int special_add_eos = -1; // -1 unknown, 1 add, 0 don't add. - - id linefeed_id = 13; - id special_prefix_id = -1; - id special_suffix_id = -1; - id special_middle_id = -1; - id special_eot_id = -1; // TODO: move above after "eos_id", and here add "file separator" token - - bool add_space_prefix = true; - - int find_bpe_rank(const std::string & token_left, const std::string & token_right) const { - GGML_ASSERT(token_left.find(' ') == std::string::npos); - GGML_ASSERT(token_left.find('\n') == std::string::npos); - GGML_ASSERT(token_right.find(' ') == std::string::npos); - GGML_ASSERT(token_right.find('\n') == std::string::npos); - - auto it = bpe_ranks.find(std::make_pair(token_left, token_right)); - if (it == bpe_ranks.end()) { - return -1; - } - - return it->second; - } -}; - struct llama_model { e_model type = MODEL_UNKNOWN; llm_arch arch = LLM_ARCH_UNKNOWN; @@ -2248,6 +2609,7 @@ struct llama_model { struct ggml_tensor * output_norm_b; struct ggml_tensor * output; struct ggml_tensor * output_b; + struct ggml_tensor * output_norm_enc; std::vector layers; @@ -2293,6 +2655,9 @@ struct llama_model { int64_t t_load_us = 0; int64_t t_start_us = 0; + // keep track of loaded lora adapters + std::set lora_adapters; + ~llama_model() { for (struct ggml_context * ctx : ctxs) { ggml_free(ctx); @@ -2305,11 +2670,19 @@ struct llama_model { #endif ggml_backend_buffer_free(buf); } + while (!lora_adapters.empty()) { + llama_lora_adapter_free(*lora_adapters.begin()); + } } }; struct llama_context { - llama_context(const llama_model & model) : model(model), t_start_us(model.t_start_us), t_load_us(model.t_load_us) {} + llama_context(const llama_model & model) + : model(model) + , sampling(llama_n_vocab(&model)) + , t_start_us(model.t_start_us) + , t_load_us(model.t_load_us) {} + ~llama_context() { ggml_backend_sched_free(sched); @@ -2320,33 +2693,34 @@ struct llama_context { ggml_backend_buffer_free(buf_output); } - llama_cparams cparams; + const struct llama_model & model; + + struct llama_cparams cparams; + struct llama_sampling sampling; + struct llama_kv_cache kv_self; + struct llama_control_vector cvec; + + std::unordered_map lora_adapters; std::vector backends; #ifdef GGML_USE_METAL ggml_backend_t backend_metal = nullptr; +#endif +#ifdef GGML_USE_BLAS + ggml_backend_t backend_blas = nullptr; #endif ggml_backend_t backend_cpu = nullptr; - const llama_model & model; - - // key + value cache for the self attention - struct llama_kv_cache kv_self; - - std::mt19937 rng; - bool has_evaluated_once = false; int64_t t_start_us; int64_t t_load_us; - int64_t t_sample_us = 0; int64_t t_p_eval_us = 0; int64_t t_eval_us = 0; int64_t t_compute_start_us = 0; int64_t n_queued_tokens = 0; - int32_t n_sample = 0; // number of tokens sampled int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1) int32_t n_eval = 0; // number of eval calls @@ -2372,6 +2746,13 @@ struct llama_context { // populated only when pooling_type != LLAMA_POOLING_TYPE_NONE std::map> embd_seq; + // whether we are computing encoder output or decoder output + bool is_encoding = false; + + // output of the encoder part of the encoder-decoder models + std::vector embd_enc; + std::vector> seq_ids_enc; + // memory buffers used to evaluate the model std::vector buf_compute_meta; ggml_backend_sched_t sched = nullptr; @@ -2380,20 +2761,64 @@ struct llama_context { void * abort_callback_data = nullptr; // input tensors - struct ggml_tensor * inp_tokens; // I32 [n_batch] - struct ggml_tensor * inp_embd; // F32 [n_embd, n_batch] - struct ggml_tensor * inp_pos; // I32 [n_batch] - struct ggml_tensor * inp_out_ids; // I32 [n_outputs] - struct ggml_tensor * inp_KQ_mask; // F32 [kv_size, n_batch] - struct ggml_tensor * inp_K_shift; // I32 [kv_size] - struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch] - struct ggml_tensor * inp_cls; // I32 [n_batch] - struct ggml_tensor * inp_s_copy; // I32 [kv_size] - struct ggml_tensor * inp_s_mask; // F32 [1, n_kv] - struct ggml_tensor * inp_s_seq; // I32 [n_kv, n_batch] + struct ggml_tensor * inp_tokens; // I32 [n_batch] + struct ggml_tensor * inp_embd; // F32 [n_embd, n_batch] + struct ggml_tensor * inp_pos; // I32 [n_batch] + struct ggml_tensor * inp_out_ids; // I32 [n_outputs] + struct ggml_tensor * inp_KQ_mask; // F32 [kv_size, n_batch] + struct ggml_tensor * inp_KQ_mask_swa; // F32 [kv_size, n_batch] + struct ggml_tensor * inp_K_shift; // I32 [kv_size] + struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch] + struct ggml_tensor * inp_cls; // I32 [n_batch] + struct ggml_tensor * inp_s_copy; // I32 [kv_size] + struct ggml_tensor * inp_s_mask; // F32 [1, n_kv] + struct ggml_tensor * inp_s_seq; // I32 [n_kv, n_batch] + struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch] + struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc] + struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch] +}; - // control vectors - struct llama_control_vector cvec; +struct llama_lora_weight { + struct ggml_tensor * a = nullptr; + struct ggml_tensor * b = nullptr; + llama_lora_weight() = default; + llama_lora_weight(struct ggml_tensor * a, struct ggml_tensor * b): a(a), b(b) {} +}; + +struct llama_lora_adapter { + struct llama_model * base_model; + // map tensor name to lora_a_b + std::unordered_map ab_map; + std::vector ctxs; + std::vector bufs; + + float alpha; + + llama_lora_adapter(struct llama_model * base_model): base_model(base_model) { + base_model->lora_adapters.insert(this); + } + + llama_lora_weight * get_weight(struct ggml_tensor * w) { + std::string name(w->name); + auto pos = ab_map.find(name); + if (ab_map.find(name) != ab_map.end()) { + return &pos->second; + } + return nullptr; + } + + ~llama_lora_adapter() { + for (struct ggml_context * ctx : ctxs) { + ggml_free(ctx); + } + for (ggml_backend_buffer_t buf : bufs) { + ggml_backend_buffer_free(buf); + } + auto pos = base_model->lora_adapters.find(this); + if (pos != base_model->lora_adapters.end()) { + base_model->lora_adapters.erase(pos); + } + } }; static size_t llama_get_device_count(const llama_model & model) { @@ -2404,6 +2829,8 @@ static size_t llama_get_device_count(const llama_model & model) { count = ggml_backend_sycl_get_device_count(); #elif defined(GGML_USE_VULKAN) count = ggml_backend_vk_get_device_count(); +#elif defined(GGML_USE_CANN) + return ggml_backend_cann_get_device_count(); #endif #if defined(GGML_USE_RPC) count += model.rpc_servers.size(); @@ -2436,6 +2863,8 @@ static ggml_backend_buffer_type_t llama_default_buffer_type_offload(const llama_ if (buft == nullptr) { LLAMA_LOG_WARN("%s: cannot use GPU %d, check `vulkaninfo --summary`\n", __func__, gpu); } +#elif defined(GGML_USE_CANN) + buft = ggml_backend_cann_buffer_type(gpu); #endif if (buft == nullptr) { @@ -2496,6 +2925,11 @@ static size_t llama_get_device_memory(const llama_model & model, int device) { size_t free; ggml_backend_vk_get_device_memory(device, &free, &total); return free; +#elif defined(GGML_USE_CANN) + size_t total; + size_t free; + ggml_backend_cann_get_device_memory(device, &free, &total); + return free; #else return 1; #endif @@ -2519,22 +2953,13 @@ static bool llama_kv_cache_init( const struct llama_hparams & hparams = model.hparams; - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s(); - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s(); - const int64_t n_layer = hparams.n_layer; + const int64_t n_layer = hparams.n_layer; cache.has_shift = false; // TODO: find a nicer way to add other recurrent model architectures cache.recurrent = model.arch == LLM_ARCH_MAMBA; - cache.v_trans = !cparams.flash_attn; - - // TODO: support mixed recurrent Transformer architectures - // NOTE: (!a || b) is a logical implication (a -> b) - GGML_ASSERT(!cache.recurrent || n_embd_k_gqa == hparams.n_embd_k_s()); - GGML_ASSERT(!cache.recurrent || n_embd_v_gqa == hparams.n_embd_v_s()); - GGML_ASSERT( cache.recurrent || n_embd_k_gqa == hparams.n_embd_k_gqa()); - GGML_ASSERT( cache.recurrent || n_embd_v_gqa == hparams.n_embd_v_gqa()); + cache.v_trans = !cache.recurrent && !cparams.flash_attn; cache.head = 0; cache.size = kv_size; @@ -2585,6 +3010,9 @@ static bool llama_kv_cache_init( cache.v_l.reserve(n_layer); for (int i = 0; i < (int) n_layer; i++) { + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(); + struct ggml_context * ctx = offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front(); ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size); ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size); @@ -2865,6 +3293,8 @@ static void llama_kv_cache_seq_add( if (p0 < 0) p0 = 0; if (p1 < 0) p1 = std::numeric_limits::max(); + // If there is no range then return early to avoid looping over the cache. + if (p0 == p1) return; if (cache.recurrent) { // for Mamba-like models, only the pos needs to be shifted @@ -2909,6 +3339,8 @@ static void llama_kv_cache_seq_div( int d) { if (p0 < 0) p0 = 0; if (p1 < 0) p1 = std::numeric_limits::max(); + // If there is no range then return early to avoid looping over the cache. + if (p0 == p1) return; if (cache.recurrent) { // for Mamba-like models, only the pos needs to be changed @@ -3161,6 +3593,15 @@ namespace GGUFMeta { using llama_buf_map = std::unordered_map; +// TODO: update when needed or think of some clever automatic way to do this +static size_t llama_model_max_nodes(const llama_model & /*model*/) { + //if (model.arch == LLM_ARCH_LLAMA && model.hparams.n_layer > ??) { // llama-3 405B + // return 32768; + //} + + return 8192; +} + struct llama_model_loader { int n_kv = 0; int n_tensors = 0; @@ -3211,7 +3652,7 @@ struct llama_model_loader { } if (param_overrides_p != nullptr) { - for (const struct llama_model_kv_override *p = param_overrides_p; p->key[0] != 0; p++) { + for (const struct llama_model_kv_override * p = param_overrides_p; p->key[0] != 0; p++) { kv_overrides.insert({std::string(p->key), *p}); } } @@ -3365,6 +3806,9 @@ struct llama_model_loader { case GGML_TYPE_IQ4_NL: ftype = LLAMA_FTYPE_MOSTLY_IQ4_NL; break; case GGML_TYPE_IQ4_XS: ftype = LLAMA_FTYPE_MOSTLY_IQ4_XS; break; case GGML_TYPE_IQ3_S: ftype = LLAMA_FTYPE_MOSTLY_IQ3_S; break; + case GGML_TYPE_Q4_0_4_4: ftype = LLAMA_FTYPE_MOSTLY_Q4_0_4_4; break; + case GGML_TYPE_Q4_0_4_8: ftype = LLAMA_FTYPE_MOSTLY_Q4_0_4_8; break; + case GGML_TYPE_Q4_0_8_8: ftype = LLAMA_FTYPE_MOSTLY_Q4_0_8_8; break; default: { LLAMA_LOG_WARN("%s: unknown type %s\n", __func__, ggml_type_name(type_max)); @@ -3376,7 +3820,7 @@ struct llama_model_loader { ftype = (llama_ftype) (ftype | LLAMA_FTYPE_GUESSED); { - const int kid = gguf_find_key(meta, "general.file_type"); + const int kid = gguf_find_key(meta, "general.file_type"); // TODO: use LLM_KV if (kid >= 0) { ftype = (llama_ftype) gguf_get_val_u32(meta, kid); } @@ -3460,9 +3904,9 @@ struct llama_model_loader { bool get_arr(const std::string & key, std::vector & result, const bool required = true) { const int kid = gguf_find_key(meta, key.c_str()); - if (kid < 0) { + if (kid < 0 || gguf_get_kv_type(meta, kid) != GGUF_TYPE_ARRAY) { if (required) { - throw std::runtime_error(format("key not found in model: %s", key.c_str())); + throw std::runtime_error(format("array key not found in model: %s", key.c_str())); } return false; } @@ -3470,22 +3914,55 @@ struct llama_model_loader { struct GGUFMeta::ArrayInfo arr_info = GGUFMeta::GKV::get_kv(meta, kid); - if (arr_info.gt != GGUF_TYPE_FLOAT32 && arr_info.gt != GGUF_TYPE_INT32) { - throw std::runtime_error(format("%s is not a float32 or int32 array", key.c_str())); + switch (arr_info.gt) { + case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same::value)); break; + case GGUF_TYPE_INT32: GGML_ASSERT( + (std::is_same::value) || + (std::is_same::value)); break; + default: + throw std::runtime_error(format("%s is not a float32, int32 array", key.c_str())); } - // GGML_ASSERT(gguf_type_size(arr_info.gt) == sizeof(T)); - GGML_ASSERT((arr_info.gt != GGUF_TYPE_FLOAT32 || std::is_same::value)); - GGML_ASSERT((arr_info.gt != GGUF_TYPE_INT32 || std::is_same::value)); - result.resize(arr_info.length); result.assign((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length); return true; } + template + bool get_arr(const std::string & key, std::array & result, const bool required = true) { + const int kid = gguf_find_key(meta, key.c_str()); + + if (kid < 0 || gguf_get_kv_type(meta, kid) != GGUF_TYPE_ARRAY) { + if (required) { + throw std::runtime_error(format("array key not found in model: %s", key.c_str())); + } + return false; + } + + struct GGUFMeta::ArrayInfo arr_info = + GGUFMeta::GKV::get_kv(meta, kid); + + switch (arr_info.gt) { + case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same::value)); break; + case GGUF_TYPE_INT32: GGML_ASSERT( + (std::is_same::value) || + (std::is_same::value)); break; + default: + throw std::runtime_error(format("%s is not a float32, int32 array", key.c_str())); + } + + if (arr_info.length > N_MAX) { + throw std::runtime_error(format("array length %u for key %s exceeds max %u", (uint32_t) arr_info.length, key.c_str(), (uint32_t) N_MAX)); + } + + std::copy((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length, result.begin()); + + return true; + } + template - bool get_arr(const enum llm_kv kid, T& result, const bool required = true) { + bool get_arr(const enum llm_kv kid, T & result, const bool required = true) { return get_arr(llm_kv(kid), result, required); } @@ -3510,6 +3987,52 @@ struct llama_model_loader { return get_key(llm_kv(kid), result, required); } + // get array of n <= N_MAX elements, or a single element repeated n times + template + bool get_key_or_arr(const std::string & key, std::array & result, uint32_t n, const bool required = true) { + const int kid = gguf_find_key(meta, key.c_str()); + + if (kid < 0) { + if (required) { + throw std::runtime_error(format("key not found in model: %s", key.c_str())); + } + return false; + } + + if (n > N_MAX) { + throw std::runtime_error(format("n > N_MAX: %u > %u for key %s", (uint32_t) n, (uint32_t) N_MAX, key.c_str())); + } + + if (gguf_get_kv_type(meta, kid) == GGUF_TYPE_ARRAY) { + struct GGUFMeta::ArrayInfo arr_info = + GGUFMeta::GKV::get_kv(meta, kid); + + if (n != arr_info.length) { + throw std::runtime_error(format("key %s has wrong array length; expected %u, got %u", key.c_str(), n, (uint32_t) arr_info.length)); + } + + return get_arr(key, result, required); + } else { + T value; + + bool ok = get_key(key, value, required); + if (!ok) { + return false; + } + + for (uint32_t i = 0; i < n; i++) { + result[i] = value; + } + + return true; + } + } + + template + bool get_key_or_arr(const enum llm_kv kid, T & result, uint32_t n, const bool required = true) { + return get_key_or_arr(llm_kv(kid), result, n, required); + } + std::string get_arch_name() const { return arch_name; } @@ -3739,6 +4262,44 @@ struct llama_model_loader { std::vector> read_buf; std::vector>> validation_result; +#if defined(GGML_USE_CUDA) + // 4 staging buffers for async uploads, each sized 1MB seems to be a good default for single NVMe drives. + // NVMe raid configurations might require more / larger buffers. + constexpr size_t n_buffers = 4; + constexpr size_t buffer_size = 1 * 1024 * 1024; // 1MB + + std::vector host_buffers; + std::vector host_ptrs; + std::vector events; + size_t buffer_idx = 0; // buffer to use for async loads + + ggml_backend_t cuda_backend = nullptr; + if (!use_mmap && !check_tensors) { + // When not using mmaped io use async uploads from pinned memory to GPU memory. + // First determine if the CUDA backend is active, and if so, determine the device ID. + ggml_backend_buffer_t buf = bufs_mmap.count(0) ? bufs_mmap.at(0) : nullptr; + if (buf) { + ggml_backend_buffer_type_t buffer_type = ggml_backend_buffer_get_type(buf); + for (int i = 0; i < ggml_backend_cuda_get_device_count(); ++i) { + auto * cuda_buffer_type = ggml_backend_cuda_buffer_type(i); + if (buffer_type == cuda_buffer_type) { + cuda_backend = ggml_backend_cuda_init(i); + break; + } + } + } + + // If the cuda backend is active create pinned memory buffers and events for synchronisation. + if (cuda_backend) { + for (size_t idx = 0; idx < n_buffers; ++idx) { + host_buffers.emplace_back(ggml_backend_buft_alloc_buffer(llama_default_buffer_type_cpu(true), buffer_size)); + host_ptrs.emplace_back(ggml_backend_buffer_get_base(host_buffers[idx])); + events.emplace_back(ggml_backend_event_new(cuda_backend)); + } + } + } +#endif + for (struct ggml_tensor * cur = ggml_get_first_tensor(ctx); cur != NULL; cur = ggml_get_next_tensor(ctx, cur)) { const auto * weight = get_weight(ggml_get_name(cur)); if (weight == nullptr) { @@ -3794,12 +4355,36 @@ struct llama_model_loader { })); } } else { - read_buf.resize(n_size); - file->seek(weight->offs, SEEK_SET); - file->read_raw(read_buf.data(), n_size); - ggml_backend_tensor_set(cur, read_buf.data(), 0, n_size); - if (check_tensors && !ggml_validate_row_data(cur->type, read_buf.data(), n_size)) { - throw std::runtime_error(format("tensor '%s' has invalid data", ggml_get_name(cur))); +#if defined(GGML_USE_CUDA) + // If cuda_backend is valid load the tensor in chunks to pinned memory and upload the buffers asynchronously to the GPU. + if (cuda_backend) { + file->seek(weight->offs, SEEK_SET); + + size_t bytes_read = 0; + + while (bytes_read < n_size) { + size_t read_iteration = std::min(buffer_size, n_size - bytes_read); + + ggml_backend_event_synchronize(events[buffer_idx]); + file->read_raw(host_ptrs[buffer_idx], read_iteration); + ggml_backend_tensor_set_async(cuda_backend, cur, host_ptrs[buffer_idx], bytes_read, read_iteration); + ggml_backend_event_record(events[buffer_idx]); + + bytes_read += read_iteration; + ++buffer_idx; + buffer_idx %= n_buffers; + } + } + else +#endif + { + read_buf.resize(n_size); + file->seek(weight->offs, SEEK_SET); + file->read_raw(read_buf.data(), n_size); + ggml_backend_tensor_set(cur, read_buf.data(), 0, n_size); + if (check_tensors && !ggml_validate_row_data(cur->type, read_buf.data(), n_size)) { + throw std::runtime_error(format("tensor '%s' has invalid data", ggml_get_name(cur))); + } } } } @@ -3807,6 +4392,18 @@ struct llama_model_loader { size_done += n_size; } +#if defined(GGML_USE_CUDA) + // free temporary resources used for async cuda uploads + if (cuda_backend) { + for (size_t idx = 0; idx < n_buffers;++idx) { + ggml_backend_event_synchronize(events[idx]); + ggml_backend_event_free(events[idx]); + ggml_backend_buffer_free(host_buffers[idx]); + } + ggml_backend_free(cuda_backend); + } +#endif + // check validation results bool validation_failed = false; for (auto & future : validation_result) { @@ -3875,40 +4472,39 @@ static std::string llama_model_ftype_name(llama_ftype ftype) { } switch (ftype) { - case LLAMA_FTYPE_ALL_F32: return "all F32"; - case LLAMA_FTYPE_MOSTLY_F16: return "F16"; - case LLAMA_FTYPE_MOSTLY_BF16: return "BF16"; - case LLAMA_FTYPE_MOSTLY_Q4_0: return "Q4_0"; - case LLAMA_FTYPE_MOSTLY_Q4_1: return "Q4_1"; - case LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16: - return "Q4_1, some F16"; - case LLAMA_FTYPE_MOSTLY_Q5_0: return "Q5_0"; - case LLAMA_FTYPE_MOSTLY_Q5_1: return "Q5_1"; - case LLAMA_FTYPE_MOSTLY_Q8_0: return "Q8_0"; - - // K-quants - case LLAMA_FTYPE_MOSTLY_Q2_K: return "Q2_K - Medium"; - case LLAMA_FTYPE_MOSTLY_Q2_K_S: return "Q2_K - Small"; - case LLAMA_FTYPE_MOSTLY_Q3_K_S: return "Q3_K - Small"; - case LLAMA_FTYPE_MOSTLY_Q3_K_M: return "Q3_K - Medium"; - case LLAMA_FTYPE_MOSTLY_Q3_K_L: return "Q3_K - Large"; - case LLAMA_FTYPE_MOSTLY_Q4_K_S: return "Q4_K - Small"; - case LLAMA_FTYPE_MOSTLY_Q4_K_M: return "Q4_K - Medium"; - case LLAMA_FTYPE_MOSTLY_Q5_K_S: return "Q5_K - Small"; - case LLAMA_FTYPE_MOSTLY_Q5_K_M: return "Q5_K - Medium"; - case LLAMA_FTYPE_MOSTLY_Q6_K: return "Q6_K"; - case LLAMA_FTYPE_MOSTLY_IQ2_XXS:return "IQ2_XXS - 2.0625 bpw"; - case LLAMA_FTYPE_MOSTLY_IQ2_XS: return "IQ2_XS - 2.3125 bpw"; - case LLAMA_FTYPE_MOSTLY_IQ2_S: return "IQ2_S - 2.5 bpw"; - case LLAMA_FTYPE_MOSTLY_IQ2_M: return "IQ2_M - 2.7 bpw"; - case LLAMA_FTYPE_MOSTLY_IQ3_XS: return "IQ3_XS - 3.3 bpw"; - case LLAMA_FTYPE_MOSTLY_IQ3_XXS:return "IQ3_XXS - 3.0625 bpw"; - case LLAMA_FTYPE_MOSTLY_IQ1_S :return "IQ1_S - 1.5625 bpw"; - case LLAMA_FTYPE_MOSTLY_IQ1_M :return "IQ1_M - 1.75 bpw"; - case LLAMA_FTYPE_MOSTLY_IQ4_NL: return "IQ4_NL - 4.5 bpw"; - case LLAMA_FTYPE_MOSTLY_IQ4_XS: return "IQ4_XS - 4.25 bpw"; - case LLAMA_FTYPE_MOSTLY_IQ3_S: return "IQ3_S - 3.4375 bpw"; - case LLAMA_FTYPE_MOSTLY_IQ3_M: return "IQ3_S mix - 3.66 bpw"; + case LLAMA_FTYPE_ALL_F32: return "all F32"; + case LLAMA_FTYPE_MOSTLY_F16: return "F16"; + case LLAMA_FTYPE_MOSTLY_BF16: return "BF16"; + case LLAMA_FTYPE_MOSTLY_Q4_0: return "Q4_0"; + case LLAMA_FTYPE_MOSTLY_Q4_1: return "Q4_1"; + case LLAMA_FTYPE_MOSTLY_Q5_0: return "Q5_0"; + case LLAMA_FTYPE_MOSTLY_Q5_1: return "Q5_1"; + case LLAMA_FTYPE_MOSTLY_Q8_0: return "Q8_0"; + case LLAMA_FTYPE_MOSTLY_Q2_K: return "Q2_K - Medium"; + case LLAMA_FTYPE_MOSTLY_Q2_K_S: return "Q2_K - Small"; + case LLAMA_FTYPE_MOSTLY_Q3_K_S: return "Q3_K - Small"; + case LLAMA_FTYPE_MOSTLY_Q3_K_M: return "Q3_K - Medium"; + case LLAMA_FTYPE_MOSTLY_Q3_K_L: return "Q3_K - Large"; + case LLAMA_FTYPE_MOSTLY_Q4_K_S: return "Q4_K - Small"; + case LLAMA_FTYPE_MOSTLY_Q4_K_M: return "Q4_K - Medium"; + case LLAMA_FTYPE_MOSTLY_Q5_K_S: return "Q5_K - Small"; + case LLAMA_FTYPE_MOSTLY_Q5_K_M: return "Q5_K - Medium"; + case LLAMA_FTYPE_MOSTLY_Q6_K: return "Q6_K"; + case LLAMA_FTYPE_MOSTLY_IQ2_XXS: return "IQ2_XXS - 2.0625 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ2_XS: return "IQ2_XS - 2.3125 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ2_S: return "IQ2_S - 2.5 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ2_M: return "IQ2_M - 2.7 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ3_XS: return "IQ3_XS - 3.3 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ3_XXS: return "IQ3_XXS - 3.0625 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ1_S: return "IQ1_S - 1.5625 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ1_M: return "IQ1_M - 1.75 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ4_NL: return "IQ4_NL - 4.5 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ4_XS: return "IQ4_XS - 4.25 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ3_S: return "IQ3_S - 3.4375 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ3_M: return "IQ3_S mix - 3.66 bpw"; + case LLAMA_FTYPE_MOSTLY_Q4_0_4_4: return "Q4_0_4_4"; + case LLAMA_FTYPE_MOSTLY_Q4_0_4_8: return "Q4_0_4_8"; + case LLAMA_FTYPE_MOSTLY_Q4_0_8_8: return "Q4_0_8_8"; default: return "unknown, may not work"; } @@ -3920,22 +4516,34 @@ static const char * llama_model_type_name(e_model type) { case MODEL_17M: return "17M"; case MODEL_22M: return "22M"; case MODEL_33M: return "33M"; + case MODEL_60M: return "60M"; case MODEL_70M: return "70M"; + case MODEL_80M: return "80M"; case MODEL_109M: return "109M"; case MODEL_137M: return "137M"; case MODEL_160M: return "160M"; + case MODEL_220M: return "220M"; + case MODEL_250M: return "250M"; + case MODEL_270M: return "270M"; case MODEL_335M: return "335M"; case MODEL_410M: return "410M"; + case MODEL_450M: return "450M"; + case MODEL_770M: return "770M"; + case MODEL_780M: return "780M"; case MODEL_0_5B: return "0.5B"; case MODEL_1B: return "1B"; + case MODEL_1_3B: return "1.3B"; case MODEL_1_4B: return "1.4B"; case MODEL_2B: return "2B"; case MODEL_2_8B: return "2.8B"; case MODEL_3B: return "3B"; case MODEL_4B: return "4B"; + case MODEL_6B: return "6B"; case MODEL_6_9B: return "6.9B"; case MODEL_7B: return "7B"; case MODEL_8B: return "8B"; + case MODEL_9B: return "9B"; + case MODEL_11B: return "11B"; case MODEL_12B: return "12B"; case MODEL_13B: return "13B"; case MODEL_14B: return "14B"; @@ -3959,6 +4567,8 @@ static const char * llama_model_type_name(e_model type) { case MODEL_8x22B: return "8x22B"; case MODEL_16x12B: return "16x12B"; case MODEL_10B_128x3_66B: return "10B+128x3.66B"; + case MODEL_57B_A14B: return "57B.A14B"; + case MODEL_27B: return "27B"; default: return "?B"; } } @@ -3969,6 +4579,7 @@ static const char * llama_model_vocab_type_name(enum llama_vocab_type type){ case LLAMA_VOCAB_TYPE_SPM: return "SPM"; case LLAMA_VOCAB_TYPE_BPE: return "BPE"; case LLAMA_VOCAB_TYPE_WPM: return "WPM"; + case LLAMA_VOCAB_TYPE_UGM: return "UGM"; default: return "unknown"; } } @@ -4001,20 +4612,18 @@ static void llm_load_hparams( ml.get_key(LLM_KV_GENERAL_NAME, model.name, false); // get hparams kv - ml.get_key(LLM_KV_VOCAB_SIZE, hparams.n_vocab, false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, hparams.n_vocab); + ml.get_key(LLM_KV_VOCAB_SIZE, hparams.n_vocab, false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, hparams.n_vocab); // everything past this point is not vocab-related if (hparams.vocab_only) { return; } - ml.get_key(LLM_KV_CONTEXT_LENGTH, hparams.n_ctx_train); - ml.get_key(LLM_KV_EMBEDDING_LENGTH, hparams.n_embd); - ml.get_key(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff); - ml.get_key(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head); - ml.get_key(LLM_KV_BLOCK_COUNT, hparams.n_layer); - ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert, false); - ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used, false); + ml.get_key(LLM_KV_CONTEXT_LENGTH, hparams.n_ctx_train); + ml.get_key(LLM_KV_EMBEDDING_LENGTH, hparams.n_embd); + ml.get_key(LLM_KV_BLOCK_COUNT, hparams.n_layer); + ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert, false); + ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used, false); GGML_ASSERT(hparams.n_expert <= LLAMA_MAX_EXPERTS); GGML_ASSERT(hparams.n_expert_used <= hparams.n_expert); @@ -4024,9 +4633,18 @@ static void llm_load_hparams( GGML_ASSERT(hparams.n_expert_used == 0); } + // zero-out the per-layer hparams + std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0); + std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0); + std::fill(hparams.n_ff_arr.begin(), hparams.n_ff_arr.end(), 0); + + ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, hparams.n_layer); + ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer); + // n_head_kv is optional, default to n_head - hparams.n_head_kv = hparams.n_head; - ml.get_key(LLM_KV_ATTENTION_HEAD_COUNT_KV, hparams.n_head_kv, false); + hparams.n_head_kv_arr = hparams.n_head_arr; + + ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT_KV, hparams.n_head_kv_arr, hparams.n_layer, false); bool rope_finetuned = false; ml.get_key(LLM_KV_ROPE_SCALING_FINETUNED, rope_finetuned, false); @@ -4054,27 +4672,33 @@ static void llm_load_hparams( ml.get_key(LLM_KV_ROPE_SCALING_ATTN_FACTOR, hparams.rope_attn_factor, false); - // sanity check for n_rot (optional) - { - hparams.n_rot = (hparams.n_head == 0) ? 0 : hparams.n_embd / hparams.n_head; + // non-transformer models do not have attention heads + if (hparams.n_head() > 0) { + // gpt-neox n_rot = rotary_pct * (n_embd / n_head) + // gpt-j n_rot = rotary_dim + + hparams.n_embd_head_k = hparams.n_embd / hparams.n_head(); + ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k, false); + + hparams.n_embd_head_v = hparams.n_embd / hparams.n_head(); + ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v, false); + + // sanity check for n_rot (optional) + hparams.n_rot = hparams.n_embd_head_k; ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false); if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_FALCON) { - if (hparams.n_rot != hparams.n_embd / hparams.n_head) { - throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd / hparams.n_head)); + if (hparams.n_rot != hparams.n_embd_head_k) { + throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd_head_k)); } } - // gpt-neox n_rot = rotary_pct * (n_embd / n_head) - // gpt-j n_rot = rotary_dim + } else { + hparams.n_rot = 0; + hparams.n_embd_head_k = 0; + hparams.n_embd_head_v = 0; } - hparams.n_embd_head_k = (hparams.n_head == 0) ? 0 : hparams.n_embd / hparams.n_head; - ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k, false); - - hparams.n_embd_head_v = (hparams.n_head == 0) ? 0 : hparams.n_embd / hparams.n_head; - ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v, false); - // arch-specific KVs switch (model.arch) { case LLM_ARCH_LLAMA: @@ -4097,7 +4721,7 @@ static void llm_load_hparams( case 40: model.type = e_model::MODEL_13B; break; case 48: model.type = e_model::MODEL_34B; break; case 60: model.type = e_model::MODEL_30B; break; - case 80: model.type = hparams.n_head == hparams.n_head_kv ? e_model::MODEL_65B : e_model::MODEL_70B; break; + case 80: model.type = hparams.n_head() == hparams.n_head_kv() ? e_model::MODEL_65B : e_model::MODEL_70B; break; default: model.type = e_model::MODEL_UNKNOWN; } } @@ -4266,16 +4890,20 @@ static void llm_load_hparams( switch (hparams.n_layer) { case 24: model.type = hparams.n_embd == 1024 ? e_model::MODEL_0_5B : e_model::MODEL_1B; break; case 32: model.type = e_model::MODEL_7B; break; - case 40: model.type = hparams.n_head == 20 ? e_model::MODEL_4B : e_model::MODEL_13B; break; + case 40: model.type = hparams.n_head() == 20 ? e_model::MODEL_4B : e_model::MODEL_13B; break; case 80: model.type = e_model::MODEL_70B; break; default: model.type = e_model::MODEL_UNKNOWN; } } break; case LLM_ARCH_QWEN2MOE: { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { case 24: model.type = e_model::MODEL_A2_7B; break; + case 28: model.type = e_model::MODEL_57B_A14B; break; default: model.type = e_model::MODEL_UNKNOWN; } } break; @@ -4291,6 +4919,7 @@ static void llm_load_hparams( } break; case LLM_ARCH_PHI3: { + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { @@ -4324,7 +4953,7 @@ static void llm_load_hparams( { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); switch (hparams.n_layer) { - case 42: model.type = e_model::MODEL_SMALL; break; + case 42: model.type = e_model::MODEL_7B; break; default: model.type = e_model::MODEL_UNKNOWN; } } break; @@ -4356,6 +4985,21 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_GEMMA2: + { + hparams.n_swa = 4096; // default value of gemma 2 + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false); + ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false); + hparams.attn_soft_cap = true; + + switch (hparams.n_layer) { + case 42: model.type = e_model::MODEL_9B; break; + case 46: model.type = e_model::MODEL_27B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; case LLM_ARCH_STARCODER2: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); @@ -4439,46 +5083,58 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_OPENELM: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 16: model.type = e_model::MODEL_270M; break; + case 20: model.type = e_model::MODEL_450M; break; + case 28: model.type = e_model::MODEL_1B; break; + case 36: model.type = e_model::MODEL_3B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; case LLM_ARCH_GPTNEOX: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); ml.get_key(LLM_KV_USE_PARALLEL_RESIDUAL, hparams.use_par_res); switch (hparams.n_layer) { case 6: - switch (hparams.n_ff) { + switch (hparams.n_ff()) { case 512: model.type = e_model::MODEL_14M; break; case 2048: model.type = e_model::MODEL_70M; break; default: model.type = e_model::MODEL_UNKNOWN; } break; case 12: - switch (hparams.n_ff) { + switch (hparams.n_ff()) { case 3072: model.type = e_model::MODEL_160M; break; default: model.type = e_model::MODEL_UNKNOWN; } break; case 16: - switch (hparams.n_ff) { + switch (hparams.n_ff()) { case 8192: model.type = e_model::MODEL_1B; break; default: model.type = e_model::MODEL_UNKNOWN; } break; case 24: - switch (hparams.n_ff) { + switch (hparams.n_ff()) { case 4096: model.type = e_model::MODEL_410M; break; case 8192: model.type = e_model::MODEL_1_4B; break; default: model.type = e_model::MODEL_UNKNOWN; } break; case 32: - switch (hparams.n_ff) { + switch (hparams.n_ff()) { case 10240: model.type = e_model::MODEL_2_8B; break; case 16384: model.type = e_model::MODEL_6_9B; break; default: model.type = e_model::MODEL_UNKNOWN; } break; case 36: - switch (hparams.n_ff) { + switch (hparams.n_ff()) { case 20480: model.type = e_model::MODEL_12B; break; default: model.type = e_model::MODEL_UNKNOWN; } break; case 44: - switch (hparams.n_ff) { + switch (hparams.n_ff()) { case 24576: model.type = e_model::MODEL_20B; break; default: model.type = e_model::MODEL_UNKNOWN; } break; @@ -4518,6 +5174,68 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_CHATGLM: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 28: model.type = e_model::MODEL_6B; break; + case 40: model.type = e_model::MODEL_9B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; + case LLM_ARCH_BITNET: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 26: model.type = e_model::MODEL_3B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; + case LLM_ARCH_T5: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, hparams.n_rel_attn_bkts); + + uint32_t dec_start_token_id; + if (ml.get_key(LLM_KV_DECODER_START_TOKEN_ID, dec_start_token_id, false)) { + hparams.dec_start_token_id = dec_start_token_id; + } + + switch (hparams.n_layer) { + case 6: model.type = e_model::MODEL_60M; break; // t5-small + case 8: model.type = e_model::MODEL_80M; break; // flan-t5-small + case 12: + switch (hparams.n_ff()) { + case 3072: model.type = e_model::MODEL_220M; break; // t5-base + case 2048: model.type = e_model::MODEL_250M; break; // flan-t5-base + default: model.type = e_model::MODEL_UNKNOWN; + } break; + case 24: + switch (hparams.n_ff()) { + case 4096: model.type = e_model::MODEL_770M; break; // t5-large + case 2816: model.type = e_model::MODEL_780M; break; // flan-t5-large + case 16384: model.type = e_model::MODEL_3B; break; // t5-3b + case 5120: model.type = e_model::MODEL_3B; break; // flan-t5-xl + case 65536: model.type = e_model::MODEL_11B; break; // t5-11b + case 10240: model.type = e_model::MODEL_11B; break; // flan-t5-xxl + default: model.type = e_model::MODEL_UNKNOWN; + } break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; + case LLM_ARCH_JAIS: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_ATTENTION_MAX_ALIBI_BIAS, hparams.f_max_alibi_bias); + + switch (hparams.n_layer) { + case 24: model.type = e_model::MODEL_1_3B; break; + case 40: model.type = e_model::MODEL_13B; break; + /* TODO: add variants */ + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; default: (void)0; } @@ -4530,12 +5248,6 @@ static void llm_load_hparams( hparams.rope_type = llama_rope_type(&model); } -// TODO: This should probably be in llama.h -static std::vector llama_tokenize_internal( - const llama_vocab & vocab, std::string raw_text, bool add_special, bool parse_special = false -); -static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch); - static void llm_load_vocab( llama_model_loader & ml, llama_model & model) { @@ -4578,40 +5290,6 @@ static void llm_load_vocab( vocab.special_pad_id = -1; vocab.special_cls_id = -1; vocab.special_mask_id = -1; - - // For Fill-In-the-Middle (FIM)/infill models which where converted - // prior to support of FIM special tokens in GGUF, the following - // will allow those models to continue to work. The general names - // of the known models are currently CodeLlama (LLM_ARCH_LLAMA) and - // CodeGemma (LLM_ARCH_GEMMA). This can potentially be removed once - // new versions of these models have been published. - std::string gen_name; - ml.get_key(LLM_KV_GENERAL_NAME, gen_name, false); - - std::transform(gen_name.begin(), gen_name.end(), gen_name.begin(), - [](unsigned char c){ return std::tolower(c); }); - - if (gen_name.find("code") != std::string::npos) { - if (model.arch == LLM_ARCH_LLAMA) { - vocab.special_prefix_id = 32007; - vocab.special_suffix_id = 32008; - vocab.special_middle_id = 32009; - vocab.special_eot_id = 32010; - } else if (model.arch == LLM_ARCH_GEMMA) { - vocab.special_prefix_id = 67; - vocab.special_suffix_id = 69; - vocab.special_middle_id = 68; - // TODO: this is not EOT, it is "file separator" token, needs fix - // https://huggingface.co/google/codegemma-7b-it/blob/9b1d9231388358c04d90bd003458f5070d97db44/tokenizer_config.json#L565-L572 - //vocab.special_eot_id = 70; - vocab.special_eot_id = 107; - } - } - - const int add_space_prefix_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_ADD_PREFIX).c_str()); - if (add_space_prefix_keyidx != -1) { - vocab.add_space_prefix = gguf_get_val_bool(ctx, add_space_prefix_keyidx); - } // The default value of add_space_prefix is true. } else if (tokenizer_model == "bert") { vocab.type = LLAMA_VOCAB_TYPE_WPM; @@ -4623,15 +5301,9 @@ static void llm_load_vocab( vocab.special_pad_id = 0; vocab.special_cls_id = 101; vocab.special_mask_id = 103; - vocab.add_space_prefix = false; } else if (tokenizer_model == "gpt2") { vocab.type = LLAMA_VOCAB_TYPE_BPE; - const int add_space_prefix_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_ADD_PREFIX).c_str()); - if (add_space_prefix_keyidx != -1) { - vocab.add_space_prefix = gguf_get_val_bool(ctx, add_space_prefix_keyidx); - } - // read bpe merges and populate bpe ranks const int merges_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_MERGES).c_str()); if (merges_keyidx == -1) { @@ -4639,7 +5311,6 @@ static void llm_load_vocab( } const int n_merges = gguf_get_arr_n(ctx, merges_keyidx); - for (int i = 0; i < n_merges; i++) { const std::string word = gguf_get_arr_str(ctx, merges_keyidx, i); GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0); @@ -4665,12 +5336,43 @@ static void llm_load_vocab( vocab.special_pad_id = -1; vocab.special_cls_id = -1; vocab.special_mask_id = -1; + } else if (tokenizer_model == "t5") { + vocab.type = LLAMA_VOCAB_TYPE_UGM; + + // default special tokens + vocab.special_bos_id = -1; + vocab.special_eos_id = 1; + vocab.special_unk_id = 2; + vocab.special_sep_id = -1; + vocab.special_pad_id = 0; + vocab.special_cls_id = -1; + vocab.special_mask_id = -1; + + const int precompiled_charsmap_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP).c_str()); + if (precompiled_charsmap_keyidx != -1) { + size_t n_precompiled_charsmap = gguf_get_arr_n(ctx, precompiled_charsmap_keyidx); + const char * precompiled_charsmap = (const char *) gguf_get_arr_data(ctx, precompiled_charsmap_keyidx); + vocab.precompiled_charsmap.assign(precompiled_charsmap, precompiled_charsmap + n_precompiled_charsmap); +#ifdef IS_BIG_ENDIAN + // correct endiannes of data in precompiled_charsmap binary blob + uint32_t * xcda_blob_size = (uint32_t *) &vocab.precompiled_charsmap[0]; + *xcda_blob_size = __builtin_bswap32(*xcda_blob_size); + assert(*xcda_blob_size + sizeof(uint32_t) < n_precompiled_charsmap); + size_t xcda_array_size = *xcda_blob_size / sizeof(uint32_t); + uint32_t * xcda_array = (uint32_t *) &vocab.precompiled_charsmap[sizeof(uint32_t)]; + for (size_t i = 0; i < xcda_array_size; ++i) { + xcda_array[i] = __builtin_bswap32(xcda_array[i]); + } +#endif + } } else { throw std::runtime_error(format("unknown tokenizer: '%s'", tokenizer_model.c_str())); } // for now, only BPE models have pre-tokenizers if (vocab.type == LLAMA_VOCAB_TYPE_BPE) { + vocab.tokenizer_add_space_prefix = false; + vocab.tokenizer_clean_spaces = true; if (tokenizer_pre == "default") { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; } else if ( @@ -4678,12 +5380,16 @@ static void llm_load_vocab( tokenizer_pre == "llama-v3" || tokenizer_pre == "llama-bpe") { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_LLAMA3; + vocab.tokenizer_ignore_merges = true; + vocab.tokenizer_add_bos = true; } else if ( tokenizer_pre == "deepseek-llm") { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM; + vocab.tokenizer_clean_spaces = false; } else if ( tokenizer_pre == "deepseek-coder") { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER; + vocab.tokenizer_clean_spaces = false; } else if ( tokenizer_pre == "falcon") { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_FALCON; @@ -4695,6 +5401,7 @@ static void llm_load_vocab( vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_STARCODER; } else if ( tokenizer_pre == "gpt-2" || + tokenizer_pre == "phi-2" || tokenizer_pre == "jina-es" || tokenizer_pre == "jina-de" || tokenizer_pre == "jina-v2-es" || @@ -4707,9 +5414,11 @@ static void llm_load_vocab( } else if ( tokenizer_pre == "command-r") { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_COMMAND_R; + vocab.tokenizer_clean_spaces = false; } else if ( tokenizer_pre == "qwen2") { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_QWEN2; + vocab.tokenizer_clean_spaces = false; } else if ( tokenizer_pre == "stablelm2") { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_STABLELM2; @@ -4722,13 +5431,60 @@ static void llm_load_vocab( } else if ( tokenizer_pre == "smaug-bpe") { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_SMAUG; + } else if ( + tokenizer_pre == "poro-chat") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_PORO; + vocab.tokenizer_clean_spaces = false; + } else if ( + tokenizer_pre == "chatglm-bpe") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_CHATGLM4; + vocab.special_bos_id = -1; + } else if ( + tokenizer_pre == "viking") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_VIKING; + vocab.tokenizer_clean_spaces = false; + } else if ( + tokenizer_pre == "jais") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_JAIS; + } else if ( + tokenizer_pre == "tekken") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_TEKKEN; + vocab.tokenizer_clean_spaces = false; + vocab.tokenizer_ignore_merges = true; + vocab.tokenizer_add_bos = true; + } else if ( + tokenizer_pre == "smollm") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_SMOLLM; + vocab.tokenizer_clean_spaces = false; + } else if ( + tokenizer_pre == "codeshell") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_CODESHELL; } else { LLAMA_LOG_WARN("%s: missing or unrecognized pre-tokenizer type, using: 'default'\n", __func__); vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; } + } else if (vocab.type == LLAMA_VOCAB_TYPE_SPM) { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; + vocab.tokenizer_add_space_prefix = true; + vocab.tokenizer_clean_spaces = false; + vocab.tokenizer_add_bos = true; + vocab.tokenizer_add_eos = false; + } else if (vocab.type == LLAMA_VOCAB_TYPE_WPM) { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; + vocab.tokenizer_add_space_prefix = false; + vocab.tokenizer_clean_spaces = true; + vocab.tokenizer_add_bos = true; + vocab.tokenizer_add_eos = false; + } else if (vocab.type == LLAMA_VOCAB_TYPE_UGM) { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; + vocab.tokenizer_add_bos = false; + vocab.tokenizer_add_eos = true; } else { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; } + + ml.get_key(LLM_KV_TOKENIZER_ADD_PREFIX, vocab.tokenizer_add_space_prefix, false); + ml.get_key(LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, vocab.tokenizer_remove_extra_whitespaces, false); } const int token_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_LIST).c_str()); @@ -4757,6 +5513,7 @@ static void llm_load_vocab( GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0); vocab.token_to_id[word] = i; + vocab.max_token_len = std::max(vocab.max_token_len, (int) word.size()); auto & token_data = vocab.id_to_token[i]; token_data.text = std::move(word); @@ -4780,8 +5537,46 @@ static void llm_load_vocab( // determine the newline token: LLaMA "<0x0A>" == 10 == '\n', Falcon 193 == '\n' if (vocab.type == LLAMA_VOCAB_TYPE_SPM) { + // For Fill-In-the-Middle (FIM)/infill models which where converted + // prior to support of FIM special tokens in GGUF, the following + // will allow those models to continue to work. The general names + // of the known models are currently CodeLlama (LLM_ARCH_LLAMA) and + // CodeGemma (LLM_ARCH_GEMMA). This can potentially be removed once + // new versions of these models have been published. + std::string gen_name; + ml.get_key(LLM_KV_GENERAL_NAME, gen_name, false); + + std::transform(gen_name.begin(), gen_name.end(), gen_name.begin(), + [](unsigned char c){ return std::tolower(c); }); + + if (gen_name.find("code") != std::string::npos) { + if (model.arch == LLM_ARCH_LLAMA + && 32010 < vocab.id_to_token.size() + && vocab.id_to_token[32007].text.find("
") != std::string::npos
+              && vocab.id_to_token[32008].text.find("") != std::string::npos
+              && vocab.id_to_token[32009].text.find("") != std::string::npos
+              && vocab.id_to_token[32010].text.find("") != std::string::npos) {
+                vocab.special_prefix_id = 32007;
+                vocab.special_suffix_id = 32008;
+                vocab.special_middle_id = 32009;
+                vocab.special_eot_id    = 32010;
+            } else if (model.arch == LLM_ARCH_GEMMA
+              && 107 < vocab.id_to_token.size()
+              && vocab.id_to_token[67].text == "<|fim_prefix|>"
+              && vocab.id_to_token[69].text == "<|fim_suffix|>"
+              && vocab.id_to_token[68].text == "<|fim_middle|>"
+              && vocab.id_to_token[107].text == "") {
+                vocab.special_prefix_id = 67;
+                vocab.special_suffix_id = 69;
+                vocab.special_middle_id = 68;
+                // TODO: this is not EOT, it is "file separator" token, needs fix
+                //       https://huggingface.co/google/codegemma-7b-it/blob/9b1d9231388358c04d90bd003458f5070d97db44/tokenizer_config.json#L565-L572
+                //vocab.special_eot_id    = 70;
+                vocab.special_eot_id    = 107;
+            }
+        }
         try {
-            vocab.linefeed_id = llama_byte_to_token(vocab, '\n');
+            vocab.linefeed_id = llama_byte_to_token_impl(vocab, '\n');
         } catch (const std::exception & e) {
             LLAMA_LOG_WARN("%s: SPM vocabulary, but newline token not found: %s! Using special_pad_id instead.", __func__, e.what());
             vocab.linefeed_id = vocab.special_pad_id;
@@ -4831,10 +5626,10 @@ static void llm_load_vocab(
             bool temp = true;
 
             if (ml.get_key(LLM_KV_TOKENIZER_ADD_BOS, temp, false)) {
-                vocab.special_add_bos = int(temp);
+                vocab.tokenizer_add_bos = temp;
             }
             if (ml.get_key(LLM_KV_TOKENIZER_ADD_EOS, temp, false)) {
-                vocab.special_add_eos = int(temp);
+                vocab.tokenizer_add_eos = temp;
             }
         }
 
@@ -4865,12 +5660,12 @@ static void llm_load_vocab(
     // build special tokens cache
     {
         for (llama_vocab::id id = 0; id < (llama_vocab::id)n_vocab; ++id) {
-            if (!(vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_NORMAL)) {
+            if (vocab.id_to_token[id].attr & (LLAMA_TOKEN_ATTR_CONTROL | LLAMA_TOKEN_ATTR_USER_DEFINED | LLAMA_TOKEN_ATTR_UNKNOWN)) {
                 vocab.cache_special_tokens.push_back(id);
             }
         }
 
-        std::sort( vocab.cache_special_tokens.begin(), vocab.cache_special_tokens.end(),
+        std::sort(vocab.cache_special_tokens.begin(), vocab.cache_special_tokens.end(),
             [&] (const llama_vocab::id a, const llama_vocab::id b) {
                 return vocab.id_to_token[a].text.size() > vocab.id_to_token[b].text.size();
             }
@@ -4934,7 +5729,7 @@ static void llm_load_vocab(
         );
 
         // set attributes by model/tokenizer name
-        if (_contains_any(tokenizer_pre, {"jina-v2-es", "jina-v2-de"})) {
+        if (_contains_any(tokenizer_pre, {"jina-v2-de", "jina-v2-es", "jina-v2-code"})) {
             _set_token_attr("", LLAMA_TOKEN_ATTR_LSTRIP, true);
         } else if (_contains_any(model_name, {"phi-3", "phi3"})) {
             for (auto id : vocab.cache_special_tokens) {
@@ -4956,43 +5751,78 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
 
     const char * rope_scaling_type = LLAMA_ROPE_SCALING_TYPES.at(hparams.rope_scaling_type_train);
 
+    auto print_f = [](const std::function & f, uint32_t n) {
+        bool is_var = false;
+
+        std::vector v;
+        for (uint32_t i = 0; i < n; ++i) {
+            v.push_back(f(i));
+            if (v[i] != v[0]) {
+                is_var = true;
+            }
+        }
+
+        std::stringstream ss;
+
+        if (is_var) {
+            ss << "[";
+            for (uint32_t i = 0; i < n; ++i) {
+                ss << v[i];
+                if (i < n - 1) {
+                    ss << ", ";
+                }
+            }
+            ss << "]";
+        } else {
+            ss << v[0];
+        }
+
+        return ss.str();
+    };
+
     // hparams
     LLAMA_LOG_INFO("%s: format           = %s\n",     __func__, llama_file_version_name(ml.fver));
     LLAMA_LOG_INFO("%s: arch             = %s\n",     __func__, LLM_ARCH_NAMES.at(model.arch));
     LLAMA_LOG_INFO("%s: vocab type       = %s\n",     __func__, llama_model_vocab_type_name(vocab.type));
     LLAMA_LOG_INFO("%s: n_vocab          = %u\n",     __func__, hparams.n_vocab);
     LLAMA_LOG_INFO("%s: n_merges         = %u\n",     __func__, (int) vocab.bpe_ranks.size());
-    LLAMA_LOG_INFO("%s: n_ctx_train      = %u\n",     __func__, hparams.n_ctx_train);
-    LLAMA_LOG_INFO("%s: n_embd           = %u\n",     __func__, hparams.n_embd);
-    LLAMA_LOG_INFO("%s: n_head           = %u\n",     __func__, hparams.n_head);
-    LLAMA_LOG_INFO("%s: n_head_kv        = %u\n",     __func__, hparams.n_head_kv);
-    LLAMA_LOG_INFO("%s: n_layer          = %u\n",     __func__, hparams.n_layer);
-    LLAMA_LOG_INFO("%s: n_rot            = %u\n",     __func__, hparams.n_rot);
-    LLAMA_LOG_INFO("%s: n_embd_head_k    = %u\n",     __func__, hparams.n_embd_head_k);
-    LLAMA_LOG_INFO("%s: n_embd_head_v    = %u\n",     __func__, hparams.n_embd_head_v);
-    LLAMA_LOG_INFO("%s: n_gqa            = %u\n",     __func__, hparams.n_gqa());
-    LLAMA_LOG_INFO("%s: n_embd_k_gqa     = %u\n",     __func__, hparams.n_embd_k_gqa());
-    LLAMA_LOG_INFO("%s: n_embd_v_gqa     = %u\n",     __func__, hparams.n_embd_v_gqa());
-    LLAMA_LOG_INFO("%s: f_norm_eps       = %.1e\n",   __func__, hparams.f_norm_eps);
-    LLAMA_LOG_INFO("%s: f_norm_rms_eps   = %.1e\n",   __func__, hparams.f_norm_rms_eps);
-    LLAMA_LOG_INFO("%s: f_clamp_kqv      = %.1e\n",   __func__, hparams.f_clamp_kqv);
-    LLAMA_LOG_INFO("%s: f_max_alibi_bias = %.1e\n",   __func__, hparams.f_max_alibi_bias);
-    LLAMA_LOG_INFO("%s: f_logit_scale    = %.1e\n",   __func__, hparams.f_logit_scale);
-    LLAMA_LOG_INFO("%s: n_ff             = %u\n",     __func__, hparams.n_ff);
-    LLAMA_LOG_INFO("%s: n_expert         = %u\n",     __func__, hparams.n_expert);
-    LLAMA_LOG_INFO("%s: n_expert_used    = %u\n",     __func__, hparams.n_expert_used);
-    LLAMA_LOG_INFO("%s: causal attn      = %d\n",     __func__, hparams.causal_attn);
-    LLAMA_LOG_INFO("%s: pooling type     = %d\n",     __func__, hparams.pooling_type);
-    LLAMA_LOG_INFO("%s: rope type        = %d\n",     __func__, hparams.rope_type);
-    LLAMA_LOG_INFO("%s: rope scaling     = %s\n",     __func__, rope_scaling_type);
-    LLAMA_LOG_INFO("%s: freq_base_train  = %.1f\n",   __func__, hparams.rope_freq_base_train);
-    LLAMA_LOG_INFO("%s: freq_scale_train = %g\n",     __func__, hparams.rope_freq_scale_train);
-    LLAMA_LOG_INFO("%s: n_ctx_orig_yarn  = %u\n",     __func__, hparams.n_ctx_orig_yarn);
-    LLAMA_LOG_INFO("%s: rope_finetuned   = %s\n",     __func__, hparams.rope_finetuned ? "yes" : "unknown");
-    LLAMA_LOG_INFO("%s: ssm_d_conv       = %u\n",     __func__, hparams.ssm_d_conv);
-    LLAMA_LOG_INFO("%s: ssm_d_inner      = %u\n",     __func__, hparams.ssm_d_inner);
-    LLAMA_LOG_INFO("%s: ssm_d_state      = %u\n",     __func__, hparams.ssm_d_state);
-    LLAMA_LOG_INFO("%s: ssm_dt_rank      = %u\n",     __func__, hparams.ssm_dt_rank);
+    LLAMA_LOG_INFO("%s: vocab_only       = %d\n",     __func__, hparams.vocab_only);
+
+    if (!hparams.vocab_only) {
+        LLAMA_LOG_INFO("%s: n_ctx_train      = %u\n",     __func__, hparams.n_ctx_train);
+        LLAMA_LOG_INFO("%s: n_embd           = %u\n",     __func__, hparams.n_embd);
+        LLAMA_LOG_INFO("%s: n_layer          = %u\n",     __func__, hparams.n_layer);
+        LLAMA_LOG_INFO("%s: n_head           = %s\n",     __func__, print_f([&](uint32_t il) { return hparams.n_head(il);    }, hparams.n_layer).c_str());
+        LLAMA_LOG_INFO("%s: n_head_kv        = %s\n",     __func__, print_f([&](uint32_t il) { return hparams.n_head_kv(il); }, hparams.n_layer).c_str());
+        LLAMA_LOG_INFO("%s: n_rot            = %u\n",     __func__, hparams.n_rot);
+        LLAMA_LOG_INFO("%s: n_swa            = %u\n",     __func__, hparams.n_swa);
+        LLAMA_LOG_INFO("%s: n_embd_head_k    = %u\n",     __func__, hparams.n_embd_head_k);
+        LLAMA_LOG_INFO("%s: n_embd_head_v    = %u\n",     __func__, hparams.n_embd_head_v);
+        LLAMA_LOG_INFO("%s: n_gqa            = %s\n",     __func__, print_f([&](uint32_t il) { return hparams.n_gqa(il);        }, hparams.n_layer).c_str());
+        LLAMA_LOG_INFO("%s: n_embd_k_gqa     = %s\n",     __func__, print_f([&](uint32_t il) { return hparams.n_embd_k_gqa(il); }, hparams.n_layer).c_str());
+        LLAMA_LOG_INFO("%s: n_embd_v_gqa     = %s\n",     __func__, print_f([&](uint32_t il) { return hparams.n_embd_v_gqa(il); }, hparams.n_layer).c_str());
+        LLAMA_LOG_INFO("%s: f_norm_eps       = %.1e\n",   __func__, hparams.f_norm_eps);
+        LLAMA_LOG_INFO("%s: f_norm_rms_eps   = %.1e\n",   __func__, hparams.f_norm_rms_eps);
+        LLAMA_LOG_INFO("%s: f_clamp_kqv      = %.1e\n",   __func__, hparams.f_clamp_kqv);
+        LLAMA_LOG_INFO("%s: f_max_alibi_bias = %.1e\n",   __func__, hparams.f_max_alibi_bias);
+        LLAMA_LOG_INFO("%s: f_logit_scale    = %.1e\n",   __func__, hparams.f_logit_scale);
+        LLAMA_LOG_INFO("%s: n_ff             = %s\n",     __func__, print_f([&](uint32_t il) { return hparams.n_ff(il); }, hparams.n_layer).c_str());
+        LLAMA_LOG_INFO("%s: n_expert         = %u\n",     __func__, hparams.n_expert);
+        LLAMA_LOG_INFO("%s: n_expert_used    = %u\n",     __func__, hparams.n_expert_used);
+        LLAMA_LOG_INFO("%s: causal attn      = %d\n",     __func__, hparams.causal_attn);
+        LLAMA_LOG_INFO("%s: pooling type     = %d\n",     __func__, hparams.pooling_type);
+        LLAMA_LOG_INFO("%s: rope type        = %d\n",     __func__, hparams.rope_type);
+        LLAMA_LOG_INFO("%s: rope scaling     = %s\n",     __func__, rope_scaling_type);
+        LLAMA_LOG_INFO("%s: freq_base_train  = %.1f\n",   __func__, hparams.rope_freq_base_train);
+        LLAMA_LOG_INFO("%s: freq_scale_train = %g\n",     __func__, hparams.rope_freq_scale_train);
+        LLAMA_LOG_INFO("%s: n_ctx_orig_yarn  = %u\n",     __func__, hparams.n_ctx_orig_yarn);
+        LLAMA_LOG_INFO("%s: rope_finetuned   = %s\n",     __func__, hparams.rope_finetuned ? "yes" : "unknown");
+        LLAMA_LOG_INFO("%s: ssm_d_conv       = %u\n",     __func__, hparams.ssm_d_conv);
+        LLAMA_LOG_INFO("%s: ssm_d_inner      = %u\n",     __func__, hparams.ssm_d_inner);
+        LLAMA_LOG_INFO("%s: ssm_d_state      = %u\n",     __func__, hparams.ssm_d_state);
+        LLAMA_LOG_INFO("%s: ssm_dt_rank      = %u\n",     __func__, hparams.ssm_dt_rank);
+    }
+
     LLAMA_LOG_INFO("%s: model type       = %s\n",     __func__, llama_model_type_name(model.type));
     LLAMA_LOG_INFO("%s: model ftype      = %s\n",     __func__, llama_model_ftype_name(model.ftype).c_str());
     if (ml.n_elements >= 1e12) {
@@ -5028,6 +5858,8 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
     if (vocab.special_middle_id != -1) { LLAMA_LOG_INFO( "%s: MID token        = %d '%s'\n", __func__, vocab.special_middle_id, vocab.id_to_token[vocab.special_middle_id].text.c_str() ); }
     if (vocab.special_eot_id    != -1) { LLAMA_LOG_INFO( "%s: EOT token        = %d '%s'\n", __func__, vocab.special_eot_id,    vocab.id_to_token[vocab.special_eot_id].text.c_str() );    }
 
+    LLAMA_LOG_INFO("%s: max token length = %d\n", __func__, vocab.max_token_len);
+
     if (model.arch == LLM_ARCH_DEEPSEEK2) {
         LLAMA_LOG_INFO("%s: n_layer_dense_lead   = %d\n",     __func__, hparams.n_layer_dense_lead);
         LLAMA_LOG_INFO("%s: n_lora_q             = %d\n",     __func__, hparams.n_lora_q);
@@ -5037,6 +5869,11 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
         LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n",   __func__, hparams.expert_weights_scale);
         LLAMA_LOG_INFO("%s: rope_yarn_log_mul    = %.4f\n",   __func__, hparams.rope_yarn_log_mul);
     }
+
+    if (model.arch == LLM_ARCH_QWEN2MOE) {
+        LLAMA_LOG_INFO("%s: n_ff_exp         = %d\n",     __func__, hparams.n_ff_exp);
+        LLAMA_LOG_INFO("%s: n_ff_shexp       = %d\n",     __func__, hparams.n_ff_shexp);
+    }
 }
 
 // Returns false if cancelled by progress_callback
@@ -5054,19 +5891,12 @@ static bool llm_load_tensors(
 
     auto & hparams = model.hparams;
 
-#ifdef GGML_USE_SYCL
-    // disable MoE with SYCL until mul_mat_id is updated
-    if (hparams.n_expert > 0) {
-        n_gpu_layers = 0;
-    }
-#endif
-
     model.split_mode   = split_mode;
     model.main_gpu     = main_gpu;
     model.n_gpu_layers = n_gpu_layers;
 
-    const int64_t n_layer     = hparams.n_layer;
-    const int64_t i_gpu_start = std::max((int64_t) hparams.n_layer - n_gpu_layers, (int64_t) 0);
+    const int n_layer     = hparams.n_layer;
+    const int i_gpu_start = std::max((int) hparams.n_layer - n_gpu_layers, (int) 0);
     bool use_mmap_buffer = true;
 
     // there is very little benefit to offloading the input layer, so always keep it on the CPU
@@ -5076,7 +5906,7 @@ static bool llm_load_tensors(
     model.buft_layer.resize(n_layer);
 
     // assign cpu layers
-    for (int64_t i = 0; i < i_gpu_start; ++i) {
+    for (int i = 0; i < i_gpu_start; ++i) {
         model.buft_layer[i] = llama_default_buffer_type_cpu(true);
     }
 
@@ -5106,7 +5936,7 @@ static bool llm_load_tensors(
 
         // assign the repeating layers to the devices according to the splits
         int act_gpu_layers = std::min(n_gpu_layers, (int)n_layer + 1);
-        for (int64_t i = i_gpu_start; i < n_layer; ++i) {
+        for (int i = i_gpu_start; i < n_layer; ++i) {
             int layer_gpu = std::upper_bound(splits.begin(), splits.begin() + device_count, float(i - i_gpu_start)/act_gpu_layers) - splits.begin();
             model.buft_layer[i] = llama_default_buffer_type_offload(model, layer_gpu);
         }
@@ -5126,7 +5956,7 @@ static bool llm_load_tensors(
             split_buft = llama_default_buffer_type_offload(model, main_gpu);
         }
         // assign the repeating layers
-        for (int64_t i = i_gpu_start; i < n_layer; ++i) {
+        for (int i = i_gpu_start; i < n_layer; ++i) {
             model.buft_layer[i] = {
                 split_buft,
                 llama_default_buffer_type_offload(model, main_gpu)
@@ -5149,7 +5979,7 @@ static bool llm_load_tensors(
     buft_layer_count[model.buft_input.buft_matrix]++;
     buft_layer_count[model.buft_output.buft]++;
     buft_layer_count[model.buft_output.buft_matrix]++;
-    for (int64_t i = 0; i < n_layer; ++i) {
+    for (int i = 0; i < n_layer; ++i) {
         buft_layer_count[model.buft_layer[i].buft]++;
         buft_layer_count[model.buft_layer[i].buft_matrix]++;
     }
@@ -5179,15 +6009,21 @@ static bool llm_load_tensors(
 
     // create tensors for the weights
     {
-        const int64_t n_embd       = hparams.n_embd;
-        const int64_t n_embd_head  = n_embd / hparams.n_head;
-        const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
-        const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
-        const int64_t n_embd_gqa   = n_embd_v_gqa;
-        const int64_t n_vocab      = hparams.n_vocab;
-        const int64_t n_vocab_type = hparams.n_vocab_type;
-        const int64_t n_ff         = hparams.n_ff;
-        const int64_t n_expert     = hparams.n_expert;
+        // note: cast to int64_t since we will use these for the tensor dimensions
+        const int64_t n_head        = hparams.n_head();
+        const int64_t n_head_kv     = hparams.n_head_kv();
+        const int64_t n_embd        = hparams.n_embd;
+        const int64_t n_embd_k_gqa  = hparams.n_embd_k_gqa();
+        const int64_t n_embd_v_gqa  = hparams.n_embd_v_gqa();
+        const int64_t n_embd_head_k = hparams.n_embd_head_k;
+        const int64_t n_embd_head_v = hparams.n_embd_head_v;
+        const int64_t n_ff          = hparams.n_ff();
+        const int64_t n_embd_gqa    = n_embd_v_gqa;
+        const int64_t n_vocab       = hparams.n_vocab;
+        const int64_t n_vocab_type  = hparams.n_vocab_type;
+        const int64_t n_expert      = hparams.n_expert;
+        const int64_t n_expert_used = hparams.n_expert_used;
+        const int64_t n_ctx_train   = hparams.n_ctx_train;
 
         if (n_expert > 0 && hparams.n_expert_used == 0) {
             throw std::runtime_error("model has expert layers but no expert layers are used");
@@ -5196,8 +6032,9 @@ static bool llm_load_tensors(
         ggml_context * ctx_input        = ctx_map.at(model.buft_input.buft);
         ggml_context * ctx_output       = ctx_map.at(model.buft_output.buft);
         ggml_context * ctx_output_split = ctx_map.at(model.buft_output.buft_matrix);
-        auto ctx_for_layer              = [&](int i) { return ctx_map.at(model.buft_layer[i].buft); };
-        auto ctx_for_layer_split        = [&](int i) { return ctx_map.at(model.buft_layer[i].buft_matrix); };
+
+        auto ctx_for_layer       = [&](int i) { return ctx_map.at(model.buft_layer[i].buft); };
+        auto ctx_for_layer_split = [&](int i) { return ctx_map.at(model.buft_layer[i].buft_matrix); };
 
         model.layers.resize(n_layer);
 
@@ -5212,7 +6049,8 @@ static bool llm_load_tensors(
                     // output
                     {
                         model.output_norm = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
-                        model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        model.output      = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
+
                         // if output is NULL, init from the input tok embed
                         if (model.output == NULL) {
                             model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
@@ -5227,10 +6065,10 @@ static bool llm_load_tensors(
 
                         layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
 
-                        layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd});
-                        layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa});
-                        layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa});
-                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
+                        layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_head_k * n_head});
+                        layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa});
+                        layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa});
+                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd});
 
                         // optional bias tensors
                         layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q,   "bias", i), {n_embd},     llama_model_loader::TENSOR_NOT_REQUIRED);
@@ -5240,6 +6078,8 @@ static bool llm_load_tensors(
 
                         layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
 
+                        layer.rope_freqs = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FREQS, "weight"), {n_embd/n_head/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
+
                         if (n_expert == 0) {
                             layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff});
                             layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd});
@@ -5295,6 +6135,7 @@ static bool llm_load_tensors(
                     {
                         model.output_norm = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
                         model.output      = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
+
                         // if output is NULL, init from the input tok embed
                         if (model.output == NULL) {
                             model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
@@ -5318,9 +6159,9 @@ static bool llm_load_tensors(
 
                         layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
 
-                        layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert});
-
+                        layer.ffn_gate_inp  = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP,  "weight", i), {n_embd, n_expert});
                         layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, llama_model_loader::TENSOR_NOT_REQUIRED);
+
                         if (layer.ffn_gate_exps) {
                             layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {  n_ff, n_embd, n_expert});
                             layer.ffn_up_exps   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {n_embd,   n_ff, n_expert});
@@ -5372,12 +6213,12 @@ static bool llm_load_tensors(
 
                     auto & layer = model.layers[i];
 
-                    layer.attn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM,  "weight", i), {n_embd});
+                    layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
 
                     layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa});
                     layer.wo   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
 
-                    layer.attn_out_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd});
+                    layer.attn_out_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd});
 
                     layer.ffn_gate_inp  = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP,  "weight", i), {n_embd, n_expert});
                     layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff,   n_expert});
@@ -5419,10 +6260,10 @@ static bool llm_load_tensors(
 
                     // output
                     {
-                        model.output_norm   = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
-                        model.output_norm_b = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd});
+                        model.output_norm   = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
+                        model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd});
 
-                        model.output        = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
                         if (!model.output) {
                             model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // needs to be on GPU
                         }
@@ -5450,7 +6291,7 @@ static bool llm_load_tensors(
             case LLM_ARCH_STARCODER:
                 {
                     model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
-                    model.pos_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_POS_EMBD,   "weight"), {n_embd, hparams.n_ctx_train});
+                    model.pos_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_POS_EMBD,   "weight"), {n_embd, n_ctx_train});
 
                     // output
                     {
@@ -5476,8 +6317,8 @@ static bool llm_load_tensors(
                         layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa});
                         layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd + 2*n_embd_gqa});
 
-                        layer.wo   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
-                        layer.bo   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i),   {n_embd});
+                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
+                        layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i),   {n_embd});
 
                         layer.ffn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
                         layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd});
@@ -5485,8 +6326,8 @@ static bool llm_load_tensors(
                         layer.ffn_down   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
                         layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i),   {n_embd});
 
-                        layer.ffn_up     = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i),   {n_embd, n_ff});
-                        layer.ffn_up_b   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i),     {n_ff});
+                        layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i),   {n_embd, n_ff});
+                        layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i),     {n_ff});
                     }
                 } break;
             case LLM_ARCH_BERT:
@@ -5494,8 +6335,9 @@ static bool llm_load_tensors(
                 {
                     model.tok_embd     = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD,  "weight"), {n_embd, n_vocab});
                     model.type_embd    = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_vocab_type});
+
                     if (model.arch == LLM_ARCH_BERT) {
-                        model.pos_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_POS_EMBD,    "weight"), {n_embd, hparams.n_ctx_train});
+                        model.pos_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_POS_EMBD,    "weight"), {n_embd, n_ctx_train});
                     }
 
                     model.tok_norm   = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd});
@@ -5508,31 +6350,30 @@ static bool llm_load_tensors(
                         auto & layer = model.layers[i];
 
                         if (model.arch == LLM_ARCH_BERT) {
-                            layer.wq   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd});
-                            layer.bq   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q,   "bias", i),   {n_embd});
+                            layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd});
+                            layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q,   "bias", i),   {n_embd});
 
-                            layer.wk   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa});
-                            layer.bk   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K,   "bias", i),   {n_embd_gqa});
+                            layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa});
+                            layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K,   "bias", i),   {n_embd_gqa});
 
-                            layer.wv   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa});
-                            layer.bv   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V,   "bias", i),   {n_embd_gqa});
+                            layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa});
+                            layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V,   "bias", i),   {n_embd_gqa});
                         } else {
                             layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa});
                         }
 
-                        layer.wo              = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT,      "weight", i), {n_embd, n_embd});
+                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT,      "weight", i), {n_embd, n_embd});
 
                         layer.attn_out_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd});
                         layer.attn_out_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i),   {n_embd});
 
-                        layer.ffn_up          = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,        "weight", i), {n_embd, n_ff});
-                        layer.ffn_down        = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN,      "weight", i), {n_ff, n_embd});
+                        layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,        "weight", i), {n_embd, n_ff});
+                        layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN,      "weight", i), {n_ff, n_embd});
 
                         if (model.arch == LLM_ARCH_BERT) {
-                            layer.bo         = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i),   {n_embd});
-                            layer.ffn_up_b   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP,   "bias", i),   {n_ff});
-
-                            layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i),   {n_embd});
+                            layer.bo         = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd});
+                            layer.ffn_up_b   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP,   "bias", i), {n_ff});
+                            layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd});
                         } else {
                             layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
                         }
@@ -5543,8 +6384,9 @@ static bool llm_load_tensors(
                 } break;
             case LLM_ARCH_JINA_BERT_V2:
                 {
-                    model.tok_embd     = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD,  "weight"), {n_embd, n_vocab}); // word_embeddings
-                    model.type_embd    = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_vocab_type}); //token_type_embeddings
+                    model.tok_embd  = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD,  "weight"), {n_embd, n_vocab}); // word_embeddings
+                    model.type_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_vocab_type}); // token_type_embeddings
+
                     model.tok_norm   = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}); // LayerNorm
                     model.tok_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"),   {n_embd}); //LayerNorm bias
 
@@ -5554,38 +6396,38 @@ static bool llm_load_tensors(
 
                         auto & layer = model.layers[i]; // JinaBertLayer
 
-                        layer.wq   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd});
-                        layer.bq   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q,   "bias", i),   {n_embd});
+                        layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd});
+                        layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i),   {n_embd});
 
                         layer.attn_q_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.attn_q_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.attn_q_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "bias",   i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
 
-                        layer.wk   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa});
-                        layer.bk   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K,   "bias", i),   {n_embd_gqa});
+                        layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa});
+                        layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias",   i), {n_embd_gqa});
 
                         layer.attn_k_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.attn_k_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.attn_k_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "bias",   i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
 
-                        layer.wv   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa});
-                        layer.bv   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V,   "bias", i),   {n_embd_gqa});
+                        layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa});
+                        layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias",   i), {n_embd_gqa});
 
-                        layer.wo              = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT,      "weight", i), {n_embd, n_embd}); //output_dens
-                        layer.bo              = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT,      "bias", i), {n_embd}); //output_dens
+                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); //output_dens
+                        layer.bo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "bias",   i), {n_embd}); //output_dens
 
                         layer.attn_out_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}); //output_norm
-                        layer.attn_out_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i),   {n_embd});
+                        layer.attn_out_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "bias",   i), {n_embd});
 
                         layer.attn_norm_2   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.attn_norm_2_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM_2, "bias", i),   {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.attn_norm_2_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM_2, "bias",   i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
 
-                        layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,        "weight", i), {n_embd, n_ff});
-                        layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE,    "weight", i), {n_embd, n_ff});
+                        layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff});
+                        layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
 
-                        layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN,        "weight", i), {n_ff, n_embd});
-                        layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN,      "bias", i), {n_embd});
+                        layer.ffn_down   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
+                        layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias",   i), {n_embd});
 
-                        layer.layer_out_norm = ml.create_tensor(ctx_split, tn(LLM_TENSOR_LAYER_OUT_NORM,        "weight", i), {n_embd});
-                        layer.layer_out_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM,        "bias", i), {n_embd});
+                        layer.layer_out_norm   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd});
+                        layer.layer_out_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "bias",   i), {n_embd});
                     }
                 } break;
             case LLM_ARCH_BLOOM:
@@ -5608,35 +6450,35 @@ static bool llm_load_tensors(
                         auto & layer = model.layers[i];
 
                         layer.attn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
-                        layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i),   {n_embd});
+                        layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias",   i), {n_embd});
 
                         layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa});
-                        layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd + 2*n_embd_gqa});
+                        layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias",   i), {n_embd + 2*n_embd_gqa});
 
                         layer.wo   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
-                        layer.bo   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i),   {n_embd});
+                        layer.bo   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias",   i), {n_embd});
 
                         layer.ffn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
-                        layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd});
+                        layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias",   i), {n_embd});
 
                         layer.ffn_down   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
-                        layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i),   {n_embd});
+                        layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias",   i), {n_embd});
 
-                        layer.ffn_up     = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff});
-                        layer.ffn_up_b   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP,   "bias", i),   {n_ff});
+                        layer.ffn_up     = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
+                        layer.ffn_up_b   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias",   i), {n_ff});
                     }
                 } break;
             case LLM_ARCH_MPT:
                 {
                     model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
-                    model.pos_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_POS_EMBD,   "weight"), {n_embd, hparams.n_ctx_train}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                    model.pos_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_POS_EMBD,   "weight"), {n_embd, n_ctx_train}, llama_model_loader::TENSOR_NOT_REQUIRED);
 
                     // output
                     {
-                        model.output_norm   = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
-                        model.output_norm_b = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        model.output_norm   = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
+                        model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
 
-                        model.output        = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        model.output        = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
                         if (!model.output) {
                             model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // needs to be on GPU
                         }
@@ -5707,8 +6549,8 @@ static bool llm_load_tensors(
                         layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V,   "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
 
                         // optional q and k layernorms, present in StableLM 2 12B
-                        layer.attn_q_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {hparams.n_embd_head_k, hparams.n_head}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.attn_k_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {hparams.n_embd_head_k, hparams.n_head_kv}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.attn_q_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k, n_head},    llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.attn_k_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k, n_head_kv}, llama_model_loader::TENSOR_NOT_REQUIRED);
 
                         // optional FFN norm, not present in StableLM 2 12B which uses parallel residual
                         layer.ffn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
@@ -5819,20 +6661,23 @@ static bool llm_load_tensors(
 
                         layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert});
 
-                        GGML_ASSERT(hparams.n_expert      > 0);
-                        GGML_ASSERT(hparams.n_expert_used > 0);
+                        GGML_ASSERT(n_expert      > 0);
+                        GGML_ASSERT(n_expert_used > 0);
 
                         // MoE branch
-                        auto n_ff_exp = n_ff / hparams.n_expert_used;
+                        const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used;
+
                         layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {  n_embd, n_ff_exp, n_expert});
                         layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp,   n_embd, n_expert});
                         layer.ffn_up_exps   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {  n_embd, n_ff_exp, n_expert});
 
                         // Shared expert branch
+                        const int64_t n_ff_shexp = hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff;
+
                         layer.ffn_gate_inp_shexp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), {n_embd});
-                        layer.ffn_gate_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd,   n_ff});
-                        layer.ffn_down_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {  n_ff, n_embd});
-                        layer.ffn_up_shexp   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_SHEXP,   "weight", i), {n_embd,   n_ff});
+                        layer.ffn_gate_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {    n_embd, n_ff_shexp});
+                        layer.ffn_down_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp,     n_embd});
+                        layer.ffn_up_shexp   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_SHEXP,   "weight", i), {    n_embd, n_ff_shexp});
                     }
                 } break;
             case LLM_ARCH_PHI2:
@@ -5882,6 +6727,8 @@ static bool llm_load_tensors(
                 } break;
             case LLM_ARCH_PHI3:
                 {
+                    const int64_t n_embd_head = n_embd / n_head;
+
                     model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab });
 
                     // output
@@ -5891,8 +6738,8 @@ static bool llm_load_tensors(
                     }
 
                     for (int i = 0; i < n_layer; ++i) {
-                        ggml_context* ctx_layer = ctx_for_layer(i);
-                        ggml_context* ctx_split = ctx_for_layer_split(i);
+                        ggml_context * ctx_layer = ctx_for_layer(i);
+                        ggml_context * ctx_split = ctx_for_layer_split(i);
 
                         auto & layer = model.layers[i];
 
@@ -5941,7 +6788,7 @@ static bool llm_load_tensors(
             case LLM_ARCH_GPT2:
                 {
                     model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
-                    model.pos_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_POS_EMBD,   "weight"),   {n_embd, hparams.n_ctx_train});
+                    model.pos_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_POS_EMBD,   "weight"), {n_embd, n_ctx_train});
 
                     // output
                     {
@@ -6078,12 +6925,7 @@ static bool llm_load_tensors(
                     model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
                     model.output      = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD,  "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading
 
-                    const int64_t n_ff          = hparams.n_ff;
-                    const int64_t n_embd_head_k = hparams.n_embd_head_k;
-                    const int64_t n_embd_k_gqa  = hparams.n_embd_k_gqa();
-                    const int64_t n_embd_v_gqa  = hparams.n_embd_v_gqa();
-
-                    for (uint32_t i = 0; i < n_layer; ++i) {
+                    for (int i = 0; i < n_layer; ++i) {
                         ggml_context * ctx_layer = ctx_for_layer(i);
                         ggml_context * ctx_split = ctx_for_layer_split(i);
 
@@ -6091,10 +6933,10 @@ static bool llm_load_tensors(
 
                         layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
 
-                        layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_head_k * hparams.n_head});
+                        layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_head_k * n_head});
                         layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa});
                         layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa});
-                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * hparams.n_head, n_embd});
+                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd});
 
                         layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
                         layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff});
@@ -6102,6 +6944,35 @@ static bool llm_load_tensors(
                         layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd});
                     }
                 } break;
+            case LLM_ARCH_GEMMA2:
+                {
+                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+
+                    // output
+                    model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
+                    model.output      = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD,  "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        ggml_context * ctx_layer = ctx_for_layer(i);
+                        ggml_context * ctx_split = ctx_for_layer_split(i);
+
+                        auto & layer = model.layers[i];
+
+                        layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
+
+                        layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_head_k * n_head});
+                        layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa});
+                        layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa});
+                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd});
+                        layer.attn_post_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd});
+
+                        layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
+                        layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff});
+                        layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff});
+                        layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd});
+                        layer.ffn_post_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd});
+                    }
+                } break;
             case LLM_ARCH_STARCODER2:
                 {
                     model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
@@ -6156,6 +7027,7 @@ static bool llm_load_tensors(
                     const int64_t d_inner = hparams.ssm_d_inner;
                     const int64_t d_state = hparams.ssm_d_state;
                     const int64_t dt_rank = hparams.ssm_dt_rank;
+
                     // only an expansion factor of 2 is supported for now
                     GGML_ASSERT(2 * n_embd == d_inner);
 
@@ -6206,15 +7078,20 @@ static bool llm_load_tensors(
                         model.output_norm = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
                         model.output      = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab});
                     }
+
                     for (int i = 0; i < n_layer; ++i) {
                         ggml_context * ctx_layer = ctx_for_layer(i);
                         ggml_context * ctx_split = ctx_for_layer_split(i);
+
                         auto & layer = model.layers[i];
+
                         layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
+
                         layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd});
                         layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa});
                         layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa});
                         layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
+
                         layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
                         layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff});
                         layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd});
@@ -6241,8 +7118,8 @@ static bool llm_load_tensors(
                         layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
 
                         if (n_layer >= 64){
-                            layer.attn_q_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {hparams.n_embd_head_k, hparams.n_head});
-                            layer.attn_k_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {hparams.n_embd_head_k, hparams.n_head_kv});
+                            layer.attn_q_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k, n_head});
+                            layer.attn_k_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k, n_head_kv});
                         }
 
                         layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd});
@@ -6278,15 +7155,49 @@ static bool llm_load_tensors(
                         layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa});
                         layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
 
-
                         layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff});
                         layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd});
                         layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff});
                     }
                 } break;
+            case LLM_ARCH_OPENELM:
+                {
+                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+
+                    // output
+                    {
+                        model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
+                        // init output from the input tok embed
+                        model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
+                    }
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        const int64_t n_head      =   hparams.n_head(i);
+                        const int64_t n_head_qkv  = 2*hparams.n_head_kv(i) + n_head;
+                        const int64_t n_ff        =   hparams.n_ff(i);
+
+                        ggml_context * ctx_layer = ctx_for_layer(i);
+                        ggml_context * ctx_split = ctx_for_layer_split(i);
+
+                        auto & layer = model.layers[i];
+
+                        layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
+
+                        layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_head_qkv*n_embd_head_k});
+                        layer.attn_q_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k});
+                        layer.attn_k_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k});
+                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head*n_embd_head_k, n_embd});
+
+                        layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
+                        layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
+                        layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
+                        layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff});
+                    }
+                } break;
             case LLM_ARCH_GPTNEOX:
                 {
-                    model.tok_embd   = ml.create_tensor(ctx_input,  tn(LLM_TENSOR_TOKEN_EMBD,      "weight"), {n_embd, n_vocab});
+                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+
                     // output
                     {
                         model.output_norm   = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
@@ -6325,8 +7236,9 @@ static bool llm_load_tensors(
 
                     // output
                     {
-                        model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
-                        model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        model.output_norm = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
+                        model.output      = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
+
                         // if output is NULL, init from the input tok embed
                         if (model.output == NULL) {
                             model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
@@ -6361,13 +7273,16 @@ static bool llm_load_tensors(
                 } break;
             case LLM_ARCH_DEEPSEEK2:
                 {
-                    bool is_lite = (hparams.n_layer == 27);
+                    const bool is_lite = (hparams.n_layer == 27);
 
-                    const uint32_t n_embd_head_qk_rope = hparams.n_rot;
-                    const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
-                    const uint32_t q_lora_rank = hparams.n_lora_q;
-                    const uint32_t kv_lora_rank = hparams.n_lora_kv;
-                    const uint32_t n_ff_exp = hparams.n_ff_exp;
+                    const int64_t n_embd_head_qk_rope = hparams.n_rot;
+                    const int64_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
+
+                    const int64_t q_lora_rank  = hparams.n_lora_q;
+                    const int64_t kv_lora_rank = hparams.n_lora_kv;
+
+                    const int64_t n_ff_exp        = hparams.n_ff_exp;
+                    const int64_t n_expert_shared = hparams.n_expert_shared;
 
                     model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
 
@@ -6387,29 +7302,31 @@ static bool llm_load_tensors(
                         if (!is_lite) {
                             layer.attn_q_a_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank});
                         }
+
                         layer.attn_kv_a_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank});
 
                         if (!is_lite) {
-                            layer.wq_a = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q_A,   "weight", i), {n_embd, q_lora_rank});
-                            layer.wq_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q_B,   "weight", i), {q_lora_rank, hparams.n_head * hparams.n_embd_head_k});
+                            layer.wq_a = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank});
+                            layer.wq_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k});
                         } else {
-                            layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_k_gqa});
+                            layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa});
                         }
-                        layer.wkv_a_mqa = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_KV_A_MQA,   "weight", i), {n_embd, kv_lora_rank + n_embd_head_qk_rope});
-                        layer.wkv_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_KV_B,   "weight", i), {kv_lora_rank, hparams.n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)});
-                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {hparams.n_head * hparams.n_embd_head_v, n_embd});
+
+                        layer.wkv_a_mqa = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + (n_embd_head_qk_rope)});
+                        layer.wkv_b     = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_KV_B,     "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)});
+                        layer.wo        = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT,      "weight", i), {              n_head * (                      n_embd_head_v), n_embd});
 
                         layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
 
-                        if ((uint32_t) i < hparams.n_layer_dense_lead) {
+                        if (i < (int) hparams.n_layer_dense_lead) {
                             layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff});
                             layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd});
                             layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff});
                         } else {
                             layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert});
 
-                            GGML_ASSERT(hparams.n_expert      > 0);
-                            GGML_ASSERT(hparams.n_expert_used > 0);
+                            GGML_ASSERT(n_expert      > 0);
+                            GGML_ASSERT(n_expert_used > 0);
 
                             // MoE branch
                             layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {  n_embd, n_ff_exp, n_expert});
@@ -6417,12 +7334,179 @@ static bool llm_load_tensors(
                             layer.ffn_up_exps   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {  n_embd, n_ff_exp, n_expert});
 
                             // Shared expert branch
-                            layer.ffn_gate_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd,   n_ff_exp * hparams.n_expert_shared});
-                            layer.ffn_down_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {  n_ff_exp * hparams.n_expert_shared, n_embd});
-                            layer.ffn_up_shexp   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_SHEXP,   "weight", i), {n_embd,   n_ff_exp * hparams.n_expert_shared});
+                            layer.ffn_gate_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared});
+                            layer.ffn_down_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {        n_ff_exp * n_expert_shared, n_embd});
+                            layer.ffn_up_shexp   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_SHEXP,   "weight", i), {n_embd, n_ff_exp * n_expert_shared});
                         }
                     }
                 } break;
+            case LLM_ARCH_BITNET:
+                {
+                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+
+                    // output
+                    {
+                        model.output_norm = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
+                    }
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        ggml_context * ctx_layer = ctx_for_layer(i);
+                        ggml_context * ctx_split = ctx_for_layer_split(i);
+
+                        auto & layer = model.layers[i];
+
+                        layer.attn_norm     = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM,     "weight", i), {n_embd});
+                        layer.attn_sub_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_SUB_NORM, "weight", i), {n_embd});
+
+                        layer.wq       = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd});
+                        layer.wq_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q,   "scale",  i), {1});
+                        layer.wk       = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa});
+                        layer.wk_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K,   "scale",  i), {1});
+                        layer.wv       = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa});
+                        layer.wv_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V,   "scale",  i), {1});
+                        layer.wo       = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
+                        layer.wo_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "scale",  i), {1});
+
+                        layer.ffn_norm     = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM,     "weight", i), {n_embd});
+                        layer.ffn_sub_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_SUB_NORM, "weight", i), {n_ff});
+
+                        layer.ffn_gate       = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
+                        layer.ffn_gate_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE, "scale",  i), {1});
+                        layer.ffn_down       = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
+                        layer.ffn_down_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "scale",  i), {1});
+                        layer.ffn_up         = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff});
+                        layer.ffn_up_scale   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP,   "scale",  i), {1});
+                    }
+                } break;
+            case LLM_ARCH_T5:
+                {
+                    const auto n_rel_attn_bkts = hparams.n_rel_attn_bkts;
+
+                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+
+                    // output
+                    {
+                        model.output_norm_enc = ml.create_tensor(ctx_output, tn(LLM_TENSOR_ENC_OUTPUT_NORM, "weight"), {n_embd});
+                        model.output_norm     = ml.create_tensor(ctx_output, tn(LLM_TENSOR_DEC_OUTPUT_NORM, "weight"), {n_embd});
+
+                        model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        // if output is NULL, init from the input tok embed
+                        if (model.output == NULL) {
+                            model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
+                        }
+                    }
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        ggml_context * ctx_layer = ctx_for_layer(i);
+                        ggml_context * ctx_split = ctx_for_layer_split(i);
+
+                        auto & layer = model.layers[i];
+
+                        layer.attn_norm_enc  = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ENC_ATTN_NORM,  "weight", i), {n_embd});
+                        layer.attn_rel_b_enc = ml.create_tensor(ctx_input, tn(LLM_TENSOR_ENC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, llama_model_loader::TENSOR_NOT_REQUIRED);
+
+                        layer.wq_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_Q,   "weight", i), {n_embd, n_embd_k_gqa});
+                        layer.wk_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa});
+                        layer.wv_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa});
+                        layer.wo_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd});
+
+                        layer.ffn_norm_enc = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ENC_FFN_NORM, "weight", i), {n_embd});
+                        layer.ffn_gate_enc = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ENC_FFN_GATE, "weight", i), {n_embd,   n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.ffn_down_enc = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_FFN_DOWN, "weight", i), {  n_ff, n_embd});
+                        layer.ffn_up_enc   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ENC_FFN_UP,   "weight", i), {n_embd,   n_ff});
+
+                        layer.attn_norm  = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_DEC_ATTN_NORM,  "weight", i), {n_embd});
+                        layer.attn_rel_b = ml.create_tensor(ctx_input, tn(LLM_TENSOR_DEC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, llama_model_loader::TENSOR_NOT_REQUIRED);
+
+                        layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_ATTN_Q,   "weight", i), {n_embd, n_embd_k_gqa});
+                        layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa});
+                        layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa});
+                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd});
+
+                        layer.attn_norm_cross  = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_DEC_CROSS_ATTN_NORM,  "weight", i), {n_embd});
+                        // this tensor seems to be unused in HF transformers implementation
+                        layer.attn_rel_b_cross = ml.create_tensor(ctx_input, tn(LLM_TENSOR_DEC_CROSS_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, llama_model_loader::TENSOR_NOT_REQUIRED);
+
+                        layer.wq_cross = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_CROSS_ATTN_Q,   "weight", i), {n_embd, n_embd_k_gqa});
+                        layer.wk_cross = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_CROSS_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa});
+                        layer.wv_cross = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_CROSS_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa});
+                        layer.wo_cross = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_CROSS_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd});
+
+                        layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_DEC_FFN_NORM, "weight", i), {n_embd});
+                        layer.ffn_gate = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_DEC_FFN_GATE, "weight", i), {n_embd,   n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_FFN_DOWN, "weight", i), {  n_ff, n_embd});
+                        layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_DEC_FFN_UP,   "weight", i), {n_embd,   n_ff});
+                    }
+                } break;
+            case LLM_ARCH_JAIS:
+                {
+                    model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
+
+                    // Output
+                    {
+                        model.output_norm   = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
+                        model.output_norm_b = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd});
+                        model.output        = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab});
+                    }
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        ggml_context * ctx_layer = ctx_for_layer(i);
+                        ggml_context * ctx_split = ctx_for_layer_split(i);
+
+                        auto & layer = model.layers[i];
+
+                        layer.attn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM,   "weight", i), {n_embd});
+                        layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM,   "bias", i),   {n_embd});
+
+                        layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa});
+                        layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd + 2*n_embd_gqa});
+
+                        layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
+                        layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i),   {n_embd});
+
+                        layer.ffn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
+                        layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd});
+
+                        layer.ffn_down   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
+                        layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i),   {n_embd});
+
+                        layer.ffn_gate   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE,   "weight", i), {n_embd, n_ff});
+                        layer.ffn_gate_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE,   "bias", i),   {n_ff});
+
+                        layer.ffn_up     = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff});
+                        layer.ffn_up_b   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP,   "bias", i),   {n_ff});
+                    }
+                } break;
+            case LLM_ARCH_CHATGLM:
+                {
+                    model.tok_embd   = ml.create_tensor(ctx_input,  tn(LLM_TENSOR_TOKEN_EMBD,      "weight"), {n_embd, n_vocab});
+
+                    // output
+                    {
+                        model.output_norm   = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
+                        model.output        = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab});
+                    }
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        ggml_context * ctx_layer = ctx_for_layer(i);
+                        ggml_context * ctx_split = ctx_for_layer_split(i);
+
+                        auto & layer = model.layers[i];
+
+                        layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
+
+                        layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + (hparams.n_embd_head_k << 2)});
+                        layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd + (hparams.n_embd_head_k << 2)});
+
+                        layer.wo   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
+
+                        layer.ffn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
+
+                        layer.ffn_up     = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff * 2});
+
+                        layer.ffn_down   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
+                    }
+                } break;
             default:
                 throw std::runtime_error("unknown architecture");
         }
@@ -6622,16 +7706,6 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam
         }
 #endif
 
-#ifdef GGML_USE_SYCL
-        if (params.split_mode == LLAMA_SPLIT_MODE_NONE) {
-            ggml_backend_sycl_set_single_device_mode(params.main_gpu);
-            //SYCL use device index (0, 1, 2) directly, uer input device id, then convert to device index.
-            params.main_gpu = ggml_backend_sycl_get_device_index(params.main_gpu);
-        } else {
-            ggml_backend_sycl_set_mul_device_mode();
-        }
-#endif
-
         if (!llm_load_tensors(
             ml, model, params.n_gpu_layers, params.split_mode,  params.main_gpu, params.tensor_split, params.use_mlock,
             params.progress_callback, params.progress_callback_user_data
@@ -6657,6 +7731,7 @@ enum llm_ffn_op_type {
     LLM_FFN_GELU,
     LLM_FFN_RELU,
     LLM_FFN_RELU_SQR,
+    LLM_FFN_SWIGLU,
 };
 
 enum llm_ffn_gate_type {
@@ -6711,8 +7786,8 @@ static void llm_build_kv_store(
                     int64_t   il) {
     const int64_t n_ctx = cparams.n_ctx;
 
-    const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
-    const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
+    const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
+    const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
 
     GGML_ASSERT(kv.size == n_ctx);
 
@@ -6743,6 +7818,58 @@ static void llm_build_kv_store(
     ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur, v_cache_view));
 }
 
+// do mat_mul, while optionally apply lora
+static struct ggml_tensor * llm_build_lora_mm(
+        struct llama_context & lctx,
+         struct ggml_context * ctx0,
+          struct ggml_tensor * w,
+          struct ggml_tensor * cur) {
+    struct ggml_tensor * res = ggml_mul_mat(ctx0, w, cur);
+    for (auto & it : lctx.lora_adapters) {
+        struct llama_lora_weight * lora = it.first->get_weight(w);
+        if (lora == nullptr) {
+            continue;
+        }
+        const float alpha = it.first->alpha;
+        const float rank  = (float) lora->b->ne[0];
+        const float scale = alpha ? it.second * alpha / rank : it.second;
+        struct ggml_tensor * ab_cur = ggml_mul_mat(
+            ctx0, lora->b,
+            ggml_mul_mat(ctx0, lora->a, cur)
+        );
+        ab_cur = ggml_scale(ctx0, ab_cur, scale);
+        res = ggml_add(ctx0, res, ab_cur);
+    }
+    return res;
+}
+
+// do mat_mul_id, while optionally apply lora
+static struct ggml_tensor * llm_build_lora_mm_id(
+        struct llama_context & lctx,
+         struct ggml_context * ctx0,
+          struct ggml_tensor * w,   // struct ggml_tensor * as
+          struct ggml_tensor * cur, // struct ggml_tensor * b
+          struct ggml_tensor * ids) {
+    struct ggml_tensor * res = ggml_mul_mat_id(ctx0, w, cur, ids);
+    for (auto & it : lctx.lora_adapters) {
+        struct llama_lora_weight * lora = it.first->get_weight(w);
+        if (lora == nullptr) {
+            continue;
+        }
+        const float alpha = it.first->alpha;
+        const float rank  = (float) lora->b->ne[0];
+        const float scale = alpha ? it.second * alpha / rank : it.second;
+        struct ggml_tensor * ab_cur = ggml_mul_mat_id(
+            ctx0, lora->b,
+            ggml_mul_mat_id(ctx0, lora->a, cur, ids),
+            ids
+        );
+        ab_cur = ggml_scale(ctx0, ab_cur, scale);
+        res = ggml_add(ctx0, res, ab_cur);
+    }
+    return res;
+}
+
 static struct ggml_tensor * llm_build_norm(
         struct ggml_context * ctx,
          struct ggml_tensor * cur,
@@ -6777,19 +7904,23 @@ static struct ggml_tensor * llm_build_norm(
 
 static struct ggml_tensor * llm_build_ffn(
         struct ggml_context * ctx,
+       struct llama_context & lctx,
          struct ggml_tensor * cur,
          struct ggml_tensor * up,
          struct ggml_tensor * up_b,
+         struct ggml_tensor * up_s,
          struct ggml_tensor * gate,
          struct ggml_tensor * gate_b,
+         struct ggml_tensor * gate_s,
          struct ggml_tensor * down,
          struct ggml_tensor * down_b,
+         struct ggml_tensor * down_s,
          struct ggml_tensor * act_scales,
             llm_ffn_op_type   type_op,
           llm_ffn_gate_type   type_gate,
          const llm_build_cb & cb,
                         int   il) {
-    struct ggml_tensor * tmp = up ? ggml_mul_mat(ctx, up, cur) : cur;
+    struct ggml_tensor * tmp = up ? llm_build_lora_mm(lctx, ctx, up, cur) : cur;
     cb(tmp, "ffn_up", il);
 
     if (up_b) {
@@ -6797,16 +7928,21 @@ static struct ggml_tensor * llm_build_ffn(
         cb(tmp, "ffn_up_b", il);
     }
 
+    if (up_s) {
+        tmp = ggml_mul(ctx, tmp, up_s);
+        cb(tmp, "ffn_up_s", il);
+    }
+
     if (gate) {
         switch (type_gate) {
             case LLM_FFN_SEQ:
                 {
-                    cur = ggml_mul_mat(ctx, gate, tmp);
+                    cur = llm_build_lora_mm(lctx, ctx, gate, tmp);
                     cb(cur, "ffn_gate", il);
                 } break;
             case LLM_FFN_PAR:
                 {
-                    cur = ggml_mul_mat(ctx, gate, cur);
+                    cur = llm_build_lora_mm(lctx, ctx, gate, cur);
                     cb(cur, "ffn_gate", il);
                 } break;
         }
@@ -6815,6 +7951,12 @@ static struct ggml_tensor * llm_build_ffn(
             cur = ggml_add(ctx, cur, gate_b);
             cb(cur, "ffn_gate_b", il);
         }
+
+        if (gate_s) {
+            cur = ggml_mul(ctx, cur, gate_s);
+            cb(cur, "ffn_gate_s", il);
+        }
+
     } else {
         cur = tmp;
     }
@@ -6847,6 +7989,19 @@ static struct ggml_tensor * llm_build_ffn(
                 cur = ggml_sqr(ctx, cur);
                 cb(cur, "ffn_sqr(relu)", il);
             } break;
+        case LLM_FFN_SWIGLU:
+            {
+                // Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
+                int64_t split_point = cur->ne[0] / 2;
+                struct ggml_tensor * x0 = ggml_cont(ctx, ggml_view_2d(ctx, cur, split_point, cur->ne[1], cur->nb[1], 0));
+                struct ggml_tensor * x1 = ggml_cont(ctx, ggml_view_2d(ctx, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
+
+                x0 = ggml_silu(ctx, x0);
+                cb(cur, "ffn_silu", il);
+
+                cur = ggml_mul(ctx, x0, x1);
+                cb(cur, "ffn_mul", il);
+            } break;
     }
 
     if (type_gate == LLM_FFN_PAR) {
@@ -6854,7 +8009,10 @@ static struct ggml_tensor * llm_build_ffn(
         cb(cur, "ffn_gate_par", il);
     }
 
-    cur = ggml_mul_mat(ctx, down, cur);
+    if (down) {
+        cur = llm_build_lora_mm(lctx, ctx, down, cur);
+    }
+
     if (down_b) {
         cb(cur, "ffn_down", il);
     }
@@ -6863,11 +8021,17 @@ static struct ggml_tensor * llm_build_ffn(
         cur = ggml_add(ctx, cur, down_b);
     }
 
+    if (down_s) {
+        cur = ggml_mul(ctx, cur, down_s);
+        cb(cur, "ffn_down_s", il);
+    }
+
     return cur;
 }
 
 static struct ggml_tensor * llm_build_moe_ffn(
         struct ggml_context * ctx,
+       struct llama_context & lctx,
          struct ggml_tensor * cur,
          struct ggml_tensor * gate_inp,
          struct ggml_tensor * up_exps,
@@ -6884,7 +8048,7 @@ static struct ggml_tensor * llm_build_moe_ffn(
     int64_t n_embd = cur->ne[0];
     int64_t n_tokens = cur->ne[1];
 
-    ggml_tensor * logits = ggml_mul_mat(ctx, gate_inp, cur); // [n_expert, n_tokens]
+    ggml_tensor * logits = llm_build_lora_mm(lctx, ctx, gate_inp, cur); // [n_expert, n_tokens]
     cb(logits, "ffn_moe_logits", il);
 
     ggml_tensor * probs = ggml_soft_max(ctx, logits); // [n_expert, n_tokens]
@@ -6916,10 +8080,10 @@ static struct ggml_tensor * llm_build_moe_ffn(
     }
 
     cur = ggml_reshape_3d(ctx, cur, n_embd, 1, n_tokens);
-    ggml_tensor * up = ggml_mul_mat_id(ctx, up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
+    ggml_tensor * up = llm_build_lora_mm_id(lctx, ctx, up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
     cb(up, "ffn_moe_up", il);
 
-    ggml_tensor * gate = ggml_mul_mat_id(ctx, gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
+    ggml_tensor * gate = llm_build_lora_mm_id(lctx, ctx, gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
     cb(gate, "ffn_moe_gate", il);
 
     switch (type_op) {
@@ -6934,13 +8098,13 @@ static struct ggml_tensor * llm_build_moe_ffn(
                 cb(gate, "ffn_moe_gelu", il);
             } break;
         default:
-            GGML_ASSERT(false);
+            GGML_ABORT("fatal error");
     }
 
     ggml_tensor * par = ggml_mul(ctx, up, gate); // [n_ff, n_expert_used, n_tokens]
     cb(par, "ffn_moe_gate_par", il);
 
-    ggml_tensor * experts = ggml_mul_mat_id(ctx, down_exps, par, selected_experts); // [n_embd, n_expert_used, n_tokens]
+    ggml_tensor * experts = llm_build_lora_mm_id(lctx, ctx, down_exps, par, selected_experts); // [n_embd, n_expert_used, n_tokens]
     cb(experts, "ffn_moe_down", il);
 
     experts = ggml_mul(ctx, experts, weights);
@@ -6968,9 +8132,7 @@ static struct ggml_tensor * llm_build_moe_ffn(
 
 static struct ggml_tensor * llm_build_kqv(
         struct ggml_context * ctx,
-          const llama_model & model,
-        const llama_hparams & hparams,
-        const llama_cparams & cparams,
+       struct llama_context & lctx,
        const llama_kv_cache & kv,
          struct ggml_cgraph * graph,
          struct ggml_tensor * wo,
@@ -6982,13 +8144,17 @@ static struct ggml_tensor * llm_build_kqv(
                     float     kq_scale,
          const llm_build_cb & cb,
                     int       il) {
+    const llama_model   & model   = lctx.model;
+    const llama_hparams & hparams = lctx.model.hparams;
+    const llama_cparams & cparams = lctx.cparams;
+
     const int64_t n_ctx         = cparams.n_ctx;
-    const int64_t n_head        = hparams.n_head;
-    const int64_t n_head_kv     = hparams.n_head_kv;
+    const int64_t n_head        = hparams.n_head(il);
+    const int64_t n_head_kv     = hparams.n_head_kv(il);
     const int64_t n_embd_head_k = hparams.n_embd_head_k;
-    const int64_t n_embd_k_gqa  = hparams.n_embd_k_gqa();
+    const int64_t n_embd_k_gqa  = hparams.n_embd_k_gqa(il);
     const int64_t n_embd_head_v = hparams.n_embd_head_v;
-    const int64_t n_embd_v_gqa  = hparams.n_embd_v_gqa();
+    const int64_t n_embd_v_gqa  = hparams.n_embd_v_gqa(il);
 
     struct ggml_tensor * q = ggml_permute(ctx, q_cur, 0, 2, 1, 3);
     cb(q, "q", il);
@@ -7047,6 +8213,12 @@ static struct ggml_tensor * llm_build_kqv(
             kq = ggml_scale(ctx, kq, 30);
         }
 
+        if (hparams.attn_soft_cap) {
+            kq = ggml_scale(ctx, kq, 1.0f / hparams.f_attn_logit_softcapping);
+            kq = ggml_tanh(ctx, kq);
+            kq = ggml_scale(ctx, kq, hparams.f_attn_logit_softcapping);
+        }
+
         kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
         cb(kq, "kq_soft_max_ext", il);
 
@@ -7073,7 +8245,10 @@ static struct ggml_tensor * llm_build_kqv(
 
     ggml_build_forward_expand(graph, cur);
 
-    cur = ggml_mul_mat(ctx, wo, cur);
+    if (wo) {
+        cur = llm_build_lora_mm(lctx, ctx, wo, cur);
+    }
+
     if (wo_b) {
         cb(cur, "kqv_wo", il);
     }
@@ -7087,9 +8262,7 @@ static struct ggml_tensor * llm_build_kqv(
 
 static struct ggml_tensor * llm_build_kv(
         struct ggml_context * ctx,
-          const llama_model & model,
-        const llama_hparams & hparams,
-        const llama_cparams & cparams,
+       struct llama_context & lctx,
        const llama_kv_cache & kv,
          struct ggml_cgraph * graph,
          struct ggml_tensor * wo,
@@ -7104,6 +8277,8 @@ static struct ggml_tensor * llm_build_kv(
                     float     kq_scale,
          const llm_build_cb & cb,
                     int       il) {
+    const llama_hparams & hparams = lctx.model.hparams;
+    const llama_cparams & cparams = lctx.cparams;
 
     // these nodes are added to the graph together so that they are not reordered
     // by doing so, the number of splits in the graph is reduced
@@ -7115,7 +8290,7 @@ static struct ggml_tensor * llm_build_kv(
 
     struct ggml_tensor * cur;
 
-    cur  = llm_build_kqv(ctx, model, hparams, cparams, kv, graph, wo, wo_b,
+    cur  = llm_build_kqv(ctx, lctx, kv, graph, wo, wo_b,
             q_cur, kq_mask, n_tokens, n_kv, kq_scale, cb, il);
     cb(cur, "kqv_out", il);
 
@@ -7155,6 +8330,7 @@ struct llm_build_context {
     const int32_t n_tokens;
     const int32_t n_kv;     // size of KV cache to consider (n_kv <= kv_self.size)
     const int32_t n_outputs;
+    const int32_t n_outputs_enc;
     const int32_t kv_head;  // index of where we store new KV data in the cache
     const int32_t n_ctx_orig;
 
@@ -7185,8 +8361,8 @@ struct llm_build_context {
         n_layer          (hparams.n_layer),
         n_rot            (hparams.n_rot),
         n_ctx            (cparams.n_ctx),
-        n_head           (hparams.n_head),
-        n_head_kv        (hparams.n_head_kv),
+        n_head           (hparams.n_head()),
+        n_head_kv        (hparams.n_head_kv()),
         n_embd_head_k    (hparams.n_embd_head_k),
         n_embd_k_gqa     (hparams.n_embd_k_gqa()),
         n_embd_head_v    (hparams.n_embd_head_v),
@@ -7204,6 +8380,7 @@ struct llm_build_context {
         n_tokens         (batch.n_tokens),
         n_kv             (worst_case ? kv_self.size : kv_self.n),
         n_outputs        (worst_case ? n_tokens : lctx.n_outputs),
+        n_outputs_enc    (worst_case ? n_tokens : lctx.embd_enc.size() / hparams.n_embd),
         kv_head          (worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head),
         n_ctx_orig       (cparams.n_ctx_orig_yarn),
         flash_attn       (cparams.flash_attn),
@@ -7223,17 +8400,21 @@ struct llm_build_context {
 
         ctx0 = ggml_init(params);
 
-        lctx.inp_tokens  = nullptr;
-        lctx.inp_embd    = nullptr;
-        lctx.inp_pos     = nullptr;
-        lctx.inp_out_ids = nullptr;
-        lctx.inp_KQ_mask = nullptr;
-        lctx.inp_K_shift = nullptr;
-        lctx.inp_mean    = nullptr;
-        lctx.inp_cls     = nullptr;
-        lctx.inp_s_copy  = nullptr;
-        lctx.inp_s_mask  = nullptr;
-        lctx.inp_s_seq   = nullptr;
+        lctx.inp_tokens      = nullptr;
+        lctx.inp_embd        = nullptr;
+        lctx.inp_pos         = nullptr;
+        lctx.inp_out_ids     = nullptr;
+        lctx.inp_KQ_mask     = nullptr;
+        lctx.inp_KQ_mask_swa = nullptr;
+        lctx.inp_K_shift     = nullptr;
+        lctx.inp_mean        = nullptr;
+        lctx.inp_cls         = nullptr;
+        lctx.inp_s_copy      = nullptr;
+        lctx.inp_s_mask      = nullptr;
+        lctx.inp_s_seq       = nullptr;
+        lctx.inp_pos_bucket    = nullptr;
+        lctx.inp_embd_enc      = nullptr;
+        lctx.inp_KQ_mask_cross = nullptr;
     }
 
     void free() {
@@ -7244,7 +8425,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_k_shift() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         GGML_ASSERT(kv_self.size == n_ctx);
 
@@ -7252,8 +8433,9 @@ struct llm_build_context {
         cb(lctx.inp_K_shift, "K_shift", -1);
         ggml_set_input(lctx.inp_K_shift);
 
-
         for (int il = 0; il < n_layer; ++il) {
+            const int64_t n_head_kv = hparams.n_head_kv(il);
+            const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
             struct ggml_tensor * rope_factors = build_rope_factors(il);
             struct ggml_tensor * tmp =
                 // we rotate only the first n_rot dimensions
@@ -7274,7 +8456,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_s_copy() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         GGML_ASSERT(kv_self.recurrent);
 
@@ -7297,7 +8479,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_defrag(const std::vector & ids) {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         for (uint32_t i = 0; i < ids.size(); ++i) {
             const uint32_t id = ids[i];
@@ -7313,6 +8495,9 @@ struct llm_build_context {
             }
 
             for (int il = 0; il < n_layer; ++il) {
+                const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
+                const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
+
                 ggml_tensor * view_k_src = ggml_view_2d(ctx0, kv_self.k_l[il],
                         n_embd_k_gqa, nm,
                         ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
@@ -7372,6 +8557,10 @@ struct llm_build_context {
         // choose long/short freq factors based on the context size
         const auto n_ctx_pre_seq = cparams.n_ctx / cparams.n_seq_max;
 
+        if (model.layers[il].rope_freqs != nullptr) {
+            return model.layers[il].rope_freqs;
+        }
+
         if (n_ctx_pre_seq > hparams.n_ctx_orig_yarn) {
             return model.layers[il].rope_long;
         }
@@ -7387,16 +8576,27 @@ struct llm_build_context {
     }
 
     struct ggml_tensor * build_inp_KQ_mask(bool causal = true) {
-        if (causal) {
-            lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv,     GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
-        } else {
-            lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
-        }
+        lctx.inp_KQ_mask = causal
+            ? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv,     GGML_PAD(n_tokens, GGML_KQ_MASK_PAD))
+            : ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
         cb(lctx.inp_KQ_mask, "KQ_mask", -1);
         ggml_set_input(lctx.inp_KQ_mask);
+
         return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_mask, GGML_TYPE_F16) : lctx.inp_KQ_mask;
     }
 
+    struct ggml_tensor * build_inp_KQ_mask_swa(bool causal = true) {
+        GGML_ASSERT(hparams.n_swa > 0);
+
+        lctx.inp_KQ_mask_swa = causal
+            ? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv,     GGML_PAD(n_tokens, GGML_KQ_MASK_PAD))
+            : ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
+        cb(lctx.inp_KQ_mask_swa, "KQ_mask_swa", -1);
+        ggml_set_input(lctx.inp_KQ_mask_swa);
+
+        return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_mask_swa, GGML_TYPE_F16) : lctx.inp_KQ_mask_swa;
+    }
+
     struct ggml_tensor * build_inp_mean() {
         lctx.inp_mean = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens);
         cb(lctx.inp_mean, "inp_mean", -1);
@@ -7432,8 +8632,99 @@ struct llm_build_context {
         return lctx.inp_s_seq;
     }
 
+    struct ggml_cgraph * append_pooling(struct ggml_cgraph * gf) {
+        // find result_norm tensor for input
+        struct ggml_tensor * inp = nullptr;
+        for (int i = gf->n_nodes - 1; i >= 0; --i) {
+            inp = gf->nodes[i];
+            if (strcmp(inp->name, "result_norm") == 0 || strcmp(inp->name, "result_embd") == 0) {
+                break;
+            } else {
+                inp = nullptr;
+            }
+        }
+        GGML_ASSERT(inp != nullptr && "missing result_norm/result_embd tensor");
+
+        struct ggml_tensor * cur;
+
+        switch (pooling_type) {
+            case LLAMA_POOLING_TYPE_MEAN:
+                {
+                    struct ggml_tensor * inp_mean = build_inp_mean();
+                    cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, inp)), inp_mean);
+                } break;
+            case LLAMA_POOLING_TYPE_CLS:
+            case LLAMA_POOLING_TYPE_LAST:
+                {
+                    struct ggml_tensor * inp_cls = build_inp_cls();
+                    cur = ggml_get_rows(ctx0, inp, inp_cls);
+                } break;
+            case LLAMA_POOLING_TYPE_NONE:
+                {
+                    cur = inp;
+                } break;
+            default:
+                {
+                    GGML_ABORT("unknown pooling type");
+                }
+        }
+
+        cb(cur, "result_embd_pooled", -1);
+
+        ggml_build_forward_expand(gf, cur);
+
+        return gf;
+    }
+
+    struct ggml_tensor * llm_build_pos_bucket(bool causal) {
+        if (causal) {
+            lctx.inp_pos_bucket = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_kv,     n_tokens);
+        } else {
+            lctx.inp_pos_bucket = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_tokens);
+        }
+
+        ggml_set_input(lctx.inp_pos_bucket);
+        cb(lctx.inp_pos_bucket, "pos_bucket", -1);
+
+        return lctx.inp_pos_bucket;
+    }
+
+    struct ggml_tensor * llm_build_pos_bias(struct ggml_tensor * pos_bucket, struct ggml_tensor * attn_rel_b) {
+        struct ggml_tensor * pos_bucket_1d = ggml_view_1d(ctx0, pos_bucket, pos_bucket->ne[0] * pos_bucket->ne[1], 0);
+        cb(pos_bucket_1d, "pos_bucket_1d", -1);
+
+        struct ggml_tensor * pos_bias = ggml_get_rows(ctx0, attn_rel_b, pos_bucket_1d);
+        cb(pos_bias, "pos_bias", -1);
+
+        pos_bias = ggml_view_3d(ctx0, pos_bias, pos_bias->ne[0], lctx.inp_pos_bucket->ne[0], lctx.inp_pos_bucket->ne[1], ggml_element_size(pos_bias) * pos_bias->ne[0], ggml_element_size(pos_bias) * pos_bias->ne[0] * lctx.inp_pos_bucket->ne[0],  0);
+        cb(pos_bias, "pos_bias", -1);
+
+        pos_bias = ggml_permute(ctx0, pos_bias, 2, 0, 1, 3);
+        cb(pos_bias, "pos_bias", -1);
+
+        pos_bias = ggml_cont(ctx0, pos_bias);
+        cb(pos_bias, "pos_bias", -1);
+
+        return pos_bias;
+    }
+
+    struct ggml_tensor * llm_build_inp_embd_enc() {
+        const int64_t n_embd = hparams.n_embd;
+        lctx.inp_embd_enc = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_outputs_enc);
+        ggml_set_input(lctx.inp_embd_enc);
+        cb(lctx.inp_embd_enc, "embd_enc", -1);
+        return lctx.inp_embd_enc;
+    }
+
+    struct ggml_tensor * llm_build_inp_KQ_mask_cross() {
+        lctx.inp_KQ_mask_cross = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_outputs_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
+        ggml_set_input(lctx.inp_KQ_mask_cross);
+        cb(lctx.inp_KQ_mask_cross, "KQ_mask_cross", -1);
+        return lctx.inp_KQ_mask_cross;
+    }
+
     struct ggml_cgraph * build_llama() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -7464,22 +8755,25 @@ struct llm_build_context {
 
             // self-attention
             {
+                // rope freq factors for llama3; may return nullptr for llama2 and other models
+                struct ggml_tensor * rope_factors = build_rope_factors(il);
+
                 // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
                 cb(Qcur, "Qcur", il);
                 if (model.layers[il].bq) {
                     Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
                     cb(Qcur, "Qcur", il);
                 }
 
-                struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
                 cb(Kcur, "Kcur", il);
                 if (model.layers[il].bk) {
                     Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
                     cb(Kcur, "Kcur", il);
                 }
 
-                struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
                 cb(Vcur, "Vcur", il);
                 if (model.layers[il].bv) {
                     Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
@@ -7487,20 +8781,20 @@ struct llm_build_context {
                 }
 
                 Qcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
+                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors,
                     n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Qcur, "Qcur", il);
 
                 Kcur = ggml_rope_ext(
-                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, rope_factors,
                     n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
@@ -7523,10 +8817,10 @@ struct llm_build_context {
                         LLM_NORM_RMS, cb, il);
                 cb(cur, "ffn_norm", il);
 
-                cur = llm_build_ffn(ctx0, cur,
-                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,
-                        model.layers[il].ffn_gate, model.layers[il].ffn_gate_b,
-                        model.layers[il].ffn_down, model.layers[il].ffn_down_b,
+                cur = llm_build_ffn(ctx0, lctx, cur,
+                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
+                        model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
+                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
                         NULL,
                         LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
                 cb(cur, "ffn_out", il);
@@ -7537,7 +8831,7 @@ struct llm_build_context {
                         LLM_NORM_RMS, cb, il);
                 cb(cur, "ffn_norm", il);
 
-                cur = llm_build_moe_ffn(ctx0, cur,
+                cur = llm_build_moe_ffn(ctx0, lctx, cur,
                         model.layers[il].ffn_gate_inp,
                         model.layers[il].ffn_up_exps,
                         model.layers[il].ffn_gate_exps,
@@ -7552,10 +8846,7 @@ struct llm_build_context {
             cur = ggml_add(ctx0, cur, ffn_inp);
             cb(cur, "ffn_out", il);
 
-            ggml_tensor * layer_dir = lctx.cvec.tensor_for(il);
-            if (layer_dir != nullptr) {
-                cur = ggml_add(ctx0, cur, layer_dir);
-            }
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
             cb(cur, "l_out", il);
 
             // input for next layer
@@ -7570,7 +8861,7 @@ struct llm_build_context {
         cb(cur, "result_norm", -1);
 
         // lm_head
-        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
 
         ggml_build_forward_expand(gf, cur);
@@ -7579,7 +8870,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_baichuan() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -7606,13 +8897,13 @@ struct llm_build_context {
 
             // self-attention
             {
-                struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
                 cb(Qcur, "Qcur", il);
 
-                struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
                 cb(Kcur, "Kcur", il);
 
-                struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
                 cb(Vcur, "Vcur", il);
 
                 switch (model.type) {
@@ -7633,12 +8924,12 @@ struct llm_build_context {
                         Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd/n_head, n_head, n_tokens);
                         break;
                     default:
-                        GGML_ASSERT(false);
+                        GGML_ABORT("fatal error");
                 }
                 cb(Qcur, "Qcur", il);
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, NULL,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
@@ -7660,16 +8951,17 @@ struct llm_build_context {
                         LLM_NORM_RMS, cb, il);
                 cb(cur, "ffn_norm", il);
 
-                cur = llm_build_ffn(ctx0, cur,
-                        model.layers[il].ffn_up,   NULL,
-                        model.layers[il].ffn_gate, NULL,
-                        model.layers[il].ffn_down, NULL,
+                cur = llm_build_ffn(ctx0, lctx, cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        model.layers[il].ffn_gate, NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
                         NULL,
                         LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
                 cb(cur, "ffn_out", il);
             }
 
             cur = ggml_add(ctx0, cur, ffn_inp);
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
             cb(cur, "l_out", il);
 
             // input for next layer
@@ -7684,7 +8976,7 @@ struct llm_build_context {
         cb(cur, "result_norm", -1);
 
         // lm_head
-        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
 
         ggml_build_forward_expand(gf, cur);
@@ -7693,7 +8985,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_xverse() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -7720,13 +9012,13 @@ struct llm_build_context {
 
             // self-attention
             {
-                struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
                 cb(Qcur, "Qcur", il);
 
-                struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
                 cb(Kcur, "Kcur", il);
 
-                struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
                 cb(Vcur, "Vcur", il);
 
                 Qcur = ggml_rope_ext(
@@ -7742,7 +9034,7 @@ struct llm_build_context {
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(Kcur, "Kcur", il);
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, NULL,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
@@ -7764,16 +9056,17 @@ struct llm_build_context {
                         LLM_NORM_RMS, cb, il);
                 cb(cur, "ffn_norm", il);
 
-                cur = llm_build_ffn(ctx0, cur,
-                        model.layers[il].ffn_up,   NULL,
-                        model.layers[il].ffn_gate, NULL,
-                        model.layers[il].ffn_down, NULL,
+                cur = llm_build_ffn(ctx0, lctx, cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        model.layers[il].ffn_gate, NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
                         NULL,
                         LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
                 cb(cur, "ffn_out", il);
             }
 
             cur = ggml_add(ctx0, cur, ffn_inp);
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
             cb(cur, "l_out", il);
 
             // input for next layer
@@ -7786,7 +9079,7 @@ struct llm_build_context {
         cb(cur, "result_norm", -1);
 
         // lm_head
-        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
 
         ggml_build_forward_expand(gf, cur);
@@ -7795,7 +9088,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_falcon() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
@@ -7835,7 +9128,7 @@ struct llm_build_context {
                     cur = attn_norm;
                 }
 
-                cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
+                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
                 cb(cur, "wqkv", il);
 
                 struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
@@ -7862,7 +9155,7 @@ struct llm_build_context {
                 );
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, NULL,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
@@ -7879,19 +9172,18 @@ struct llm_build_context {
 
             // feed forward
             {
-                cur = llm_build_ffn(ctx0, attn_norm, // !! use the attn norm, not the result
-                        model.layers[il].ffn_up,   NULL,
-                        NULL,                      NULL,
-                        model.layers[il].ffn_down, NULL,
+                cur = llm_build_ffn(ctx0, lctx, attn_norm, // !! use the attn norm, not the result
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        NULL,                      NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
                         NULL,
                         LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
                 cb(cur, "ffn_out", il);
             }
 
             cur = ggml_add(ctx0, cur, ffn_inp);
-            cb(cur, "l_out", il);
-
             cur = ggml_add(ctx0, cur, inpL);
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
             cb(cur, "l_out", il);
 
             // input for next layer
@@ -7907,7 +9199,7 @@ struct llm_build_context {
                 LLM_NORM, cb, -1);
         cb(cur, "result_norm", -1);
 
-        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
 
         ggml_build_forward_expand(gf, cur);
@@ -7916,7 +9208,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_grok() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -7952,21 +9244,21 @@ struct llm_build_context {
             // self-attention
             {
                 // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
                 cb(Qcur, "Qcur", il);
                 if (model.layers[il].bq) {
                     Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
                     cb(Qcur, "Qcur", il);
                 }
 
-                struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
                 cb(Kcur, "Kcur", il);
                 if (model.layers[il].bk) {
                     Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
                     cb(Kcur, "Kcur", il);
                 }
 
-                struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
                 cb(Vcur, "Vcur", il);
                 if (model.layers[il].bv) {
                     Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
@@ -7987,7 +9279,7 @@ struct llm_build_context {
                 );
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il);
             }
@@ -8019,7 +9311,7 @@ struct llm_build_context {
                     LLM_NORM_RMS, cb, il);
             cb(cur, "ffn_norm", il);
 
-            cur = llm_build_moe_ffn(ctx0, cur,
+            cur = llm_build_moe_ffn(ctx0, lctx, cur,
                     model.layers[il].ffn_gate_inp,
                     model.layers[il].ffn_up_exps,
                     model.layers[il].ffn_gate_exps,
@@ -8043,10 +9335,7 @@ struct llm_build_context {
             cur = ggml_add(ctx0, cur, ffn_inp);
             cb(cur, "ffn_out", il);
 
-            ggml_tensor * layer_dir = lctx.cvec.tensor_for(il);
-            if (layer_dir != nullptr) {
-                cur = ggml_add(ctx0, cur, layer_dir);
-            }
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
             cb(cur, "l_out", il);
 
             // input for next layer
@@ -8061,7 +9350,7 @@ struct llm_build_context {
         cb(cur, "result_norm", -1);
 
         // lm_head
-        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
 
         // Grok
         // multiply logits by output_multiplier_scale of 0.5773502691896257
@@ -8076,7 +9365,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_dbrx() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -8112,7 +9401,7 @@ struct llm_build_context {
                 struct ggml_tensor * Kcur = nullptr;
                 struct ggml_tensor * Vcur = nullptr;
 
-                cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
+                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
                 cb(cur, "wqkv", il);
 
                 cur = ggml_clamp(ctx0, cur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
@@ -8140,7 +9429,7 @@ struct llm_build_context {
                 );
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, NULL,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
@@ -8163,7 +9452,7 @@ struct llm_build_context {
                                  LLM_NORM, cb, il);
             cb(cur, "attn_out_norm", il);
 
-            cur = llm_build_moe_ffn(ctx0, cur,
+            cur = llm_build_moe_ffn(ctx0, lctx, cur,
                     model.layers[il].ffn_gate_inp,
                     model.layers[il].ffn_up_exps,
                     model.layers[il].ffn_gate_exps,
@@ -8177,10 +9466,7 @@ struct llm_build_context {
             cur = ggml_add(ctx0, cur, ffn_inp);
             cb(cur, "ffn_out", il);
 
-            ggml_tensor * layer_dir = lctx.cvec.tensor_for(il);
-            if (layer_dir != nullptr) {
-                cur = ggml_add(ctx0, cur, layer_dir);
-            }
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
             cb(cur, "l_out", il);
 
             // input for next layer
@@ -8195,7 +9481,7 @@ struct llm_build_context {
         cb(cur, "result_norm", -1);
 
         // lm_head
-        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
 
         cb(cur, "result_output", -1);
 
@@ -8205,7 +9491,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_starcoder() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
@@ -8237,7 +9523,7 @@ struct llm_build_context {
 
             // self-attention
             {
-                cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
+                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
                 cb(cur, "wqkv", il);
 
                 cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
@@ -8253,7 +9539,7 @@ struct llm_build_context {
 
                 Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
@@ -8277,17 +9563,21 @@ struct llm_build_context {
                         LLM_NORM, cb, il);
                 cb(cur, "ffn_norm", il);
 
-                cur = llm_build_ffn(ctx0, cur,
-                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,
-                        NULL,                      NULL,
-                        model.layers[il].ffn_down, model.layers[il].ffn_down_b,
+                cur = llm_build_ffn(ctx0, lctx, cur,
+                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
+                        NULL,                      NULL,                        NULL,
+                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
                         NULL,
                         LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
                 cb(cur, "ffn_out", il);
             }
 
-            inpL = ggml_add(ctx0, cur, ffn_inp);
-            cb(inpL, "l_out", il);
+            cur = ggml_add(ctx0, cur, ffn_inp);
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
         }
 
         cur = llm_build_norm(ctx0, inpL, hparams,
@@ -8296,7 +9586,7 @@ struct llm_build_context {
                 LLM_NORM, cb, -1);
         cb(cur, "result_norm", -1);
 
-        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
 
         ggml_build_forward_expand(gf, cur);
@@ -8305,7 +9595,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_refact() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -8328,13 +9618,13 @@ struct llm_build_context {
 
             // self-attention
             {
-                struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
                 cb(Qcur, "Qcur", il);
 
-                struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
                 cb(Kcur, "Kcur", il);
 
-                struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
                 cb(Vcur, "Vcur", il);
 
                 Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
@@ -8343,7 +9633,7 @@ struct llm_build_context {
                 Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
                 cb(Qcur, "Qcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, NULL,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
@@ -8365,16 +9655,17 @@ struct llm_build_context {
                         LLM_NORM_RMS, cb, il);
                 cb(cur, "ffn_norm", il);
 
-                cur = llm_build_ffn(ctx0, cur,
-                        model.layers[il].ffn_up,   NULL,
-                        model.layers[il].ffn_gate, NULL,
-                        model.layers[il].ffn_down, NULL,
+                cur = llm_build_ffn(ctx0, lctx, cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        model.layers[il].ffn_gate, NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
                         NULL,
                         LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
                 cb(cur, "ffn_out", il);
             }
 
             cur = ggml_add(ctx0, cur, ffn_inp);
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
             cb(cur, "l_out", il);
 
             // input for next layer
@@ -8389,7 +9680,7 @@ struct llm_build_context {
         cb(cur, "result_norm", -1);
 
         // lm_head
-        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
 
         ggml_build_forward_expand(gf, cur);
@@ -8398,7 +9689,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_bert() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
@@ -8412,8 +9703,6 @@ struct llm_build_context {
         if (model.arch != LLM_ARCH_JINA_BERT_V2) {
             inp_pos = build_inp_pos();
         }
-        struct ggml_tensor * inp_mean = build_inp_mean();
-        struct ggml_tensor * inp_cls  = build_inp_cls();
 
         // construct input embeddings (token, type, position)
         inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
@@ -8443,7 +9732,7 @@ struct llm_build_context {
 
             // self-attention
             if (model.arch == LLM_ARCH_BERT || model.arch == LLM_ARCH_JINA_BERT_V2) {
-                Qcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, cur), model.layers[il].bq);
+                Qcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur), model.layers[il].bq);
                 cb(Qcur, "Qcur", il);
 
                 if (model.layers[il].attn_q_norm) {
@@ -8453,7 +9742,7 @@ struct llm_build_context {
                             LLM_NORM, cb, il);
                 }
 
-                Kcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wk, cur), model.layers[il].bk);
+                Kcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur), model.layers[il].bk);
                 cb(Kcur, "Kcur", il);
 
                 if (model.layers[il].attn_k_norm) {
@@ -8462,14 +9751,14 @@ struct llm_build_context {
                             model.layers[il].attn_k_norm_b,
                             LLM_NORM, cb, il);
                 }
-                Vcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wv, cur), model.layers[il].bv);
+                Vcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur), model.layers[il].bv);
                 cb(Vcur, "Vcur", il);
 
                 Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
                 Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
             } else {
                 // compute Q and K and RoPE them
-                cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
+                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
                 cb(cur, "wqkv", il);
 
                 Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
@@ -8518,7 +9807,7 @@ struct llm_build_context {
 
             ggml_build_forward_expand(gf, cur);
 
-            cur = ggml_mul_mat(ctx0, model.layers[il].wo, cur);
+            cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, cur);
             if (model.layers[il].bo) {
                 cb(cur, "kqv_wo", il);
             }
@@ -8551,24 +9840,24 @@ struct llm_build_context {
 
             // feed-forward network
             if (model.arch == LLM_ARCH_BERT) {
-                cur = llm_build_ffn(ctx0, cur,
-                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,
-                        NULL,                      NULL,
-                        model.layers[il].ffn_down, model.layers[il].ffn_down_b,
+                cur = llm_build_ffn(ctx0, lctx, cur,
+                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
+                        NULL,                      NULL,                        NULL,
+                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
                         NULL,
                         LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
             } else if (model.arch == LLM_ARCH_JINA_BERT_V2) {
-                cur = llm_build_ffn(ctx0, cur,
-                        model.layers[il].ffn_up,   NULL,
-                        model.layers[il].ffn_gate, NULL,
-                        model.layers[il].ffn_down, model.layers[il].ffn_down_b,
+                cur = llm_build_ffn(ctx0, lctx, cur,
+                        model.layers[il].ffn_up,   NULL,                        NULL,
+                        model.layers[il].ffn_gate, NULL,                        NULL,
+                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
                         NULL,
                         LLM_FFN_GELU, LLM_FFN_PAR, cb, il);
             } else {
-                cur = llm_build_ffn(ctx0, cur,
-                        model.layers[il].ffn_up,   NULL,
-                        model.layers[il].ffn_gate, NULL,
-                        model.layers[il].ffn_down, NULL,
+                cur = llm_build_ffn(ctx0, lctx, cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        model.layers[il].ffn_gate, NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
                         NULL,
                         LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
             }
@@ -8588,35 +9877,13 @@ struct llm_build_context {
         cur = inpL;
         cb(cur, "result_embd", -1);
 
-        // pooling layer
-        switch (pooling_type) {
-            case LLAMA_POOLING_TYPE_NONE:
-                {
-                    // nop
-                } break;
-            case LLAMA_POOLING_TYPE_MEAN:
-                {
-                    cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, cur)), inp_mean);
-                    cb(cur, "result_embd_pooled", -1);
-                } break;
-            case LLAMA_POOLING_TYPE_CLS:
-                {
-                    cur = ggml_get_rows(ctx0, cur, inp_cls);
-                    cb(cur, "result_embd_pooled", -1);
-                } break;
-            case LLAMA_POOLING_TYPE_UNSPECIFIED:
-                {
-                    GGML_ASSERT(false && "Invalid pooling type");
-                } break;
-        }
-
         ggml_build_forward_expand(gf, cur);
 
         return gf;
     }
 
     struct ggml_cgraph * build_bloom() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
@@ -8645,7 +9912,7 @@ struct llm_build_context {
 
             // self-attention
             {
-                cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
+                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
                 cb(cur, "wqkv", il);
 
                 cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
@@ -8661,7 +9928,7 @@ struct llm_build_context {
 
                 Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
@@ -8685,17 +9952,21 @@ struct llm_build_context {
                         LLM_NORM, cb, il);
                 cb(cur, "ffn_norm", il);
 
-                cur = llm_build_ffn(ctx0, cur,
-                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,
-                        NULL,                      NULL,
-                        model.layers[il].ffn_down, model.layers[il].ffn_down_b,
+                cur = llm_build_ffn(ctx0, lctx, cur,
+                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
+                        NULL,                      NULL,                        NULL,
+                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
                         NULL,
                         LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
                 cb(cur, "ffn_out", il);
             }
 
-            inpL = ggml_add(ctx0, cur, ffn_inp);
-            cb(inpL, "l_out", il);
+            cur = ggml_add(ctx0, cur, ffn_inp);
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
         }
 
         cur = llm_build_norm(ctx0, inpL, hparams,
@@ -8704,7 +9975,7 @@ struct llm_build_context {
                 LLM_NORM, cb, -1);
         cb(cur, "result_norm", -1);
 
-        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
 
         ggml_build_forward_expand(gf, cur);
@@ -8713,7 +9984,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_mpt() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
@@ -8751,7 +10022,7 @@ struct llm_build_context {
             {
                 cur = attn_norm;
 
-                cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
+                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
                 cb(cur, "wqkv", il);
 
                 if (model.layers[il].bqkv){
@@ -8789,13 +10060,13 @@ struct llm_build_context {
                     Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
                     Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
 
-                    cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                    cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                             model.layers[il].wo, model.layers[il].bo,
                             Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
                 } else {
                     Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
 
-                    cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                    cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                             model.layers[il].wo, model.layers[il].bo,
                             Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
                 }
@@ -8819,16 +10090,17 @@ struct llm_build_context {
                         model.layers[il].ffn_norm_b,
                         LLM_NORM, cb, il);
                 cb(cur, "ffn_norm", il);
-                cur = llm_build_ffn(ctx0, cur,
-                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,
-                        NULL,                      NULL,
-                        model.layers[il].ffn_down, model.layers[il].ffn_down_b,
+                cur = llm_build_ffn(ctx0, lctx, cur,
+                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
+                        NULL,                      NULL,                        NULL,
+                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
                         model.layers[il].ffn_act,
                         LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
                 cb(cur, "ffn_out", il);
             }
 
             cur = ggml_add(ctx0, cur, ffn_inp);
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
             cb(cur, "l_out", il);
 
             // input for next layer
@@ -8843,7 +10115,7 @@ struct llm_build_context {
                 LLM_NORM, cb, -1);
         cb(cur, "result_norm", -1);
 
-        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
 
         ggml_build_forward_expand(gf, cur);
@@ -8883,21 +10155,21 @@ struct llm_build_context {
             // self-attention
             {
                 // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
                 cb(Qcur, "Qcur", il);
                 if (model.layers[il].bq) {
                     Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
                     cb(Qcur, "Qcur", il);
                 }
 
-                struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
                 cb(Kcur, "Kcur", il);
                 if (model.layers[il].bk) {
                     Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
                     cb(Kcur, "Kcur", il);
                 }
 
-                struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
                 cb(Vcur, "Vcur", il);
                 if (model.layers[il].bv) {
                     Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
@@ -8939,7 +10211,7 @@ struct llm_build_context {
                 );
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, NULL,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
@@ -8967,16 +10239,17 @@ struct llm_build_context {
                     // parallel residual
                     cur = inpSA;
                 }
-                cur = llm_build_ffn(ctx0, cur,
-                        model.layers[il].ffn_up,   NULL,
-                        model.layers[il].ffn_gate, NULL,
-                        model.layers[il].ffn_down, NULL,
+                cur = llm_build_ffn(ctx0, lctx, cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        model.layers[il].ffn_gate, NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
                         NULL,
                         LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
                 cb(cur, "ffn_out", il);
             }
 
             cur = ggml_add(ctx0, cur, ffn_inp);
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
             cb(cur, "l_out", il);
 
             // input for next layer
@@ -8992,7 +10265,7 @@ struct llm_build_context {
         cb(cur, "result_norm", -1);
 
         // lm_head
-        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
 
         ggml_build_forward_expand(gf, cur);
@@ -9001,7 +10274,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_qwen() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -9027,7 +10300,7 @@ struct llm_build_context {
 
             // self-attention
             {
-                cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
+                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
                 cb(cur, "wqkv", il);
 
                 cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
@@ -9057,7 +10330,7 @@ struct llm_build_context {
                 );
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, NULL,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
@@ -9079,16 +10352,17 @@ struct llm_build_context {
                         LLM_NORM_RMS, cb, il);
                 cb(cur, "ffn_norm", il);
 
-                cur = llm_build_ffn(ctx0, cur,
-                        model.layers[il].ffn_up,   NULL,
-                        model.layers[il].ffn_gate, NULL,
-                        model.layers[il].ffn_down, NULL,
+                cur = llm_build_ffn(ctx0, lctx, cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        model.layers[il].ffn_gate, NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
                         NULL,
                         LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
                 cb(cur, "ffn_out", il);
             }
 
             cur = ggml_add(ctx0, cur, ffn_inp);
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
             cb(cur, "l_out", il);
 
             // input for next layer
@@ -9103,7 +10377,7 @@ struct llm_build_context {
         cb(cur, "result_norm", -1);
 
         // lm_head
-        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
 
         ggml_build_forward_expand(gf, cur);
@@ -9112,7 +10386,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_qwen2() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -9141,17 +10415,17 @@ struct llm_build_context {
             // self-attention
             {
                 // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
                 cb(Qcur, "Qcur", il);
                 Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
                 cb(Qcur, "Qcur", il);
 
-                struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
                 cb(Kcur, "Kcur", il);
                 Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
                 cb(Kcur, "Kcur", il);
 
-                struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
                 cb(Vcur, "Vcur", il);
                 Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
                 cb(Vcur, "Vcur", il);
@@ -9170,7 +10444,7 @@ struct llm_build_context {
                 );
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
@@ -9191,15 +10465,16 @@ struct llm_build_context {
                     LLM_NORM_RMS, cb, il);
             cb(cur, "ffn_norm", il);
 
-            cur = llm_build_ffn(ctx0, cur,
-                    model.layers[il].ffn_up,   NULL,
-                    model.layers[il].ffn_gate, NULL,
-                    model.layers[il].ffn_down, NULL,
+            cur = llm_build_ffn(ctx0, lctx, cur,
+                    model.layers[il].ffn_up,   NULL, NULL,
+                    model.layers[il].ffn_gate, NULL, NULL,
+                    model.layers[il].ffn_down, NULL, NULL,
                     NULL,
                     LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
             cb(cur, "ffn_out", il);
 
             cur = ggml_add(ctx0, cur, ffn_inp);
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
             cb(cur, "l_out", il);
 
             // input for next layer
@@ -9214,7 +10489,7 @@ struct llm_build_context {
         cb(cur, "result_norm", -1);
 
         // lm_head
-        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
 
         ggml_build_forward_expand(gf, cur);
@@ -9223,7 +10498,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_qwen2moe() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -9255,17 +10530,17 @@ struct llm_build_context {
             // self_attention
             {
                 // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
                 cb(Qcur, "Qcur", il);
                 Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
                 cb(Qcur, "Qcur", il);
 
-                struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
                 cb(Kcur, "Kcur", il);
                 Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
                 cb(Kcur, "Kcur", il);
 
-                struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
                 cb(Vcur, "Vcur", il);
                 Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
                 cb(Vcur, "Vcur", il);
@@ -9284,7 +10559,7 @@ struct llm_build_context {
                 );
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
@@ -9307,7 +10582,7 @@ struct llm_build_context {
             cb(cur, "ffn_norm", il);
 
             ggml_tensor * moe_out =
-                    llm_build_moe_ffn(ctx0, cur,
+                    llm_build_moe_ffn(ctx0, lctx, cur,
                         model.layers[il].ffn_gate_inp,
                         model.layers[il].ffn_up_exps,
                         model.layers[il].ffn_gate_exps,
@@ -9320,17 +10595,17 @@ struct llm_build_context {
 
             // FFN shared expert
             {
-                ggml_tensor * cur_gate_inp = ggml_mul_mat(ctx0, model.layers[il].ffn_gate_inp_shexp, cur);
+                ggml_tensor * cur_gate_inp = llm_build_lora_mm(lctx, ctx0, model.layers[il].ffn_gate_inp_shexp, cur);
                 cb(cur_gate_inp, "ffn_shexp_gate_inp", il);
 
                 // sigmoid
                 ggml_tensor * cur_gate = ggml_div(ctx0, ggml_silu(ctx0, cur_gate_inp), cur_gate_inp);
                 cb(cur_gate, "ffn_shexp_gate", il);
 
-                ggml_tensor * cur_ffn = llm_build_ffn(ctx0, cur,
-                        model.layers[il].ffn_up_shexp,   NULL,
-                        model.layers[il].ffn_gate_shexp, NULL,
-                        model.layers[il].ffn_down_shexp, NULL,
+                ggml_tensor * cur_ffn = llm_build_ffn(ctx0, lctx, cur,
+                        model.layers[il].ffn_up_shexp,   NULL, NULL,
+                        model.layers[il].ffn_gate_shexp, NULL, NULL,
+                        model.layers[il].ffn_down_shexp, NULL, NULL,
                         NULL,
                         LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
                 cb(cur_ffn, "ffn_shexp", il);
@@ -9345,6 +10620,7 @@ struct llm_build_context {
             }
 
             cur = ggml_add(ctx0, cur, ffn_inp);
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
             cb(cur, "l_out", il);
 
             // input for next layer
@@ -9359,7 +10635,7 @@ struct llm_build_context {
         cb(cur, "result_norm", -1);
 
         // lm_head
-        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
 
         ggml_build_forward_expand(gf, cur);
@@ -9368,7 +10644,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_phi2() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
@@ -9401,7 +10677,7 @@ struct llm_build_context {
                 struct ggml_tensor * Vcur = nullptr;
 
                 if (model.layers[il].wqkv) {
-                    cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, attn_norm_output);
+                    cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, attn_norm_output);
                     cb(cur, "wqkv", il);
 
                     cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
@@ -9411,9 +10687,9 @@ struct llm_build_context {
                     Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
                     Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
                 } else {
-                    Qcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, attn_norm_output), model.layers[il].bq);
-                    Kcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wk, attn_norm_output), model.layers[il].bk);
-                    Vcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wv, attn_norm_output), model.layers[il].bv);
+                    Qcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, attn_norm_output), model.layers[il].bq);
+                    Kcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, attn_norm_output), model.layers[il].bk);
+                    Vcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, attn_norm_output), model.layers[il].bv);
                 }
 
                 cb(Qcur, "Qcur", il);
@@ -9440,7 +10716,7 @@ struct llm_build_context {
                 );
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il);
             }
@@ -9455,21 +10731,21 @@ struct llm_build_context {
 
             // FF
             {
-                ffn_output = llm_build_ffn(ctx0, attn_norm_output,
-                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,
-                        NULL,                      NULL,
-                        model.layers[il].ffn_down, model.layers[il].ffn_down_b,
+                ffn_output = llm_build_ffn(ctx0, lctx, attn_norm_output,
+                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
+                        NULL,                      NULL,                        NULL,
+                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
                         NULL,
                         LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
                 cb(ffn_output, "ffn_out", il);
             }
 
             cur = ggml_add(ctx0, cur, ffn_output);
-            cb(cur, "l_out", il);
-
             cur = ggml_add(ctx0, cur, inpL);
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
             cb(cur, "l_out", il);
 
+            // input for next layer
             inpL = cur;
         }
 
@@ -9479,7 +10755,7 @@ struct llm_build_context {
                 LLM_NORM, cb, -1);
         cb(cur, "result_norm", -1);
 
-        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output_no_bias", -1);
 
         cur = ggml_add(ctx0, cur, model.output_b);
@@ -9489,7 +10765,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_phi3() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
@@ -9504,7 +10780,7 @@ struct llm_build_context {
         struct ggml_tensor * inp_pos = build_inp_pos();
 
         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
-        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
+        struct ggml_tensor * KQ_mask_swa = build_inp_KQ_mask_swa();
 
         for (int il = 0; il < n_layer; ++il) {
             auto residual = inpL;
@@ -9525,7 +10801,7 @@ struct llm_build_context {
                 struct ggml_tensor * Vcur = nullptr;
 
                 if (model.layers[il].wqkv) {
-                    cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, attn_norm_output);
+                    cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, attn_norm_output);
                     cb(cur, "wqkv", il);
 
                     Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0 * sizeof(float) * (n_embd)));
@@ -9533,9 +10809,9 @@ struct llm_build_context {
                     Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa)));
                 }
                 else {
-                    Qcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, attn_norm_output), model.layers[il].bq);
-                    Kcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wk, attn_norm_output), model.layers[il].bk);
-                    Vcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wv, attn_norm_output), model.layers[il].bv);
+                    Qcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, attn_norm_output), model.layers[il].bq);
+                    Kcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, attn_norm_output), model.layers[il].bk);
+                    Vcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, attn_norm_output), model.layers[il].bv);
                 }
 
                 cb(Qcur, "Qcur", il);
@@ -9560,9 +10836,9 @@ struct llm_build_context {
                 );
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask_swa, n_tokens, kv_head, n_kv, 1.0f, cb, il);
             }
 
             if (il == n_layer - 1) {
@@ -9584,25 +10860,20 @@ struct llm_build_context {
             // special-case: the up and gate tensors are merged into a single tensor
             // TOOD: support into llm_build_ffn
             {
-                struct ggml_tensor* up = ggml_mul_mat(ctx0, model.layers[il].ffn_up, cur);
-                cb(up, "ffn_up", il);
-
-                auto g = ggml_cont(ctx0, ggml_view_2d(ctx0, up, up->ne[0] / 2, up->ne[1], ggml_row_size(up->type, up->ne[0]), 0));
-                auto y = ggml_cont(ctx0, ggml_view_2d(ctx0, up, up->ne[0] / 2, up->ne[1], ggml_row_size(up->type, up->ne[0]), up->nb[1] / 2));
-
-                y = ggml_mul(ctx0, y, ggml_silu(ctx0, g));
-                cb(y, "ffn_gate", il);
-
-                auto down = ggml_mul_mat(ctx0, model.layers[il].ffn_down, y);
-                cb(down, "ffn_down", il);
-
-                cur = down;
+                cur = llm_build_ffn(ctx0, lctx, cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        NULL,                      NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
+                        NULL,
+                        LLM_FFN_SWIGLU, LLM_FFN_SEQ, cb, il);
                 cb(cur, "ffn_out", il);
             }
 
             cur = ggml_add(ctx0, residual, cur);
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
             cb(cur, "l_out", il);
 
+            // input for next layer
             inpL = cur;
         }
 
@@ -9612,7 +10883,7 @@ struct llm_build_context {
             LLM_NORM_RMS, cb, -1);
         cb(cur, "result_norm", -1);
 
-        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
 
         ggml_build_forward_expand(gf, cur);
@@ -9652,13 +10923,13 @@ struct llm_build_context {
             // self-attention
             {
                 // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
                 cb(Qcur, "Qcur", il);
 
-                struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
                 cb(Kcur, "Kcur", il);
 
-                struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
                 cb(Vcur, "Vcur", il);
 
                 Qcur = ggml_rope_ext(
@@ -9673,7 +10944,7 @@ struct llm_build_context {
                         ext_factor, attn_factor, beta_fast, beta_slow);
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, NULL,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
@@ -9691,19 +10962,18 @@ struct llm_build_context {
 
             // feed-forward network
             {
-                cur = llm_build_ffn(ctx0, cur,
-                        model.layers[il].ffn_up, NULL,
-                        model.layers[il].ffn_gate, NULL,
-                        model.layers[il].ffn_down, NULL,
+                cur = llm_build_ffn(ctx0, lctx, cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        model.layers[il].ffn_gate, NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
                         NULL,
                         LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
                 cb(cur, "ffn_out", il);
             }
 
             cur = ggml_add(ctx0, cur, sa_out);
-            cb(cur, "l_out", il);
-
             cur = ggml_add(ctx0, cur, inpL);
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
             cb(cur, "l_out", il);
 
             // input for next layer
@@ -9718,7 +10988,7 @@ struct llm_build_context {
         cb(cur, "result_norm", -1);
 
         // lm_head
-        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
 
         ggml_build_forward_expand(gf, cur);
@@ -9727,7 +10997,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_gpt2() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
@@ -9760,7 +11030,7 @@ struct llm_build_context {
 
             // self-attention
             {
-                cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
+                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
                 cb(cur, "wqkv", il);
 
                 cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
@@ -9776,7 +11046,7 @@ struct llm_build_context {
 
                 Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
@@ -9800,17 +11070,21 @@ struct llm_build_context {
                         LLM_NORM, cb, il);
                 cb(cur, "ffn_norm", il);
 
-                cur = llm_build_ffn(ctx0, cur,
-                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,
-                        NULL,                      NULL,
-                        model.layers[il].ffn_down, model.layers[il].ffn_down_b,
+                cur = llm_build_ffn(ctx0, lctx, cur,
+                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
+                        NULL,                      NULL,                        NULL,
+                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
                         NULL,
                         LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
                 cb(cur, "ffn_out", il);
             }
 
-            inpL = ggml_add(ctx0, cur, ffn_inp);
-            cb(inpL, "l_out", il);
+            cur = ggml_add(ctx0, cur, ffn_inp);
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
         }
 
         cur = llm_build_norm(ctx0, inpL, hparams,
@@ -9819,7 +11093,7 @@ struct llm_build_context {
                 LLM_NORM, cb, -1);
         cb(cur, "result_norm", -1);
 
-        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
 
         ggml_build_forward_expand(gf, cur);
@@ -9828,7 +11102,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_codeshell() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
@@ -9855,7 +11129,7 @@ struct llm_build_context {
 
             // self-attention
             {
-                cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
+                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
                 cb(cur, "wqkv", il);
 
                 cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
@@ -9883,7 +11157,7 @@ struct llm_build_context {
                 );
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
@@ -9907,17 +11181,21 @@ struct llm_build_context {
                         LLM_NORM, cb, il);
                 cb(cur, "ffn_norm", il);
 
-                cur = llm_build_ffn(ctx0, cur,
-                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,
-                        NULL,                      NULL,
-                        model.layers[il].ffn_down, model.layers[il].ffn_down_b,
+                cur = llm_build_ffn(ctx0, lctx, cur,
+                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
+                        NULL,                      NULL,                        NULL,
+                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
                         NULL,
                         LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
                 cb(cur, "ffn_out", il);
             }
 
-            inpL = ggml_add(ctx0, cur, ffn_inp);
-            cb(inpL, "l_out", il);
+            cur = ggml_add(ctx0, cur, ffn_inp);
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
         }
 
         cur = llm_build_norm(ctx0, inpL, hparams,
@@ -9926,7 +11204,7 @@ struct llm_build_context {
                 LLM_NORM, cb, -1);
         cb(cur, "result_norm", -1);
 
-        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
 
         ggml_build_forward_expand(gf, cur);
@@ -9935,7 +11213,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_orion() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -9964,21 +11242,21 @@ struct llm_build_context {
             // self-attention
             {
                 // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
                 cb(Qcur, "Qcur", il);
                 // if (model.layers[il].bq) {
                 //     Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
                 //     cb(Qcur, "Qcur", il);
                 // }
 
-                struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
                 cb(Kcur, "Kcur", il);
                 // if (model.layers[il].bk) {
                 //     Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
                 //     cb(Kcur, "Kcur", il);
                 // }
 
-                struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
                 cb(Vcur, "Vcur", il);
                 // if (model.layers[il].bv) {
                 //     Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
@@ -9999,7 +11277,7 @@ struct llm_build_context {
                 );
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, NULL,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
@@ -10020,15 +11298,16 @@ struct llm_build_context {
                     LLM_NORM, cb, il);
             cb(cur, "ffn_norm", il);
 
-            cur = llm_build_ffn(ctx0, cur,
-                    model.layers[il].ffn_up,   NULL,
-                    model.layers[il].ffn_gate, NULL,
-                    model.layers[il].ffn_down, NULL,
+            cur = llm_build_ffn(ctx0, lctx, cur,
+                    model.layers[il].ffn_up,   NULL, NULL,
+                    model.layers[il].ffn_gate, NULL, NULL,
+                    model.layers[il].ffn_down, NULL, NULL,
                     NULL,
                     LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
             cb(cur, "ffn_out", il);
 
             cur = ggml_add(ctx0, cur, ffn_inp);
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
             cb(cur, "l_out", il);
 
             // input for next layer
@@ -10043,7 +11322,7 @@ struct llm_build_context {
         cb(cur, "result_norm", -1);
 
         // lm_head
-        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
 
         ggml_build_forward_expand(gf, cur);
@@ -10052,7 +11331,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_internlm2() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -10081,21 +11360,21 @@ struct llm_build_context {
             // self-attention
             {
                 // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
                 cb(Qcur, "Qcur", il);
                 if (model.layers[il].bq) {
                     Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
                     cb(Qcur, "Qcur", il);
                 }
 
-                struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
                 cb(Kcur, "Kcur", il);
                 if (model.layers[il].bk) {
                     Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
                     cb(Kcur, "Kcur", il);
                 }
 
-                struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
                 cb(Vcur, "Vcur", il);
                 if (model.layers[il].bv) {
                     Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
@@ -10116,7 +11395,7 @@ struct llm_build_context {
                 );
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
@@ -10137,15 +11416,16 @@ struct llm_build_context {
                     LLM_NORM_RMS, cb, il);
             cb(cur, "ffn_norm", il);
 
-            cur = llm_build_ffn(ctx0, cur,
-                    model.layers[il].ffn_up,   NULL,
-                    model.layers[il].ffn_gate, NULL,
-                    model.layers[il].ffn_down, NULL,
+            cur = llm_build_ffn(ctx0, lctx, cur,
+                    model.layers[il].ffn_up,   NULL, NULL,
+                    model.layers[il].ffn_gate, NULL, NULL,
+                    model.layers[il].ffn_down, NULL, NULL,
                     NULL,
                     LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
             cb(cur, "ffn_out", il);
 
             cur = ggml_add(ctx0, cur, ffn_inp);
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
             cb(cur, "l_out", il);
 
             // input for next layer
@@ -10160,7 +11440,7 @@ struct llm_build_context {
         cb(cur, "result_norm", -1);
 
         // lm_head
-        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
 
         ggml_build_forward_expand(gf, cur);
@@ -10172,7 +11452,7 @@ struct llm_build_context {
     //      https://github.com/ggerganov/llama.cpp/issues/5276#issuecomment-1925774738
     // based on the original build_llama() function
     struct ggml_cgraph * build_minicpm() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -10211,21 +11491,21 @@ struct llm_build_context {
             // self-attention
             {
                 // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
                 cb(Qcur, "Qcur", il);
                 if (model.layers[il].bq) {
                     Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
                     cb(Qcur, "Qcur", il);
                 }
 
-                struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
                 cb(Kcur, "Kcur", il);
                 if (model.layers[il].bk) {
                     Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
                     cb(Kcur, "Kcur", il);
                 }
 
-                struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
                 cb(Vcur, "Vcur", il);
                 if (model.layers[il].bv) {
                     Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
@@ -10246,7 +11526,7 @@ struct llm_build_context {
                 );
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
@@ -10273,10 +11553,10 @@ struct llm_build_context {
                         LLM_NORM_RMS, cb, il);
                 cb(cur, "ffn_norm", il);
 
-                cur = llm_build_ffn(ctx0, cur,
-                        model.layers[il].ffn_up,   NULL,
-                        model.layers[il].ffn_gate, NULL,
-                        model.layers[il].ffn_down, NULL,
+                cur = llm_build_ffn(ctx0, lctx, cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        model.layers[il].ffn_gate, NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
                         NULL,
                         LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
                 cb(cur, "ffn_out", il);
@@ -10287,6 +11567,7 @@ struct llm_build_context {
             cb(cur, "hidden_scaled_ffn", -1);
 
             cur = ggml_add(ctx0, cur, ffn_inp);
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
             cb(cur, "l_out", il);
 
             // input for next layer
@@ -10306,7 +11587,7 @@ struct llm_build_context {
         cb(cur, "lmhead_scaling", -1);
 
         // lm_head
-        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
 
         ggml_build_forward_expand(gf, cur);
@@ -10315,7 +11596,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_gemma() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         const int64_t n_embd_head_k = hparams.n_embd_head_k;
 
@@ -10343,18 +11624,18 @@ struct llm_build_context {
             // self-attention
             {
                 // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
                 cb(Qcur, "Qcur", il);
 
-                struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
                 cb(Kcur, "Kcur", il);
 
-                struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
                 cb(Vcur, "Vcur", il);
 
                 Qcur = ggml_rope_ext(
                         ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head,    n_tokens), inp_pos, nullptr,
-                        n_embd_head_k, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow);
                 cb(Qcur, "Qcur", il);
 
@@ -10363,11 +11644,11 @@ struct llm_build_context {
 
                 Kcur = ggml_rope_ext(
                         ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens), inp_pos, nullptr,
-                        n_embd_head_k, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
                         ext_factor, attn_factor, beta_fast, beta_slow);
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, NULL,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il);
             }
@@ -10389,16 +11670,17 @@ struct llm_build_context {
 
             // feed-forward network
             {
-                cur = llm_build_ffn(ctx0, cur,
-                        model.layers[il].ffn_up, NULL,
-                        model.layers[il].ffn_gate, NULL,
-                        model.layers[il].ffn_down, NULL,
+                cur = llm_build_ffn(ctx0, lctx, cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        model.layers[il].ffn_gate, NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
                         NULL,
                         LLM_FFN_GELU, LLM_FFN_PAR, cb, il);
                 cb(cur, "ffn_out", il);
             }
 
             cur = ggml_add(ctx0, cur, sa_out);
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
             cb(cur, "l_out", il);
 
             // input for next layer
@@ -10413,7 +11695,7 @@ struct llm_build_context {
         cb(cur, "result_norm", -1);
 
         // lm_head
-        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
 
         ggml_build_forward_expand(gf, cur);
@@ -10421,8 +11703,143 @@ struct llm_build_context {
         return gf;
     }
 
+    struct ggml_cgraph * build_gemma2() {
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+
+        const int64_t n_embd_head_k = hparams.n_embd_head_k;
+
+        struct ggml_tensor * cur;
+        struct ggml_tensor * inpL;
+
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+
+        inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd));
+        cb(inpL, "inp_scaled", -1);
+
+        // inp_pos - contains the positions
+        struct ggml_tensor * inp_pos = build_inp_pos();
+
+        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+        // gemma 2 requires different mask for layers using sliding window (SWA)
+        struct ggml_tensor * KQ_mask     = build_inp_KQ_mask(true);
+        struct ggml_tensor * KQ_mask_swa = build_inp_KQ_mask_swa(true);
+
+        for (int il = 0; il < n_layer; ++il) {
+            // (il % 2) layers use SWA
+            struct ggml_tensor * KQ_mask_l = (il % 2 == 0) ? KQ_mask_swa : KQ_mask;
+
+            // norm
+            cur = llm_build_norm(ctx0, inpL, hparams,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, cb, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                // compute Q and K and RoPE them
+                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+
+                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+
+                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head,    n_tokens), inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow);
+                cb(Qcur, "Qcur", il);
+
+                // ref: https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e
+                switch (model.type) {
+                    case e_model::MODEL_9B:  Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head_k)));   break;
+                    case e_model::MODEL_27B: Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd / n_head))); break;
+                    default: GGML_ABORT("fatal error");
+                };
+                cb(Qcur, "Qcur_scaled", il);
+
+                Kcur = ggml_rope_ext(
+                        ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens), inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow);
+                cb(Kcur, "Kcur", il);
+
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
+                        model.layers[il].wo, NULL,
+                        Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, 1.0f, cb, il);
+            }
+
+            cur = llm_build_norm(ctx0, cur, hparams,
+                    model.layers[il].attn_post_norm, NULL,
+                    LLM_NORM_RMS, cb, il);
+            cb(cur, "attn_post_norm", il);
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur  = ggml_get_rows(ctx0,  cur, inp_out_ids);
+                inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+            }
+
+            struct ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL);
+            cb(sa_out, "sa_out", il);
+
+            cur = llm_build_norm(ctx0, sa_out, hparams,
+                    model.layers[il].ffn_norm, NULL,
+                    LLM_NORM_RMS, cb, il);
+            cb(cur, "ffn_norm", il);
+
+            // feed-forward network
+            {
+                cur = llm_build_ffn(ctx0, lctx, cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        model.layers[il].ffn_gate, NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
+                        NULL,
+                        LLM_FFN_GELU, LLM_FFN_PAR, cb, il);
+                cb(cur, "ffn_out", il);
+            }
+
+            cur = llm_build_norm(ctx0, cur, hparams,
+                model.layers[il].ffn_post_norm, NULL,
+                LLM_NORM_RMS, cb, -1);
+            cb(cur, "ffn_post_norm", -1);
+
+            cur = ggml_add(ctx0, cur, sa_out);
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = llm_build_norm(ctx0, cur, hparams,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, cb, -1);
+        cb(cur, "result_norm", -1);
+
+        // lm_head
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
+
+        // final logit soft-capping
+        cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping);
+        cur = ggml_tanh(ctx0, cur);
+        cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping);
+
+        cb(cur, "result_output", -1);
+
+        ggml_build_forward_expand(gf, cur);
+
+        return gf;
+    }
+
+
     struct ggml_cgraph * build_starcoder2() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -10451,21 +11868,21 @@ struct llm_build_context {
             // self-attention
             {
                 // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
                 cb(Qcur, "Qcur", il);
                 if (model.layers[il].bq) {
                     Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
                     cb(Qcur, "Qcur", il);
                 }
 
-                struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
                 cb(Kcur, "Kcur", il);
                 if (model.layers[il].bk) {
                     Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
                     cb(Kcur, "Kcur", il);
                 }
 
-                struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
                 cb(Vcur, "Vcur", il);
                 if (model.layers[il].bv) {
                     Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
@@ -10486,7 +11903,7 @@ struct llm_build_context {
                 );
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
@@ -10508,14 +11925,16 @@ struct llm_build_context {
                     LLM_NORM, cb, il);
             cb(cur, "ffn_norm", il);
 
-            cur = llm_build_ffn(ctx0, cur,
-                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,
-                        NULL,                      NULL,
-                        model.layers[il].ffn_down, model.layers[il].ffn_down_b,
+            cur = llm_build_ffn(ctx0, lctx, cur,
+                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
+                        NULL,                      NULL,                        NULL,
+                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
                         NULL,
                         LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
             cb(cur, "ffn_out", il);
+
             cur = ggml_add(ctx0, cur, ffn_inp);
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
             cb(cur, "l_out", il);
 
             // input for next layer
@@ -10530,7 +11949,7 @@ struct llm_build_context {
         cb(cur, "result_norm", -1);
 
         // lm_head
-        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
 
         ggml_build_forward_expand(gf, cur);
@@ -10539,7 +11958,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_mamba() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         const int64_t d_model = n_embd;
         const int64_t d_conv  = hparams.ssm_d_conv;
@@ -10582,7 +12001,7 @@ struct llm_build_context {
             cb(cur, "attn_norm", il);
 
             // {n_embd, 2*d_inner} * {n_embd, n_tokens} => {2*d_inner, n_tokens}
-            struct ggml_tensor * xz = ggml_mul_mat(ctx0, model.layers[il].ssm_in, cur);
+            struct ggml_tensor * xz = llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_in, cur);
             // split the above in two
             // => {d_inner, n_tokens}
             struct ggml_tensor * x = ggml_view_2d(ctx0, xz, d_inner, xz->ne[1], xz->nb[1], 0);
@@ -10621,14 +12040,14 @@ struct llm_build_context {
             // ssm
             {
                 // {d_inner, dt_rank + 2*d_state} * {d_inner, n_tokens} => {dt_rank + 2*d_state, n_tokens}
-                struct ggml_tensor * x_db = ggml_mul_mat(ctx0, model.layers[il].ssm_x, x);
+                struct ggml_tensor * x_db = llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_x, x);
                 // split
                 struct ggml_tensor * dt = ggml_view_2d(ctx0, x_db, dt_rank, n_tokens, x_db->nb[1], 0);
                 struct ggml_tensor * B  = ggml_view_2d(ctx0, x_db, d_state, n_tokens, x_db->nb[1], ggml_element_size(x_db)*dt_rank);
                 struct ggml_tensor * C  = ggml_view_2d(ctx0, x_db, d_state, n_tokens, x_db->nb[1], ggml_element_size(x_db)*(dt_rank+d_state));
 
                 // {dt_rank, d_inner} * {dt_rank, n_tokens} => {d_inner, n_tokens}
-                dt = ggml_mul_mat(ctx0, model.layers[il].ssm_dt, dt);
+                dt = llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_dt, dt);
                 dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b);
 
                 // Custom operator to optimize the parallel associative scan
@@ -10659,11 +12078,12 @@ struct llm_build_context {
                 y = ggml_mul(ctx0, y, ggml_silu(ctx0, z));
 
                 // {d_inner, n_embd} * {d_inner, n_tokens} => {n_embd, n_tokens}
-                cur = ggml_mul_mat(ctx0, model.layers[il].ssm_out, y);
+                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_out, y);
             }
 
             // residual
             cur = ggml_add(ctx0, cur, inpL);
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
             cb(cur, "l_out", il);
 
             // input for next layer
@@ -10677,7 +12097,7 @@ struct llm_build_context {
         cb(cur, "result_norm", -1);
 
         // lm_head
-        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
 
         ggml_build_forward_expand(gf, cur);
@@ -10687,7 +12107,7 @@ struct llm_build_context {
 
     struct ggml_cgraph * build_command_r() {
 
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -10716,21 +12136,21 @@ struct llm_build_context {
             // self-attention
             {
                 // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
                 cb(Qcur, "Qcur", il);
                 if (model.layers[il].bq) {
                     Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
                     cb(Qcur, "Qcur", il);
                 }
 
-                struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
                 cb(Kcur, "Kcur", il);
                 if (model.layers[il].bk) {
                     Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
                     cb(Kcur, "Kcur", il);
                 }
 
-                struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
                 cb(Vcur, "Vcur", il);
                 if (model.layers[il].bv) {
                     Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
@@ -10776,7 +12196,7 @@ struct llm_build_context {
                 );
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
@@ -10793,10 +12213,10 @@ struct llm_build_context {
 
             // feed-forward network
             {
-                cur = llm_build_ffn(ctx0, ffn_inp,
-                        model.layers[il].ffn_up,   NULL,
-                        model.layers[il].ffn_gate, NULL,
-                        model.layers[il].ffn_down, NULL,
+                cur = llm_build_ffn(ctx0, lctx, ffn_inp,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        model.layers[il].ffn_gate, NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
                         NULL,
                         LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
                 cb(cur, "ffn_out", il);
@@ -10805,6 +12225,7 @@ struct llm_build_context {
             // add together residual + FFN + self-attention
             cur = ggml_add(ctx0, cur, inpL);
             cur = ggml_add(ctx0, cur, attn_out);
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
             cb(cur, "l_out", il);
 
             // input for next layer
@@ -10819,7 +12240,7 @@ struct llm_build_context {
         cb(cur, "result_norm", -1);
 
         // lm_head
-        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
 
         if (f_logit_scale) {
             cur = ggml_scale(ctx0, cur, f_logit_scale);
@@ -10840,7 +12261,7 @@ struct llm_build_context {
     //   * removed bias
     //   * removed MoE
     struct ggml_cgraph * build_olmo() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -10872,21 +12293,21 @@ struct llm_build_context {
             // self-attention
             {
                 // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
                 cb(Qcur, "Qcur", il);
                 if (hparams.f_clamp_kqv > 0.0f) {
                     Qcur = ggml_clamp(ctx0, Qcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
                     cb(Qcur, "Qcur", il);
                 }
 
-                struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
                 cb(Kcur, "Kcur", il);
                 if (hparams.f_clamp_kqv > 0.0f) {
                     Kcur = ggml_clamp(ctx0, Kcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
                     cb(Kcur, "Kcur", il);
                 }
 
-                struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
                 cb(Vcur, "Vcur", il);
                 if (hparams.f_clamp_kqv > 0.0f) {
                     Vcur = ggml_clamp(ctx0, Vcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
@@ -10907,7 +12328,7 @@ struct llm_build_context {
                 );
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, nullptr,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
@@ -10929,10 +12350,10 @@ struct llm_build_context {
                     LLM_NORM, cb, il);
             cb(cur, "ffn_norm", il);
 
-            cur = llm_build_ffn(ctx0, cur,
-                    model.layers[il].ffn_up,   NULL,
-                    model.layers[il].ffn_gate, NULL,
-                    model.layers[il].ffn_down, NULL,
+            cur = llm_build_ffn(ctx0, lctx, cur,
+                    model.layers[il].ffn_up,   NULL, NULL,
+                    model.layers[il].ffn_gate, NULL, NULL,
+                    model.layers[il].ffn_down, NULL, NULL,
                     NULL,
                     LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
             cb(cur, "ffn_out", il);
@@ -10940,10 +12361,7 @@ struct llm_build_context {
             cur = ggml_add(ctx0, cur, ffn_inp);
             cb(cur, "ffn_out", il);
 
-            ggml_tensor * layer_dir = lctx.cvec.tensor_for(il);
-            if (layer_dir != nullptr) {
-                cur = ggml_add(ctx0, cur, layer_dir);
-            }
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
             cb(cur, "l_out", il);
 
             // input for next layer
@@ -10958,7 +12376,132 @@ struct llm_build_context {
         cb(cur, "result_norm", -1);
 
         // lm_head
-        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
+        cb(cur, "result_output", -1);
+
+        ggml_build_forward_expand(gf, cur);
+
+        return gf;
+    }
+
+    struct ggml_cgraph * build_openelm() {
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+
+        struct ggml_tensor * cur;
+        struct ggml_tensor * inpL;
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+
+        // inp_pos - contains the positions
+        struct ggml_tensor * inp_pos = build_inp_pos();
+
+        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
+
+        for (int il = 0; il < n_layer; ++il) {
+            const int64_t n_head    = hparams.n_head(il);
+            const int64_t n_head_kv = hparams.n_head_kv(il);
+            const int64_t n_head_qkv = 2*n_head_kv + n_head;
+
+            cur = inpL;
+            struct ggml_tensor * residual = cur;
+
+            // norm
+            cur = llm_build_norm(ctx0, inpL, hparams,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, cb, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
+                cb(cur, "wqkv", il);
+
+                cur = ggml_reshape_3d(ctx0, cur, n_embd_head_k, n_head_qkv, n_tokens);
+
+                struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, cur->nb[1], cur->nb[2], 0));
+                cb(Qcur, "Qcur", il);
+
+                struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, cur->nb[1], cur->nb[2], cur->nb[1]*n_head));
+                cb(Kcur, "Kcur", il);
+
+                struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, cur->nb[1], cur->nb[2], cur->nb[1]*(n_head+n_head_kv)));
+                cb(Vcur, "Vcur", il);
+
+                Qcur = llm_build_norm(ctx0, Qcur, hparams,
+                        model.layers[il].attn_q_norm, NULL,
+                        LLM_NORM_RMS, cb, il);
+                cb(Qcur, "Qcur", il);
+
+                Kcur = llm_build_norm(ctx0, Kcur, hparams,
+                        model.layers[il].attn_k_norm, NULL,
+                        LLM_NORM_RMS, cb, il);
+                cb(Kcur, "Kcur", il);
+
+                Qcur = ggml_rope_ext(
+                    ctx0, Qcur, inp_pos, NULL, n_rot, rope_type, n_ctx_orig,
+                    freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
+                );
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_ext(
+                    ctx0, Kcur, inp_pos, NULL, n_rot, rope_type, n_ctx_orig,
+                    freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
+                );
+                cb(Kcur, "Kcur", il);
+
+                Vcur = ggml_reshape_2d(ctx0, Vcur, n_embd_head * n_head_kv, n_tokens);
+                cb(Qcur, "Vcur", il);
+
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
+                        model.layers[il].wo, NULL,
+                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+                residual = ggml_get_rows(ctx0, residual, inp_out_ids);
+                cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+            }
+
+            struct ggml_tensor * ffn_inp = ggml_add(ctx0, residual, cur);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+            {
+                cur = llm_build_norm(ctx0, ffn_inp, hparams,
+                        model.layers[il].ffn_norm, NULL,
+                        LLM_NORM_RMS, cb, il);
+                cb(cur, "ffn_norm", il);
+
+                cur = llm_build_ffn(ctx0, lctx, cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        model.layers[il].ffn_gate, NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
+                        NULL,
+                        LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
+                cb(cur, "ffn_out", il);
+            }
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
+            cb(cur, "l_out", il);
+
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        // norm
+        cur = llm_build_norm(ctx0, cur, hparams,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, cb, -1);
+        cb(cur, "result_norm", -1);
+
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
 
         ggml_build_forward_expand(gf, cur);
@@ -10967,7 +12510,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_gptneox() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
@@ -10993,7 +12536,7 @@ struct llm_build_context {
 
             // self-attention
             {
-                cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
+                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
                 cb(cur, "wqkv", il);
 
                 cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
@@ -11021,7 +12564,7 @@ struct llm_build_context {
                 );
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
@@ -11046,10 +12589,10 @@ struct llm_build_context {
                         LLM_NORM, cb, il);
                 cb(cur, "ffn_norm", il);
 
-                cur = llm_build_ffn(ctx0, cur,
-                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,
-                        NULL,                      NULL,
-                        model.layers[il].ffn_down, model.layers[il].ffn_down_b,
+                cur = llm_build_ffn(ctx0, lctx, cur,
+                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
+                        NULL,                      NULL,                        NULL,
+                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
                         NULL,
                         LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
                 cb(cur, "ffn_out", il);
@@ -11057,8 +12600,12 @@ struct llm_build_context {
                 cur = ggml_add(ctx0, cur, inpL);
                 cb(cur, "ffn_out", il);
 
-                inpL = ggml_add(ctx0, cur, attn_out);
-                cb(inpL, "l_out", il);
+                cur = ggml_add(ctx0, cur, attn_out);
+                cur = lctx.cvec.apply_to(ctx0, cur, il);
+                cb(cur, "l_out", il);
+
+                // input for next layer
+                inpL = cur;
             } else {
                 // attention and ffn are computed sequentially
                 // x = x + attn(ln1(x))
@@ -11073,16 +12620,20 @@ struct llm_build_context {
                         LLM_NORM, cb, il);
                 cb(cur, "ffn_norm", il);
 
-                cur = llm_build_ffn(ctx0, cur,
-                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,
-                        NULL,                      NULL,
-                        model.layers[il].ffn_down, model.layers[il].ffn_down_b,
+                cur = llm_build_ffn(ctx0, lctx, cur,
+                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
+                        NULL,                      NULL,                        NULL,
+                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
                         NULL,
                         LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
                 cb(cur, "ffn_out", il);
 
-                inpL = ggml_add(ctx0, cur, ffn_inp);
-                cb(inpL, "l_out", il);
+                cur = ggml_add(ctx0, cur, ffn_inp);
+                cur = lctx.cvec.apply_to(ctx0, cur, il);
+                cb(cur, "l_out", il);
+
+                // input for next layer
+                inpL = cur;
             }
         }
 
@@ -11092,7 +12643,7 @@ struct llm_build_context {
                 LLM_NORM, cb, -1);
         cb(cur, "result_norm", -1);
 
-        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
 
         ggml_build_forward_expand(gf, cur);
@@ -11101,7 +12652,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_arctic() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -11133,13 +12684,13 @@ struct llm_build_context {
             // self-attention
             {
                 // compute Q and K and RoPE them
-                struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
                 cb(Qcur, "Qcur", il);
 
-                struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
                 cb(Kcur, "Kcur", il);
 
-                struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
                 cb(Vcur, "Vcur", il);
 
                 Qcur = ggml_rope_ext(
@@ -11156,7 +12707,7 @@ struct llm_build_context {
                 );
                 cb(Kcur, "Kcur", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, NULL,
                         Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
@@ -11178,10 +12729,10 @@ struct llm_build_context {
                     LLM_NORM_RMS, cb, il);
             cb(cur, "ffn_norm", il);
 
-            cur = llm_build_ffn(ctx0, cur,
-                    model.layers[il].ffn_up,   NULL,
-                    model.layers[il].ffn_gate, NULL,
-                    model.layers[il].ffn_down, NULL,
+            cur = llm_build_ffn(ctx0, lctx, cur,
+                    model.layers[il].ffn_up,   NULL, NULL,
+                    model.layers[il].ffn_gate, NULL, NULL,
+                    model.layers[il].ffn_down, NULL, NULL,
                     NULL,
                     LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
             cb(cur, "ffn_out", il);
@@ -11195,7 +12746,7 @@ struct llm_build_context {
                     LLM_NORM_RMS, cb, il);
             cb(cur, "ffn_norm_exps", il);
 
-            cur = llm_build_moe_ffn(ctx0, cur,
+            cur = llm_build_moe_ffn(ctx0, lctx, cur,
                     model.layers[il].ffn_gate_inp,
                     model.layers[il].ffn_up_exps,
                     model.layers[il].ffn_gate_exps,
@@ -11209,10 +12760,7 @@ struct llm_build_context {
             cur = ggml_add(ctx0, cur, ffn_out);
             cb(cur, "ffn_out", il);
 
-            ggml_tensor * layer_dir = lctx.cvec.tensor_for(il);
-            if (layer_dir != nullptr) {
-                cur = ggml_add(ctx0, cur, layer_dir);
-            }
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
             cb(cur, "l_out", il);
 
             // input for next layer
@@ -11227,7 +12775,7 @@ struct llm_build_context {
         cb(cur, "result_norm", -1);
 
         // lm_head
-        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
 
         ggml_build_forward_expand(gf, cur);
@@ -11236,7 +12784,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_deepseek2() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -11381,7 +12929,7 @@ struct llm_build_context {
                 struct ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0);
                 cb(k_states, "k_states", il);
 
-                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
                         model.layers[il].wo, NULL,
                         k_states, v_states, q_states, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il);
             }
@@ -11397,28 +12945,23 @@ struct llm_build_context {
             struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
             cb(ffn_inp, "ffn_inp", il);
 
-            if ((uint32_t) il < hparams.n_layer_dense_lead) {
-                cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                        model.layers[il].ffn_norm, NULL,
-                        LLM_NORM_RMS, cb, il);
-                cb(cur, "ffn_norm", il);
+            cur = llm_build_norm(ctx0, ffn_inp, hparams,
+                    model.layers[il].ffn_norm, NULL,
+                    LLM_NORM_RMS, cb, il);
+            cb(cur, "ffn_norm", il);
 
-                cur = llm_build_ffn(ctx0, cur,
-                        model.layers[il].ffn_up,   NULL,
-                        model.layers[il].ffn_gate, NULL,
-                        model.layers[il].ffn_down, NULL,
+            if ((uint32_t) il < hparams.n_layer_dense_lead) {
+                cur = llm_build_ffn(ctx0, lctx, cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        model.layers[il].ffn_gate, NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
                         NULL,
                         LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
                 cb(cur, "ffn_out", il);
             } else {
                 // MoE branch
-                cur = llm_build_norm(ctx0, ffn_inp, hparams,
-                        model.layers[il].ffn_norm, NULL,
-                        LLM_NORM_RMS, cb, il);
-                cb(cur, "ffn_norm", il);
-
                 ggml_tensor * moe_out =
-                        llm_build_moe_ffn(ctx0, cur,
+                        llm_build_moe_ffn(ctx0, lctx, cur,
                             model.layers[il].ffn_gate_inp,
                             model.layers[il].ffn_up_exps,
                             model.layers[il].ffn_gate_exps,
@@ -11431,10 +12974,10 @@ struct llm_build_context {
 
                 // FFN shared expert
                 {
-                    ggml_tensor * ffn_shexp = llm_build_ffn(ctx0, cur,
-                            model.layers[il].ffn_up_shexp,   NULL,
-                            model.layers[il].ffn_gate_shexp, NULL,
-                            model.layers[il].ffn_down_shexp, NULL,
+                    ggml_tensor * ffn_shexp = llm_build_ffn(ctx0, lctx, cur,
+                            model.layers[il].ffn_up_shexp,   NULL, NULL,
+                            model.layers[il].ffn_gate_shexp, NULL, NULL,
+                            model.layers[il].ffn_down_shexp, NULL, NULL,
                             NULL,
                             LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
                     cb(ffn_shexp, "ffn_shexp", il);
@@ -11445,6 +12988,7 @@ struct llm_build_context {
             }
 
             cur = ggml_add(ctx0, cur, ffn_inp);
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
             cb(cur, "l_out", il);
 
             // input for next layer
@@ -11467,6 +13011,668 @@ struct llm_build_context {
         return gf;
     }
 
+    struct ggml_cgraph * build_bitnet() {
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+
+        struct ggml_tensor * cur;
+        struct ggml_tensor * inpL;
+
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+
+        // inp_pos - contains the positions
+        struct ggml_tensor * inp_pos = build_inp_pos();
+
+        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
+
+        for (int il = 0; il < n_layer; ++il) {
+            struct ggml_tensor * inpSA = inpL;
+
+            cur = llm_build_norm(ctx0, inpL, hparams,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, cb, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                // compute Q and K and RoPE them
+                struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
+                Qcur = ggml_mul(ctx0, Qcur, model.layers[il].wq_scale);
+                cb(Qcur, "Qcur", il);
+                if (model.layers[il].bq) {
+                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+                    cb(Qcur, "Qcur", il);
+                }
+
+                // B1.K
+                struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
+                Kcur = ggml_mul(ctx0, Kcur, model.layers[il].wk_scale);
+                cb(Kcur, "Kcur", il);
+                if (model.layers[il].bk) {
+                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+                    cb(Kcur, "Kcur", il);
+                }
+
+                // B1.V
+                struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
+                Vcur = ggml_mul(ctx0, Vcur, model.layers[il].wv_scale);
+                cb(Vcur, "Vcur", il);
+                if (model.layers[il].bv) {
+                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+                    cb(Vcur, "Vcur", il);
+                }
+
+                Qcur = ggml_rope_ext(
+                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                );
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_ext(
+                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                );
+                cb(Kcur, "Kcur", il);
+
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
+                        NULL, NULL,
+                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+
+                cur = llm_build_norm(ctx0, cur, hparams,
+                        model.layers[il].attn_sub_norm, NULL,
+                        LLM_NORM_RMS, cb, il);
+                cb(cur, "attn_sub_norm", il);
+
+                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, cur);
+                cur = ggml_mul(ctx0, cur, model.layers[il].wo_scale);
+                if (model.layers[il].bo) {
+                    cur = ggml_add(ctx0, cur, model.layers[il].bo);
+                }
+                cb(cur, "attn_o_out", il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward forward
+            cur = llm_build_norm(ctx0, ffn_inp, hparams,
+                    model.layers[il].ffn_norm, NULL,
+                    LLM_NORM_RMS, cb, il);
+            cb(cur, "ffn_norm", il);
+
+            cur = llm_build_ffn(ctx0, lctx, cur,
+                    model.layers[il].ffn_up,   NULL, model.layers[il].ffn_up_scale,
+                    model.layers[il].ffn_gate, NULL, model.layers[il].ffn_gate_scale,
+                    NULL,                      NULL, NULL,
+                    NULL,
+                    LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
+            cb(cur, "ffn_sub_out", il);
+
+            cur = llm_build_norm(ctx0, cur, hparams,
+                            model.layers[il].ffn_sub_norm, NULL,
+                            LLM_NORM_RMS, cb, il);
+            cb(cur, "ffn_sub_norm", il);
+
+            cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].ffn_down, cur);
+            cur = ggml_mul(ctx0, cur, model.layers[il].ffn_down_scale);
+            cb(cur, "ffn_down", il);
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = llm_build_norm(ctx0, cur, hparams,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, cb, -1);
+        cb(cur, "result_norm", -1);
+
+        // lm_head
+        cur = llm_build_lora_mm(lctx, ctx0, model.tok_embd, cur);
+        cb(cur, "result_output", -1);
+
+        ggml_build_forward_expand(gf, cur);
+        return gf;
+    }
+
+    struct ggml_cgraph * build_t5() {
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+
+        // mutable variable, needed during the last layer of the computation to skip unused tokens
+        int32_t n_tokens = this->n_tokens;
+
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+
+        struct ggml_tensor * cur;
+        struct ggml_tensor * inpL;
+
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+
+        if (lctx.is_encoding) {
+            struct ggml_tensor * pos_bucket_enc = llm_build_pos_bucket(false);
+
+            // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+            struct ggml_tensor * KQ_mask_enc = build_inp_KQ_mask(false);
+
+            for (int il = 0; il < n_layer; ++il) {
+                struct ggml_tensor * inpSA = inpL;
+
+                // norm
+                cur = llm_build_norm(ctx0, inpL, hparams,
+                        model.layers[il].attn_norm_enc, NULL,
+                        LLM_NORM_RMS, cb, il);
+                cb(cur, "attn_norm", il);
+
+                // self-attention
+                {
+                    struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq_enc, cur);
+                    cb(Qcur, "Qcur", il);
+
+                    struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk_enc, cur);
+                    cb(Kcur, "Kcur", il);
+
+                    struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv_enc, cur);
+                    cb(Vcur, "Vcur", il);
+
+                    Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+                    Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+
+                    struct ggml_tensor * q =                 ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
+                    struct ggml_tensor * k = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3));
+
+                    struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
+                    cb(kq, "kq", il);
+
+                    struct ggml_tensor * attn_rel_b = model.layers[il].attn_rel_b_enc ? model.layers[il].attn_rel_b_enc : model.layers[0].attn_rel_b_enc;
+                    struct ggml_tensor * pos_bias = llm_build_pos_bias(pos_bucket_enc, attn_rel_b);
+                    struct ggml_tensor * kq_b = ggml_add(ctx0, kq, pos_bias);
+                    cb(kq_b, "kq_b", il);
+
+                    kq = ggml_soft_max_ext(ctx0, kq_b, KQ_mask_enc, 1.0f, hparams.f_max_alibi_bias);
+                    cb(kq, "kq_soft_max_ext", il);
+
+                    struct ggml_tensor * v = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_embd_gqa, n_tokens)));
+                    cb(v, "v", il);
+
+                    struct ggml_tensor * kqv = ggml_mul_mat(ctx0, ggml_reshape_3d(ctx0, v, n_tokens, n_embd_head, n_head_kv), kq);
+                    cb(kqv, "kqv", il);
+
+                    struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
+                    cb(kqv_merged, "kqv_merged", il);
+
+                    cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_gqa, n_tokens);
+                    cb(cur, "kqv_merged_cont", il);
+
+                    ggml_build_forward_expand(gf, cur);
+
+                    cur = ggml_mul_mat(ctx0, model.layers[il].wo_enc, cur);
+                    cb(cur, "kqv_out", il);
+                }
+
+                if (il == n_layer - 1) {
+                    // skip computing output for unused tokens
+                    struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+                    n_tokens = n_outputs;
+                    cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                    inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+                }
+
+                struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+                cb(ffn_inp, "ffn_inp", il);
+
+                // feed-forward network
+                {
+                    cur = llm_build_norm(ctx0, ffn_inp, hparams,
+                            model.layers[il].ffn_norm_enc, NULL,
+                            LLM_NORM_RMS, cb, il);
+                    cb(cur, "ffn_norm", il);
+
+                    // T5 uses relu, flan-T5 uses gelu-gated
+                    cur = llm_build_ffn(ctx0, lctx, cur,
+                            model.layers[il].ffn_up_enc,   NULL, NULL,
+                            model.layers[il].ffn_gate_enc, NULL, NULL,
+                            model.layers[il].ffn_down_enc, NULL, NULL,
+                            NULL,
+                            model.layers[il].ffn_gate_enc ? LLM_FFN_GELU : LLM_FFN_RELU,
+                            model.layers[il].ffn_gate_enc ? LLM_FFN_PAR  : LLM_FFN_SEQ,
+                            cb, il);
+                    cb(cur, "ffn_out", il);
+                }
+
+                cur = ggml_add(ctx0, cur, ffn_inp);
+                cb(cur, "ffn_out", il);
+
+                ggml_tensor * layer_dir = lctx.cvec.tensor_for(il);
+                if (layer_dir != nullptr) {
+                    cur = ggml_add(ctx0, cur, layer_dir);
+                }
+                cb(cur, "l_out", il);
+
+                // input for next layer
+                inpL = cur;
+            }
+
+            cur = inpL;
+            cb(cur, "result_embd", -1);
+
+            cur = llm_build_norm(ctx0, cur, hparams,
+                    model.output_norm_enc, NULL,
+                    LLM_NORM_RMS, cb, -1);
+            cb(cur, "result_norm", -1);
+        } else {
+            GGML_ASSERT(n_outputs_enc > 0 && "call llama_encode() first");
+
+            struct ggml_tensor * embd_enc       = llm_build_inp_embd_enc();
+            struct ggml_tensor * pos_bucket_dec = llm_build_pos_bucket(true);
+
+            struct ggml_tensor * KQ_mask_dec   = build_inp_KQ_mask();
+            struct ggml_tensor * KQ_mask_cross = llm_build_inp_KQ_mask_cross();
+
+            for (int il = 0; il < n_layer; ++il) {
+                struct ggml_tensor * inpSA = inpL;
+
+                // norm
+                cur = llm_build_norm(ctx0, inpL, hparams,
+                        model.layers[il].attn_norm, NULL,
+                        LLM_NORM_RMS, cb, il);
+                cb(cur, "attn_norm", il);
+
+                // self-attention
+                {
+                    struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+                    cb(Qcur, "Qcur", il);
+
+                    struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
+                    cb(Kcur, "Kcur", il);
+
+                    struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+                    cb(Vcur, "Vcur", il);
+
+                    llm_build_kv_store(ctx0, hparams, cparams, kv_self, gf, Kcur, Vcur, n_tokens, kv_head, cb, il);
+
+                    struct ggml_tensor * k =
+                        ggml_view_3d(ctx0, kv_self.k_l[il],
+                                n_embd_head_k, n_kv, n_head_kv,
+                                ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
+                                ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k),
+                                0);
+                    cb(k, "k", il);
+
+                    struct ggml_tensor * v =
+                        ggml_view_3d(ctx0, kv_self.v_l[il],
+                                n_kv, n_embd_head_v, n_head_kv,
+                                ggml_element_size(kv_self.v_l[il])*n_ctx,
+                                ggml_element_size(kv_self.v_l[il])*n_ctx*n_embd_head_v,
+                                0);
+                    cb(v, "v", il);
+
+                    Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+
+                    struct ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
+
+                    struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
+                    cb(kq, "kq", il);
+
+                    struct ggml_tensor * attn_rel_b = model.layers[il].attn_rel_b ? model.layers[il].attn_rel_b : model.layers[0].attn_rel_b;
+                    struct ggml_tensor * pos_bias = llm_build_pos_bias(pos_bucket_dec, attn_rel_b);
+                    struct ggml_tensor * kq_b = ggml_add(ctx0, kq, pos_bias);
+                    cb(kq_b, "kq_b", il);
+
+                    kq = ggml_soft_max_ext(ctx0, kq_b, KQ_mask_dec, 1.0f, hparams.f_max_alibi_bias);
+                    cb(kq, "kq_soft_max_ext", il);
+
+                    struct ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
+                    cb(kqv, "kqv", il);
+
+                    struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
+                    cb(kqv_merged, "kqv_merged", il);
+
+                    cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_gqa, n_tokens);
+                    cb(cur, "kqv_merged_cont", il);
+
+                    ggml_build_forward_expand(gf, cur);
+
+                    cur = ggml_mul_mat(ctx0, model.layers[il].wo, cur);
+                    cb(cur, "kqv_out", il);
+                }
+
+                cur = ggml_add(ctx0, cur, inpSA);
+                cb(cur, "cross_inp", il);
+
+                struct ggml_tensor * inpCA = cur;
+
+                // norm
+                cur = llm_build_norm(ctx0, cur, hparams,
+                        model.layers[il].attn_norm_cross, NULL,
+                        LLM_NORM_RMS, cb, il);
+                cb(cur, "attn_norm_cross", il);
+
+                // cross-attention
+                {
+                    struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq_cross, cur);
+                    cb(Qcur, "Qcur", il);
+
+                    struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk_cross, embd_enc);
+                    cb(Kcur, "Kcur", il);
+
+                    struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv_cross, embd_enc);
+                    cb(Vcur, "Vcur", il);
+
+                    Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                    Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_outputs_enc);
+
+                    struct ggml_tensor * q =                 ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
+                    struct ggml_tensor * k = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3));
+
+                    struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
+                    cb(kq, "kq", il);
+
+                    kq = ggml_soft_max_ext(ctx0, kq, KQ_mask_cross, 1.0f, hparams.f_max_alibi_bias);
+                    cb(kq, "kq_soft_max_ext", il);
+
+                    struct ggml_tensor * v = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_embd_gqa, n_outputs_enc)));
+                    cb(v, "v", il);
+
+                    struct ggml_tensor * kqv = ggml_mul_mat(ctx0, ggml_reshape_3d(ctx0, v, n_outputs_enc, n_embd_head, n_head_kv), kq);
+                    cb(kqv, "kqv", il);
+
+                    struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
+                    cb(kqv_merged, "kqv_merged", il);
+
+                    cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_gqa, n_tokens);
+                    cb(cur, "kqv_merged_cont", il);
+
+                    ggml_build_forward_expand(gf, cur);
+
+                    cur = ggml_mul_mat(ctx0, model.layers[il].wo_cross, cur);
+                    cb(cur, "kqv_out", il);
+                }
+
+                if (il == n_layer - 1) {
+                    // skip computing output for unused tokens
+                    struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+                    n_tokens = n_outputs;
+                    cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                    inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+                    inpCA = ggml_get_rows(ctx0, inpCA, inp_out_ids);
+                }
+
+                struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpCA);
+                cb(ffn_inp, "ffn_inp", il);
+
+                // feed-forward network
+                {
+                    cur = llm_build_norm(ctx0, ffn_inp, hparams,
+                            model.layers[il].ffn_norm, NULL,
+                            LLM_NORM_RMS, cb, il);
+                    cb(cur, "ffn_norm", il);
+
+                    // T5 uses relu, flan-T5 uses gelu-gated
+                    cur = llm_build_ffn(ctx0, lctx, cur,
+                            model.layers[il].ffn_up,   NULL, NULL,
+                            model.layers[il].ffn_gate, NULL, NULL,
+                            model.layers[il].ffn_down, NULL, NULL,
+                            NULL,
+                            model.layers[il].ffn_gate_enc ? LLM_FFN_GELU : LLM_FFN_RELU,
+                            model.layers[il].ffn_gate_enc ? LLM_FFN_PAR : LLM_FFN_SEQ,
+                            cb, il);
+                    cb(cur, "ffn_out", il);
+                }
+
+                cur = ggml_add(ctx0, cur, ffn_inp);
+                cb(cur, "ffn_out", il);
+
+                ggml_tensor * layer_dir = lctx.cvec.tensor_for(il);
+                if (layer_dir != nullptr) {
+                    cur = ggml_add(ctx0, cur, layer_dir);
+                }
+                cb(cur, "l_out", il);
+
+                // input for next layer
+                inpL = cur;
+            }
+
+            cur = inpL;
+            cb(cur, "result_embd", -1);
+
+            cur = llm_build_norm(ctx0, cur, hparams,
+                    model.output_norm, NULL,
+                    LLM_NORM_RMS, cb, -1);
+            cb(cur, "result_norm", -1);
+
+            // lm_head
+            cur = ggml_mul_mat(ctx0, model.output, cur);
+            cb(cur, "result_output", -1);
+        }
+
+        ggml_build_forward_expand(gf, cur);
+
+        return gf;
+    }
+
+    struct ggml_cgraph * build_jais() {
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+
+        struct ggml_tensor * cur;
+        struct ggml_tensor * inpL;
+
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+
+        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
+
+        for (int il = 0; il < n_layer; ++il) {
+            cur = llm_build_norm(ctx0, inpL, hparams,
+                    model.layers[il].attn_norm,
+                    model.layers[il].attn_norm_b,
+                    LLM_NORM, cb, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
+                cb(cur, "wqkv", il);
+
+                cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
+                cb(cur, "bqkv", il);
+
+                struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*cur->nb[0]*(n_embd)));
+                struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*cur->nb[0]*(n_embd)));
+                struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*cur->nb[0]*(n_embd + n_embd_gqa)));
+
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/float(n_embd_head), cb, il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur  = ggml_get_rows(ctx0,  cur, inp_out_ids);
+                inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+            }
+
+            // add the input
+            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // FF
+            {
+                cur = llm_build_norm(ctx0, ffn_inp, hparams,
+                        model.layers[il].ffn_norm,
+                        model.layers[il].ffn_norm_b,
+                        LLM_NORM, cb, il);
+                cb(cur, "ffn_norm", il);
+
+                cur = llm_build_ffn(ctx0, lctx, cur,
+                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
+                        model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
+                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
+                        NULL,
+                        LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
+                cb(cur, "ffn_out", il);
+            }
+
+            inpL = ggml_add(ctx0, cur, ffn_inp);
+            cb(inpL, "l_out", il);
+        }
+
+        cur = llm_build_norm(ctx0, inpL, hparams,
+                model.output_norm,
+                model.output_norm_b,
+                LLM_NORM, cb, -1);
+        cb(cur, "result_norm", -1);
+
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
+
+        cb(cur, "result_output", -1);
+
+        ggml_build_forward_expand(gf, cur);
+
+        return gf;
+    }
+
+    struct ggml_cgraph * build_chatglm() {
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+
+        struct ggml_tensor * cur;
+        struct ggml_tensor * inpL;
+
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+
+        // inp_pos - contains the positions
+        struct ggml_tensor * inp_pos = build_inp_pos();
+
+        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
+
+        for (int il = 0; il < n_layer; ++il) {
+            struct ggml_tensor * inpSA = inpL;
+
+            cur = llm_build_norm(ctx0, inpL, hparams,
+                    model.layers[il].attn_norm,
+                    NULL,
+                    LLM_NORM_RMS, cb, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                struct ggml_tensor * Qcur = nullptr;
+                struct ggml_tensor * Kcur = nullptr;
+                struct ggml_tensor * Vcur = nullptr;
+
+                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
+                cb(cur, "wqkv", il);
+
+                cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
+                cb(cur, "bqkv", il);
+
+                Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
+                Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
+                Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
+
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
+                //printf("freq_base: %f freq_scale: %f ext_factor: %f attn_factor: %f\n", freq_base, freq_scale, ext_factor, attn_factor);
+                Qcur = ggml_rope_ext(
+                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                );
+                cb(Qcur, "Qcur_rope", il);
+
+                Kcur = ggml_rope_ext(
+                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                );
+                cb(Kcur, "Kcur_rope", il);
+
+                cur = llm_build_kv(ctx0, lctx, kv_self, gf,
+                        model.layers[il].wo, NULL,
+                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            // Add the input
+            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // FF
+            {
+                cur = llm_build_norm(ctx0, ffn_inp, hparams,
+                        model.layers[il].ffn_norm,
+                        NULL,
+                        LLM_NORM_RMS, cb, il);
+                cb(cur, "ffn_norm", il);
+
+                cur = llm_build_ffn(ctx0, lctx, cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        NULL,                      NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
+                        NULL,
+                        LLM_FFN_SWIGLU, LLM_FFN_SEQ, cb, il);
+                cb(cur, "ffn_out", il);
+
+            }
+
+            inpL = ggml_add(ctx0, cur, ffn_inp);
+            cb(inpL, "l_out", il);
+        }
+
+        cur = llm_build_norm(ctx0, inpL, hparams,
+                model.output_norm,
+                NULL,
+                LLM_NORM_RMS, cb, -1);
+        cb(cur, "result_norm", -1);
+
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
+        cb(cur, "result_output", -1);
+
+        ggml_build_forward_expand(gf, cur);
+
+        return gf;
+    }
 };
 
 static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector & ids) {
@@ -11547,7 +13753,8 @@ static struct ggml_cgraph * llama_build_graph(
         if (batch.n_tokens < 32 || full_offload) {
             if (il != -1 && strcmp(name, "norm") == 0) {
                 for (auto * backend : lctx.backends) {
-                    if (ggml_backend_buft_supports_backend(lctx.model.buft_layer[il].buft, backend)) {
+                    if (ggml_backend_supports_buft(backend, lctx.model.buft_layer[il].buft) &&
+                        (ggml_backend_supports_op(backend, cur) || ggml_backend_offload_op(backend, cur))) {
                         ggml_backend_sched_set_tensor_backend(lctx.sched, cur, backend);
                         break;
                     }
@@ -11653,6 +13860,10 @@ static struct ggml_cgraph * llama_build_graph(
             {
                 result = llm.build_gemma();
             } break;
+        case LLM_ARCH_GEMMA2:
+            {
+                result = llm.build_gemma2();
+            } break;
         case LLM_ARCH_STARCODER2:
             {
                 result = llm.build_starcoder2();
@@ -11677,6 +13888,10 @@ static struct ggml_cgraph * llama_build_graph(
             {
                 result = llm.build_olmo();
             } break;
+        case LLM_ARCH_OPENELM:
+            {
+                result = llm.build_openelm();
+            } break;
         case LLM_ARCH_GPTNEOX:
             {
                 result = llm.build_gptneox();
@@ -11689,8 +13904,29 @@ static struct ggml_cgraph * llama_build_graph(
             {
                 result = llm.build_deepseek2();
             } break;
+        case LLM_ARCH_CHATGLM:
+            {
+                result = llm.build_chatglm();
+            } break;
+        case LLM_ARCH_BITNET:
+            {
+                result = llm.build_bitnet();
+            } break;
+        case LLM_ARCH_T5:
+            {
+                result = llm.build_t5();
+            } break;
+        case LLM_ARCH_JAIS:
+            {
+                result = llm.build_jais();
+            } break;
         default:
-            GGML_ASSERT(false);
+            GGML_ABORT("fatal error");
+    }
+
+    // add on pooling layer
+    if (lctx.cparams.embeddings) {
+        result = llm.append_pooling(result);
     }
 
     llm.free();
@@ -11722,6 +13958,30 @@ static void llama_set_s_copy(llama_context & lctx) {
     }
 }
 
+static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
+    // TODO move to hparams if a T5 variant appears that uses a different value
+    const int64_t max_distance = 128;
+
+    if (bidirectional) {
+        n_buckets >>= 1;
+    }
+
+    const int64_t max_exact = n_buckets >> 1;
+
+    int32_t relative_position = x - y;
+    int32_t relative_bucket = 0;
+    if (bidirectional) {
+        relative_bucket += (relative_position > 0) * n_buckets;
+        relative_position = abs(relative_position);
+    } else {
+        relative_position = -std::min(relative_position, 0);
+    }
+    int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact));
+    relative_position_if_large = std::min(relative_position_if_large, n_buckets - 1);
+    relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);
+    return relative_bucket;
+}
+
 static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
     //
     // set input data
@@ -11782,18 +14042,28 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
         // (!a || b) is a logical implication (a -> b)
         // !hparams.causal_attn -> !cparams.causal_attn
         (hparams.causal_attn || !cparams.causal_attn) &&
-        "causal attention with embedding models is not supported"
+        "causal attention is not supported by this model"
     );
 
-    if (lctx.inp_KQ_mask) {
+    if (lctx.inp_KQ_mask || lctx.inp_KQ_mask_swa) {
         // NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
-        if (cparams.causal_attn) {
+        if (cparams.causal_attn && !lctx.is_encoding) {
             const int64_t n_kv     = kv_self.n;
             const int64_t n_tokens = batch.n_tokens;
 
-            GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
 
-            float * data = (float *) lctx.inp_KQ_mask->data;
+            float * data     = nullptr;
+            float * data_swa = nullptr;
+
+            if (lctx.inp_KQ_mask) {
+                GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
+                data = (float *) lctx.inp_KQ_mask->data;
+            }
+
+            if (lctx.inp_KQ_mask_swa) {
+                GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask_swa->buffer));
+                data_swa = (float *) lctx.inp_KQ_mask_swa->data;
+            }
 
             // For causal attention, use only the previous KV cells
             // of the correct sequence for each token of the batch.
@@ -11809,25 +14079,46 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
                             f = -INFINITY;
                         } else {
                             if (hparams.use_alibi) {
-                                f = -fabs(lctx.kv_self.cells[i].pos - pos);
+                                f = -std::abs(lctx.kv_self.cells[i].pos - pos);
                             } else {
                                 f = 0.0f;
                             }
                         }
-                        data[h*(n_kv*n_tokens) + j*n_kv + i] = f;
+
+                        if (data) {
+                            data[h*(n_kv*n_tokens) + j*n_kv + i] = f;
+                        }
+
+                        // may need to cut off old tokens for sliding window
+                        if (data_swa) {
+                            if (pos - lctx.kv_self.cells[i].pos >= (int32_t)hparams.n_swa) {
+                                f = -INFINITY;
+                            }
+                            data_swa[h*(n_kv*n_tokens) + j*n_kv + i] = f;
+                        }
                     }
                 }
 
-                for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
-                    for (int j = 0; j < n_kv; ++j) {
-                        data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
+                if (data) {
+                    for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
+                        for (int j = 0; j < n_kv; ++j) {
+                            data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
+                        }
+                    }
+                }
+
+                if (data_swa) {
+                    for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
+                        for (int j = 0; j < n_kv; ++j) {
+                            data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
+                        }
                     }
                 }
             }
         } else {
             // when using kv cache, the mask needs to match the kv cache size
             const int64_t n_tokens = batch.n_tokens;
-            const int64_t n_stride = hparams.causal_attn ? kv_self.n : n_tokens;
+            const int64_t n_stride = hparams.causal_attn && !lctx.is_encoding ? kv_self.n : n_tokens;
 
             GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
 
@@ -11842,7 +14133,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
                         for (int s = 0; s < batch.n_seq_id[i]; ++s) {
                             if (batch.seq_id[i][s] == seq_id) {
                                 if (hparams.use_alibi) {
-                                    f = -fabs(batch.pos[i] - batch.pos[j]);
+                                    f = -std::abs(batch.pos[i] - batch.pos[j]);
                                 } else {
                                     f = 0.0f;
                                 }
@@ -11861,7 +14152,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
         }
     }
 
-    if (cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
+    if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
         const int64_t n_tokens = batch.n_tokens;
 
         GGML_ASSERT(lctx.inp_mean);
@@ -11893,7 +14184,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
         }
     }
 
-    if (cparams.pooling_type == LLAMA_POOLING_TYPE_CLS) {
+    if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_CLS) {
         const int64_t n_tokens = batch.n_tokens;
 
         GGML_ASSERT(lctx.inp_cls);
@@ -11914,6 +14205,37 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
         }
     }
 
+    if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) {
+        const int64_t n_tokens = batch.n_tokens;
+
+        GGML_ASSERT(lctx.inp_cls);
+        GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer));
+
+        uint32_t * data = (uint32_t *) lctx.inp_cls->data;
+        memset(lctx.inp_cls->data, 0, n_tokens * ggml_element_size(lctx.inp_cls));
+
+        std::vector last_pos(n_tokens, -1);
+        std::vector last_row(n_tokens, -1);
+
+        for (int i = 0; i < n_tokens; ++i) {
+            const llama_seq_id seq_id = batch.seq_id[i][0];
+            const llama_pos    pos    = batch.pos[i];
+
+            GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == LAST");
+
+            if (pos >= last_pos[seq_id]) {
+                last_pos[seq_id] = pos;
+                last_row[seq_id] = i;
+            }
+        }
+
+        for (int i = 0; i < n_tokens; ++i) {
+            if (last_row[i] >= 0) {
+                data[i] = last_row[i];
+            }
+        }
+    }
+
     if (kv_self.recurrent) {
         const int64_t n_kv = kv_self.n;
 
@@ -11960,6 +14282,70 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
             }
         }
     }
+
+    if (lctx.inp_pos_bucket) {
+        const int64_t n_tokens = batch.n_tokens;
+
+        GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_pos_bucket->buffer));
+
+        int32_t * data = (int32_t *) lctx.inp_pos_bucket->data;
+
+        if (!lctx.is_encoding) {
+            const int64_t n_kv = kv_self.n;
+            for (int h = 0; h < 1; ++h) {
+                for (int j = 0; j < n_tokens; ++j) {
+                    for (int i = 0; i < n_kv; ++i) {
+                        data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(lctx.kv_self.cells[i].pos, batch.pos[j], hparams.n_rel_attn_bkts, lctx.is_encoding);
+                    }
+                }
+            }
+        } else {
+            for (int h = 0; h < 1; ++h) {
+                for (int j = 0; j < n_tokens; ++j) {
+                    for (int i = 0; i < n_tokens; ++i) {
+                        data[h*(n_tokens*n_tokens) + j*n_tokens + i] = llama_relative_position_bucket(batch.pos[i], batch.pos[j], hparams.n_rel_attn_bkts, lctx.is_encoding);
+                    }
+                }
+            }
+        }
+    }
+
+    if (!lctx.is_encoding && lctx.inp_embd_enc) {
+        assert(lctx.inp_embd_enc->type == GGML_TYPE_F32);
+        assert((size_t) ggml_nelements(lctx.inp_embd_enc) == lctx.embd_enc.size());
+
+        ggml_backend_tensor_set(lctx.inp_embd_enc, lctx.embd_enc.data(), 0, ggml_nbytes(lctx.inp_embd_enc));
+    }
+
+    if (!lctx.is_encoding && lctx.inp_KQ_mask_cross) {
+        const int64_t n_output_enc = lctx.embd_enc.size() / hparams.n_embd;
+        const int64_t n_tokens = batch.n_tokens;
+
+        GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask_cross->buffer));
+
+        float * data = (float *) lctx.inp_KQ_mask_cross->data;
+
+        for (int h = 0; h < 1; ++h) {
+            for (int j = 0; j < n_tokens; ++j) {
+                for (int i = 0; i < n_output_enc; ++i) {
+                    float f = -INFINITY;
+                    for (int s = 0; s < batch.n_seq_id[j]; ++s) {
+                        const llama_seq_id seq_id = batch.seq_id[j][s];
+                        if (lctx.seq_ids_enc[i].find(seq_id) != lctx.seq_ids_enc[i].end()) {
+                            f = 0.0f;
+                        }
+                    }
+                    data[h*(n_output_enc*n_tokens) + j*n_output_enc + i] = f;
+                }
+            }
+
+            for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
+                for (int j = 0; j < n_output_enc; ++j) {
+                    data[h*(n_output_enc*n_tokens) + i*n_output_enc + j] = -INFINITY;
+                }
+            }
+        }
+    }
 }
 
 // Make sure enough space is available for outputs.
@@ -11975,8 +14361,8 @@ static size_t llama_output_reserve(llama_context & lctx, size_t n_outputs) {
     const auto n_embd  = hparams.n_embd;
 
     // TODO: use a per-batch flag for logits presence instead
-    const bool has_logits = cparams.causal_attn;
-    const bool has_embd   = cparams.embeddings && (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
+    const bool has_logits =  cparams.causal_attn;
+    const bool has_embd   =  lctx.is_encoding || (cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE));
 
     const size_t logits_size = has_logits ? n_vocab*n_outputs_max : 0;
     const size_t embd_size   = has_embd   ?  n_embd*n_outputs_max : 0;
@@ -12044,6 +14430,11 @@ static void llama_graph_compute(
         ggml_backend_cpu_set_n_threads(lctx.backend_cpu, n_threads);
         ggml_backend_cpu_set_abort_callback(lctx.backend_cpu, lctx.abort_callback, lctx.abort_callback_data);
     }
+#ifdef GGML_USE_BLAS
+    if (lctx.backend_blas != nullptr) {
+        ggml_backend_blas_set_n_threads(lctx.backend_blas, n_threads);
+    }
+#endif
 
     ggml_backend_sched_graph_compute_async(lctx.sched, gf);
 
@@ -12063,6 +14454,7 @@ static int llama_decode_internal(
          llama_context & lctx,
            llama_batch   batch_all) { // TODO: rename back to batch
 
+    lctx.is_encoding = false;
     const uint32_t n_tokens_all = batch_all.n_tokens;
 
     if (n_tokens_all == 0) {
@@ -12095,17 +14487,21 @@ static int llama_decode_internal(
 
     const auto n_ubatch = cparams.n_ubatch;
 
+    // TODO: simplify or deprecate
     std::vector pos;
     std::vector                   n_seq_id;
     std::vector            seq_id_arr;
     std::vector> seq_id;
 
+    // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
+    const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
+
     // count outputs
-    if (batch_all.logits) {
+    if (batch_all.logits && !embd_pooled) {
         for (uint32_t i = 0; i < n_tokens_all; ++i) {
             n_outputs += batch_all.logits[i] != 0;
         }
-    } else if (lctx.logits_all || (cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE)) {
+    } else if (lctx.logits_all || embd_pooled) {
         n_outputs = n_tokens_all;
     } else {
         // keep last output only
@@ -12151,7 +14547,7 @@ static int llama_decode_internal(
         {
             int32_t n_outputs_new = 0;
 
-            if (u_batch.logits) {
+            if (u_batch.logits && !embd_pooled) {
                 for (uint32_t i = 0; i < n_tokens; i++) {
                     n_outputs_new += u_batch.logits[i] != 0;
                 }
@@ -12236,47 +14632,27 @@ static int llama_decode_internal(
             // no output
             res  = nullptr;
             embd = nullptr;
-        } else if (!hparams.causal_attn) {
-            res = nullptr; // do not extract logits for embedding models such as BERT
+        }
 
-            // token or sequence embeddings
-            embd = gf->nodes[gf->n_nodes - 1];
-
-            GGML_ASSERT(strcmp(embd->name, "result_embd") == 0 || strcmp(embd->name, "result_embd_pooled") == 0);
-        } else if (cparams.embeddings) {
-            // the embeddings could be in the second to last tensor, or any of the previous tensors
-            int i_embd = gf->n_nodes - 2;
-            for (int i = 3; strcmp(embd->name, "result_norm") != 0; ++i) {
-                i_embd = gf->n_nodes - i;
-                if (i_embd < 0) { break; }
-                embd = gf->nodes[i_embd];
+        if (cparams.embeddings) {
+            for (int i = gf->n_nodes - 1; i >= 0; --i) {
+                embd = gf->nodes[i];
+                if (strcmp(embd->name, "result_embd_pooled") == 0) {
+                    break;
+                }
             }
-            GGML_ASSERT(i_embd >= 0 && "missing result_norm tensor");
-
-            // TODO: use a per-batch flag to know when to skip logits while keeping embeddings
-            if (!cparams.causal_attn) {
-                res = nullptr; // do not extract logits when not needed
-                // skip computing logits
-                // TODO: is this safe?
-                gf->n_nodes = i_embd + 1;
-            }
-        } else {
+            GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0 && "missing embeddings tensor");
+         } else {
             embd = nullptr; // do not extract embeddings when not needed
             GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor");
         }
-        // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
 
-        // for big prompts, if BLAS is enabled, it is better to use only one thread
-        // otherwise, the threads are spin-lock waiting for the BLAS calls and are degrading the performance
-        // TODO: this is mostly important for Apple Silicon where CBLAS is still performing very well
-        //       we still need some threads to process all non-mul_mat ops, but not too much to avoid interfering
-        //       with the BLAS calls. need a better solution
-        // MoE Special Case: This logic applies when hparams.n_expert == 0, i.e. the model is NOT an MoE model. When an MoE is
-        //                   being processed then Accelerate/BLAS will not be involved, so capping would limit performance.
-        if (n_tokens >= 32 && hparams.n_expert == 0 && ggml_cpu_has_blas() && !ggml_cpu_has_gpublas()) {
-            n_threads = std::min(4, n_threads);
+        if (!cparams.causal_attn) {
+            res = nullptr; // do not extract logits when not needed
         }
 
+        // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
+
         ggml_backend_sched_alloc_graph(lctx.sched, gf);
 
         llama_set_inputs(lctx, u_batch);
@@ -12293,12 +14669,6 @@ static int llama_decode_internal(
             }
         }
 
-#ifdef GGML_PERF
-        // print timing information per ggml operation (for debugging purposes)
-        // requires GGML_PERF to be defined
-        ggml_graph_print(gf);
-#endif
-
         // plot the computation graph in dot format (for debugging purposes)
         //if (n_past%100 == 0) {
         //    ggml_graph_dump_dot(gf, NULL, "llama.dot");
@@ -12339,11 +14709,10 @@ static int llama_decode_internal(
                             ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_outputs_new*n_embd*sizeof(float));
                         }
                     } break;
-                case LLAMA_POOLING_TYPE_CLS:
                 case LLAMA_POOLING_TYPE_MEAN:
+                case LLAMA_POOLING_TYPE_CLS:
+                case LLAMA_POOLING_TYPE_LAST:
                     {
-                        GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0);
-
                         // extract sequence embeddings
                         auto & embd_seq_out = lctx.embd_seq;
                         embd_seq_out.clear();
@@ -12359,8 +14728,8 @@ static int llama_decode_internal(
                     } break;
                 case LLAMA_POOLING_TYPE_UNSPECIFIED:
                     {
-                        GGML_ASSERT(false && "unknown pooling type");
-                    } break;
+                        GGML_ABORT("unknown pooling type");
+                    }
             }
         }
         n_outputs_prev += lctx.n_outputs;
@@ -12391,6 +14760,138 @@ static int llama_decode_internal(
     return 0;
 }
 
+// encode a batch of tokens by evaluating the encoder part of the transformer
+//
+//   - lctx:      llama context
+//   - batch:     batch to evaluate
+//
+// return 0 on success
+// return positive int on warning
+// return negative int on error
+//
+static int llama_encode_internal(
+         llama_context & lctx,
+           llama_batch   batch) {
+
+    lctx.is_encoding = true;
+
+    const uint32_t n_tokens = batch.n_tokens;
+
+    if (n_tokens == 0) {
+        LLAMA_LOG_ERROR("%s: n_tokens == 0", __func__);
+        return -1;
+    }
+
+    const auto & model   = lctx.model;
+    const auto & hparams = model.hparams;
+    const auto & cparams = lctx.cparams;
+
+    GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
+
+    // micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
+    GGML_ASSERT(cparams.n_ubatch >= n_tokens && "encoder requires n_ubatch >= n_tokens");
+
+    if (lctx.t_compute_start_us == 0) {
+        lctx.t_compute_start_us = ggml_time_us();
+    }
+
+    lctx.n_queued_tokens += n_tokens;
+
+    const int64_t n_embd = hparams.n_embd;
+
+    // TODO: simplify or deprecate
+    std::vector pos;
+    std::vector                   n_seq_id;
+    std::vector            seq_id_arr;
+    std::vector> seq_id;
+
+    // reserve output buffer
+    if (llama_output_reserve(lctx, n_tokens) < n_tokens) {
+        LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens);
+        return -2;
+    };
+
+    for (uint32_t i = 0; i < n_tokens; ++i) {
+        lctx.output_ids[i] = i;
+    }
+
+    lctx.inp_embd_enc = NULL;
+    lctx.n_outputs = n_tokens;
+
+    const int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
+    GGML_ASSERT(n_threads > 0);
+
+    // helpers for smoother batch API transition
+    // after deprecating the llama_eval calls, these will be removed
+    if (batch.pos == nullptr) {
+        pos.resize(n_tokens);
+        for (uint32_t i = 0; i < n_tokens; i++) {
+            pos[i] = batch.all_pos_0 + i*batch.all_pos_1;
+        }
+
+        batch.pos = pos.data();
+    }
+
+    if (batch.seq_id == nullptr) {
+        n_seq_id.resize(n_tokens);
+        seq_id.resize(n_tokens);
+        seq_id_arr.resize(n_tokens);
+        for (uint32_t i = 0; i < n_tokens; i++) {
+            n_seq_id[i] = 1;
+            seq_id[i].resize(1);
+            seq_id[i][0] = batch.all_seq_id;
+            seq_id_arr[i] = seq_id[i].data();
+        }
+
+        batch.n_seq_id = n_seq_id.data();
+        batch.seq_id = seq_id_arr.data();
+    }
+
+    ggml_backend_sched_reset(lctx.sched);
+    ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
+
+    ggml_cgraph * gf = llama_build_graph(lctx, batch, false);
+
+    // the output embeddings after the final encoder normalization
+    struct ggml_tensor * embd = gf->nodes[gf->n_nodes - 1];
+
+    GGML_ASSERT(strcmp(embd->name, "result_norm") == 0);
+
+    ggml_backend_sched_alloc_graph(lctx.sched, gf);
+
+    llama_set_inputs(lctx, batch);
+
+    llama_graph_compute(lctx, gf, n_threads);
+
+    // extract embeddings
+    if (embd) {
+        ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(lctx.sched, embd);
+        GGML_ASSERT(backend_embd != nullptr);
+
+        // extract token embeddings
+        GGML_ASSERT(lctx.embd != nullptr);
+
+        lctx.embd_enc.resize(n_tokens*n_embd);
+        float * embd_out = lctx.embd_enc.data();
+
+        ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_tokens*n_embd*sizeof(float));
+
+        // remember the sequence ids used during the encoding - needed for cross attention later
+        lctx.seq_ids_enc.resize(n_tokens);
+        for (uint32_t i = 0; i < n_tokens; i++) {
+            for (int s = 0; s < batch.n_seq_id[i]; s++) {
+                llama_seq_id seq_id = batch.seq_id[i][s];
+                lctx.seq_ids_enc[i].insert(seq_id);
+            }
+        }
+    }
+
+    // Reset state for the next token before backend sync, to allow the CPU activities in the reset to
+    // overlap with device computation.
+    ggml_backend_sched_reset(lctx.sched);
+
+    return 0;
+}
 
 // find holes from the beginning of the KV cache and fill them by moving data from the end of the cache
 static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
@@ -12413,9 +14914,9 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
     // each move requires 6*n_layer tensors (see build_defrag)
     //   - source view, destination view, copy operation
     //   - x2 for keys and values
-    //const uint32_t max_moves = LLAMA_MAX_NODES/(6*n_layer);
+    //const uint32_t max_moves = llama_model_max_nodes(model)/(6*n_layer);
     // TODO: tmp fix https://github.com/ggerganov/llama.cpp/issues/6685#issuecomment-2057579516
-    const uint32_t max_moves = (LLAMA_MAX_NODES - 2*n_layer)/(6*n_layer);
+    const uint32_t max_moves = (llama_model_max_nodes(lctx.model) - 2*n_layer)/(6*n_layer);
 
     // determine which KV cells to move where
     //
@@ -12618,6 +15119,10 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
 
     // apply K-shift if needed
     if (lctx.model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE && lctx.kv_self.has_shift) {
+        if (lctx.model.arch == LLM_ARCH_DEEPSEEK2) { // not supported due to MLA
+            GGML_ABORT("Deepseek2 does not support K-shift");
+        }
+
         {
             ggml_backend_sched_reset(lctx.sched);
 
@@ -12695,2049 +15200,6 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
     }
 }
 
-//
-// tokenizer
-//
-
-static enum llama_vocab_type llama_vocab_get_type(const llama_vocab & vocab) {
-    return vocab.type;
-}
-
-static bool llama_is_normal_token(const llama_vocab & vocab, llama_token id) {
-    GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
-    return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_NORMAL;
-}
-
-static bool llama_is_unknown_token(const llama_vocab & vocab, llama_token id) {
-    GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
-    return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_UNKNOWN;
-}
-
-static bool llama_is_control_token(const llama_vocab & vocab, llama_token id) {
-    GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
-    return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_CONTROL;
-}
-
-static bool llama_is_byte_token(const llama_vocab & vocab, llama_token id) {
-    GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
-    return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_BYTE;
-}
-
-static bool llama_is_user_defined_token(const llama_vocab& vocab, llama_token id) {
-    GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
-    return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_USER_DEFINED;
-}
-
-static uint8_t llama_token_to_byte(const llama_vocab& vocab, llama_token id) {
-    GGML_ASSERT(llama_vocab_get_type(vocab) != LLAMA_VOCAB_TYPE_NONE);
-    GGML_ASSERT(llama_is_byte_token(vocab, id));
-    const auto & token_data = vocab.id_to_token.at(id);
-    switch (llama_vocab_get_type(vocab)) {
-        case LLAMA_VOCAB_TYPE_SPM: {
-            auto buf = token_data.text.substr(3, 2);
-            return strtol(buf.c_str(), NULL, 16);
-        }
-        case LLAMA_VOCAB_TYPE_BPE: {
-            GGML_ASSERT(false);
-            return unicode_utf8_to_byte(token_data.text); // TODO: why is this here after GGML_ASSERT?
-        }
-        case LLAMA_VOCAB_TYPE_WPM: {
-            GGML_ASSERT(false);
-        }
-        default:
-            GGML_ASSERT(false);
-    }
-}
-
-static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch) {
-    GGML_ASSERT(llama_vocab_get_type(vocab) != LLAMA_VOCAB_TYPE_NONE);
-    static const char * hex = "0123456789ABCDEF";
-    switch (llama_vocab_get_type(vocab)) {
-        case LLAMA_VOCAB_TYPE_SPM: {
-            const char buf[7] = { '<', '0', 'x', hex[ch >> 4], hex[ch & 15], '>', 0 };
-            auto token = vocab.token_to_id.find(buf);
-            if (token != vocab.token_to_id.end()) {
-                return (*token).second;
-            }
-            // Try to fall back to just the byte as a string
-            const char buf2[2] = { (char)ch, 0 };
-            return vocab.token_to_id.at(buf2);
-        }
-        case LLAMA_VOCAB_TYPE_WPM:
-        case LLAMA_VOCAB_TYPE_BPE: {
-            return vocab.token_to_id.at(unicode_byte_to_utf8(ch));
-        }
-        default:
-            GGML_ASSERT(false);
-    }
-}
-
-static void llama_escape_whitespace(std::string & text) {
-    replace_all(text, " ", "\xe2\x96\x81");
-}
-
-static void llama_unescape_whitespace(std::string & word) {
-    replace_all(word, "\xe2\x96\x81", " ");
-}
-
-struct llm_symbol {
-    using index = int;
-    index prev;
-    index next;
-    const char * text;
-    size_t n;
-};
-
-static_assert(std::is_trivially_copyable::value, "llm_symbol is not trivially copyable");
-
-// SPM tokenizer
-// original implementation:
-// https://github.com/ggerganov/llama.cpp/commit/074bea2eb1f1349a0118239c4152914aecaa1be4
-
-struct llm_bigram_spm {
-    struct comparator {
-        bool operator()(llm_bigram_spm & l, llm_bigram_spm & r) {
-            return (l.score < r.score) || (l.score == r.score && l.left > r.left);
-        }
-    };
-    using queue_storage = std::vector;
-    using queue = std::priority_queue;
-    llm_symbol::index left;
-    llm_symbol::index right;
-    float score;
-    size_t size;
-};
-
-struct llm_tokenizer_spm {
-    llm_tokenizer_spm(const llama_vocab & vocab) : vocab(vocab) {}
-
-    void tokenize(const std::string & text, std::vector & output) {
-        // split string into utf8 chars
-        int index = 0;
-        size_t offs = 0;
-        while (offs < text.size()) {
-            llm_symbol sym;
-            size_t len = utf8_len(text[offs]);
-            sym.text = text.c_str() + offs;
-            sym.n = std::min(len, text.size() - offs);
-            offs += sym.n;
-            sym.prev = index - 1;
-            sym.next = offs == text.size() ? -1 : index + 1;
-            index++;
-            symbols.emplace_back(sym);
-        }
-
-        // seed the work queue with all possible 2-character tokens.
-        for (size_t i = 1; i < symbols.size(); ++i) {
-            try_add_bigram(i - 1, i);
-        }
-
-        // keep substituting the highest frequency pairs for as long as we can.
-        while (!work_queue.empty()) {
-            auto bigram = work_queue.top();
-            work_queue.pop();
-
-            auto & left_sym = symbols[bigram.left];
-            auto & right_sym = symbols[bigram.right];
-
-            // if one of the symbols already got merged, skip it.
-            if (left_sym.n == 0 || right_sym.n == 0 ||
-                left_sym.n + right_sym.n != bigram.size) {
-                continue;
-            }
-
-            // merge the right sym into the left one
-            left_sym.n += right_sym.n;
-            right_sym.n = 0;
-
-            //LLAMA_LOG_INFO("left = '%*s' size = %zu\n", (int) left_sym.n, left_sym.text, bigram.size);
-
-            // remove the right sym from the chain
-            left_sym.next = right_sym.next;
-            if (right_sym.next >= 0) {
-                symbols[right_sym.next].prev = bigram.left;
-            }
-
-            // find more substitutions
-            try_add_bigram(left_sym.prev, bigram.left);
-            try_add_bigram(bigram.left, left_sym.next);
-        }
-
-        for (int i = 0; i != -1; i = symbols[i].next) {
-            auto & symbol = symbols[i];
-            resegment(symbol, output);
-        }
-    }
-
-private:
-    void resegment(llm_symbol & symbol, std::vector & output) {
-        auto text = std::string(symbol.text, symbol.n);
-        auto token = vocab.token_to_id.find(text);
-
-        // Do we need to support is_unused?
-        if (token != vocab.token_to_id.end()) {
-            output.push_back((*token).second);
-            return;
-        }
-
-        const auto p = rev_merge.find(text);
-
-        if (p == rev_merge.end()) {
-            // output any symbols that did not form tokens as bytes.
-            output.reserve(output.size() + symbol.n);
-            for (int j = 0; j < (int)symbol.n; ++j) {
-                llama_vocab::id token_id = llama_byte_to_token(vocab, symbol.text[j]);
-                output.push_back(token_id);
-            }
-            return;
-        }
-
-        resegment(symbols[p->second.first],  output);
-        resegment(symbols[p->second.second], output);
-    }
-
-    void try_add_bigram(int left, int right) {
-        if (left == -1 || right == -1) {
-            return;
-        }
-
-        const std::string text = std::string(symbols[left].text, symbols[left].n + symbols[right].n);
-        auto token = vocab.token_to_id.find(text);
-
-        if (token == vocab.token_to_id.end()) {
-            return;
-        }
-
-        if (static_cast((*token).second) >= vocab.id_to_token.size()) {
-            return;
-        }
-
-        const auto & tok_data = vocab.id_to_token[(*token).second];
-
-        llm_bigram_spm bigram;
-        bigram.left  = left;
-        bigram.right = right;
-        bigram.score = tok_data.score;
-        bigram.size  = text.size();
-
-        work_queue.push(bigram);
-
-        // Do we need to support is_unused?
-        rev_merge[text] = std::make_pair(left, right);
-    }
-
-    const llama_vocab & vocab;
-
-    std::vector symbols;
-    llm_bigram_spm::queue work_queue;
-
-    std::map> rev_merge;
-};
-
-// BPE tokenizer
-// adapted from https://github.com/cmp-nct/ggllm.cpp [MIT License]
-// tried to simplify unicode stuff, so most likely does not work 100% correctly!
-
-// TODO: there are a lot of common parts between spm and bpe tokenizers, should be refactored and reused
-
-struct llm_bigram_bpe {
-    struct comparator {
-        bool operator()(const llm_bigram_bpe & l, const llm_bigram_bpe & r) const {
-            return l.rank > r.rank || (l.rank == r.rank && l.left > r.left);
-        }
-    };
-
-    using queue_storage = std::vector;
-    using queue = std::priority_queue;
-    llm_symbol::index left;
-    llm_symbol::index right;
-    std::string text;
-    int rank;
-    size_t size;
-};
-
-struct llm_tokenizer_bpe {
-    llm_tokenizer_bpe(const llama_vocab & vocab): vocab(vocab) {}
-
-    void tokenize(const std::string & text, std::vector & output) {
-        int final_prev_index = -1;
-        bool ignore_merges = false;
-
-        std::vector word_collection;
-        switch (vocab.type) {
-            case LLAMA_VOCAB_TYPE_BPE:
-                switch (vocab.type_pre) {
-                    case LLAMA_VOCAB_PRE_TYPE_LLAMA3:
-                        ignore_merges = true;
-                        word_collection = unicode_regex_split(text, {
-                            // original regex from tokenizer.json
-                            //"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
-
-                            // adapted: https://github.com/ggerganov/llama.cpp/pull/6920#issuecomment-2080233989
-                            "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
-                        });
-                        break;
-                    case LLAMA_VOCAB_PRE_TYPE_DBRX:
-                    case LLAMA_VOCAB_PRE_TYPE_SMAUG:
-                        word_collection = unicode_regex_split(text, {
-                            // same as llama3
-                            "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
-                        });
-                        break;
-                    case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM:
-                        word_collection = unicode_regex_split(text, {
-                            "[\r\n]",
-                            "\\s?[A-Za-zµÀ-ÖØ-öø-ƺƼ-ƿDŽ-ʓʕ-ʯͰ-ͳͶͷͻ-ͽͿΆΈ-ΊΌΎ-ΡΣ-ϵϷ-ҁҊ-ԯԱ-ՖႠ-ჅᎠ-Ᏽᏸ-ᏽᲐ-ᲺᲽ-Ჿᴀ-ᴫᵫ-ᵷᵹ-ᶚḀ-ἕἘ-Ἕἠ-ὅὈ-Ὅὐ-ὗὙὛὝὟ-ώᾀ-ᾴᾶ-ᾼιῂ-ῄῆ-ῌῐ-ΐῖ-Ίῠ-Ῥῲ-ῴῶ-ῼℂℇℊ-ℓℕℙ-ℝℤΩℨK-ℭℯ-ℴℹℼ-ℿⅅ-ⅉⅎↃↄⰀ-ⱻⱾ-ⳤⳫ-ⳮⳲⳳꙀ-ꙭꚀ-ꚛꜢ-ꝯꝱ-ꞇꞋ-ꞎꭰ-ꮿff-stﬓ-ﬗA-Za-z𐐀-𐑏𐒰-𐓓𐓘-𐓻𐲀-𐲲𐳀-𐳲𑢠-𑣟𞤀-𞥃]+",
-                            "\\s?[!-/:-~!-/:-~‘-‟ -。]+",
-                            "\\s+$",
-                            "[一-龥ࠀ-一가-퟿]+",
-                            "\\p{N}+",
-                        });
-                        break;
-                    case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER:
-                        word_collection = unicode_regex_split(text, {
-                            "[\r\n]",
-                            "\\s?\\p{L}+",
-                            "\\s?\\p{P}+",
-                            "[一-龥ࠀ-一가-퟿]+",
-                            "\\p{N}",
-                        });
-                        break;
-                    case LLAMA_VOCAB_PRE_TYPE_FALCON:
-                        word_collection = unicode_regex_split(text, {
-                            "[\\p{P}\\$\\+<=>\\^~\\|]+",
-                            "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
-                            "[0-9][0-9][0-9]",
-                        });
-                        break;
-                    case LLAMA_VOCAB_PRE_TYPE_MPT:
-                        // TODO: MPT pre-tokenization regexes are unknown
-                        //       the following are close, but not exact. run the following:
-                        //       ./bin/test-tokenizer-0 ../models/ggml-vocab-mpt.gguf
-                        GGML_ASSERT("MPT pre-tokenization regexes are unknown - fixes needed");
-                        word_collection = unicode_regex_split(text, {
-                            "\\s?\\p{L}+",
-                            "\\s?\\p{P}+",
-                            "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
-                        });
-                        break;
-                    case LLAMA_VOCAB_PRE_TYPE_STARCODER:
-                    case LLAMA_VOCAB_PRE_TYPE_REFACT:
-                    case LLAMA_VOCAB_PRE_TYPE_COMMAND_R:
-                        word_collection = unicode_regex_split(text, {
-                            "\\p{N}",
-                            "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
-                        });
-                        break;
-                    case LLAMA_VOCAB_PRE_TYPE_GPT2:
-                    case LLAMA_VOCAB_PRE_TYPE_OLMO:
-                        word_collection = unicode_regex_split(text, {
-                            "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
-                        });
-                        break;
-                    case LLAMA_VOCAB_PRE_TYPE_STABLELM2:
-                    case LLAMA_VOCAB_PRE_TYPE_QWEN2:
-                        word_collection = unicode_regex_split(text, {
-                            // original regex from tokenizer.json
-                            // "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
-                            "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
-                        });
-                        break;
-                    default:
-                        // default regex for BPE tokenization pre-processing
-                        word_collection = unicode_regex_split(text, {
-                            "[\\p{P}\\$\\+<=>\\^~\\|]+",
-                            "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
-                            "\\p{N}+",
-                            "[0-9][0-9][0-9]",
-                        });
-                        break;
-                }
-                break;
-            default:
-                GGML_ASSERT(false);
-                break;
-        }
-
-        symbols_final.clear();
-
-        for (auto & word : word_collection) {
-            work_queue = llm_bigram_bpe::queue();
-            symbols.clear();
-
-            int index = 0;
-            size_t offset = 0;
-
-            if (ignore_merges && vocab.token_to_id.find(word) != vocab.token_to_id.end()) {
-                symbols.emplace_back(llm_symbol{-1, -1, word.c_str(), word.size()});
-                offset = word.size();
-            }
-
-            while (offset < word.size()) {
-                llm_symbol sym;
-                size_t char_len = std::min(word.size() - offset, (size_t) ::utf8_len(word[offset]));
-                sym.text = word.c_str() + offset;
-                sym.n = char_len;
-                offset += sym.n;
-                sym.prev = index - 1;
-                sym.next = offset == word.size() ? -1 : index + 1;
-                index++;
-                symbols.emplace_back(sym);
-            }
-            for (size_t i = 1; i < symbols.size(); ++i) {
-                add_new_bigram(i - 1, i);
-            }
-
-            // build token(s)
-            while (!work_queue.empty()) {
-                auto bigram = work_queue.top();
-                work_queue.pop();
-
-                auto & left_symbol = symbols[bigram.left];
-                auto & right_symbol = symbols[bigram.right];
-
-                if (left_symbol.n == 0 || right_symbol.n == 0) {
-                    continue;
-                }
-                std::string left_token = std::string(left_symbol.text, left_symbol.n);
-                std::string right_token = std::string(right_symbol.text, right_symbol.n);
-                if (left_token + right_token != bigram.text) {
-                    continue;  // Skip this bigram if it's outdated
-                }
-
-                // merge the right sym into the left one
-                left_symbol.n += right_symbol.n;
-                right_symbol.n = 0;
-
-                // remove the right sym from the chain
-                left_symbol.next = right_symbol.next;
-                if (right_symbol.next >= 0) {
-                    symbols[right_symbol.next].prev = bigram.left;
-                }
-
-                add_new_bigram(left_symbol.prev, bigram.left);  // left side of current symbol
-                add_new_bigram(bigram.left, left_symbol.next);  // right side of current symbol
-            }
-
-            // add the finished tokens to the final list keeping correct order for next and prev
-            for (auto & sym : symbols) {
-                if (sym.n > 0) {
-                    sym.prev = final_prev_index;
-                    sym.next = -1;
-                    if (final_prev_index != -1) {
-                        symbols_final[final_prev_index].next = symbols_final.size();
-                    }
-                    symbols_final.emplace_back(sym);
-                    final_prev_index = symbols_final.size() - 1;
-                }
-            }
-        }
-
-        symbols = symbols_final;
-
-        if (!symbols.empty()) {
-            for (int i = 0; i != -1; i = symbols[i].next) {
-                auto & symbol = symbols[i];
-                if (symbol.n == 0) {
-                    continue;
-                }
-
-                const std::string str = std::string(symbol.text, symbol.n);
-                const auto token = vocab.token_to_id.find(str);
-
-                if (token == vocab.token_to_id.end()) {
-                    for (auto j = str.begin(); j != str.end(); ++j) {
-                        std::string byte_str(1, *j);
-                        auto token_multibyte = vocab.token_to_id.find(byte_str);
-                        if (token_multibyte == vocab.token_to_id.end()) {
-                            throw std::runtime_error("ERROR: byte not found in vocab");
-                        }
-                        output.push_back((*token_multibyte).second);
-                    }
-                } else {
-                    output.push_back((*token).second);
-                }
-            }
-        }
-    }
-
-private:
-    void add_new_bigram(int left, int right) {
-        if (left == -1 || right == -1) {
-            return;
-        }
-
-        std::string left_token  = std::string(symbols[left].text,  symbols[left].n);
-        std::string right_token = std::string(symbols[right].text, symbols[right].n);
-
-        int rank_found = -1;
-
-        rank_found = vocab.find_bpe_rank(left_token, right_token);
-
-        if (rank_found < 0) {
-            return;
-        }
-
-        llm_bigram_bpe bigram;
-
-        bigram.left  = left;
-        bigram.right = right;
-        bigram.text  = left_token + right_token;
-        bigram.size  = left_token.size() + right_token.size();
-        bigram.rank  = rank_found;
-
-        work_queue.push(bigram);
-    }
-
-    const llama_vocab & vocab;
-
-    std::vector symbols;
-    std::vector symbols_final;
-
-    llm_bigram_bpe::queue work_queue;
-};
-
-struct llm_tokenizer_wpm {
-    llm_tokenizer_wpm(const llama_vocab & vocab): vocab(vocab) {}
-
-    void tokenize(const std::string & text, std::vector & output) {
-        const auto & token_map = vocab.token_to_id;
-
-        // normalize and split by whitespace
-        std::vector words = preprocess(text);
-
-        // bos token prepended already
-
-        // find the longest tokens that form the words
-        for (const std::string &word : words) {
-            // skip empty words
-            if (word.size() == 0) {
-                continue;
-            }
-
-            // prepend phantom space
-            const std::string word1 = "\xe2\x96\x81" + word;
-            const int n = word1.size();
-
-            const size_t current_tokens = output.size();
-
-            // we're at the start of a new word
-            // move through character position in word
-            for (int i = 0; i < n; ++i) {
-                // loop through possible match length
-                bool match = false;
-                for (int j = n; j > i; j--) {
-                    auto it = token_map.find(word1.substr(i, j - i));
-                    if (it != token_map.end()) {
-                        output.push_back(it->second);
-                        match = true;
-                        i = j - 1;
-                        break;
-                    }
-                }
-
-                if (!match) { // discard all
-                    output.resize(current_tokens);
-                    break;  // and discard next tokens
-                }
-            }
-
-            // we didn't find any matches for this word
-            if (current_tokens == output.size()) {
-                output.push_back(vocab.special_unk_id);
-            }
-        }
-    }
-
-    std::vector preprocess(const std::string & text) {
-        const std::vector cpts_nfd = unicode_cpts_normalize_nfd(unicode_cpts_from_utf8(text));
-        std::vector words(1, "");
-
-        for (const char32_t cpt : cpts_nfd) {
-            const auto flags = unicode_cpt_flags(cpt);
-
-            if (flags.is_whitespace) {
-                if (words.back().size()) {  // finish previous word if any
-                    words.emplace_back();
-                }
-                continue;
-            }
-
-            assert (!flags.is_separator);
-            if (cpt == 0 || cpt == 0xFFFD || flags.is_control) {
-                continue;
-            }
-
-            const std::string s = unicode_cpt_to_utf8(unicode_tolower(cpt));
-            if (flags.is_punctuation || ( cpt < 0x7F && flags.is_symbol ) || is_chinese_char(cpt)) {
-                if (words.back().size()) {  // finish previous word if any
-                    words.emplace_back();
-                }
-                words.back() = s;       // single char word
-                words.emplace_back();   // start a new word
-            } else {
-                words.back() += s;  // append char to word
-            }
-        }
-
-        if (!words.back().size()) {
-            words.pop_back();
-        }
-
-        return words;
-    }
-
-    static bool is_chinese_char(uint32_t cpt) {
-        return
-            (cpt >= 0x04E00 && cpt <= 0x09FFF) ||
-            (cpt >= 0x03400 && cpt <= 0x04DBF) ||
-            (cpt >= 0x20000 && cpt <= 0x2A6DF) ||
-            (cpt >= 0x2A700 && cpt <= 0x2B73F) ||
-            (cpt >= 0x2B740 && cpt <= 0x2B81F) ||
-            (cpt >= 0x2B920 && cpt <= 0x2CEAF) || // this should be 0x2B820 but in hf rust code it is 0x2B920
-            (cpt >= 0x0F900 && cpt <= 0x0FAFF) ||
-            (cpt >= 0x2F800 && cpt <= 0x2FA1F);
-            //(cpt >= 0x3000  && cpt <= 0x303F)  ||
-            //(cpt >= 0xFF00  && cpt <= 0xFFEF);
-    }
-
-    const llama_vocab & vocab;
-};
-
-typedef enum FRAGMENT_BUFFER_VARIANT_TYPE {
-    FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN,
-    FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT
-} FRAGMENT_BUFFER_VARIANT_TYPE;
-
-struct fragment_buffer_variant {
-    fragment_buffer_variant(llama_vocab::id _token)
-    :
-        type(FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN),
-        token(_token),
-        raw_text(_dummy),
-        offset(0),
-        length(0) {}
-
-    fragment_buffer_variant(const std::string & _raw_text, int64_t _offset, int64_t _length)
-    :
-        type(FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT),
-        token((llama_vocab::id) - 1),
-        raw_text(_raw_text),
-        offset(_offset),
-        length(_length){
-            GGML_ASSERT(_offset >= 0);
-            GGML_ASSERT(_length >= 1);
-            GGML_ASSERT(offset + length <= raw_text.length());
-        }
-
-    const FRAGMENT_BUFFER_VARIANT_TYPE type;
-    const llama_vocab::id token;
-    const std::string _dummy;
-    const std::string & raw_text;
-    const uint64_t offset;
-    const uint64_t length;
-};
-
-// #define PRETOKENIZERDEBUG
-
-static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list & buffer) {
-    // for each special token
-    for (const llama_vocab::id special_id : vocab.cache_special_tokens) {
-        const auto & data = vocab.id_to_token[special_id];
-        const auto & special_token = data.text;
-
-        // for each text fragment
-        std::forward_list::iterator it = buffer.begin();
-        while (it != buffer.end()) {
-            auto & fragment = (*it);
-
-            // if a fragment is text ( not yet processed )
-            if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
-                auto & raw_text = fragment.raw_text;
-
-                auto raw_text_base_offset = fragment.offset;
-                auto raw_text_base_length = fragment.length;
-
-                // loop over the text
-                while (true) {
-                    // find the first occurrence of a given special token in this fragment
-                    //  passing offset argument only limit the "search area" but match coordinates
-                    //  are still relative to the source full raw_text
-                    auto match = raw_text.find(special_token, raw_text_base_offset);
-
-                    // no occurrences found, stop processing this fragment for a given special token
-                    if (match == std::string::npos) break;
-
-                    // check if match is within bounds of offset <-> length
-                    if (match + special_token.length() > raw_text_base_offset + raw_text_base_length) break;
-
-#ifdef PRETOKENIZERDEBUG
-                    LLAMA_LOG_WARN("FF: (%ld %ld %ld) '%s'\n", raw_text->length(), raw_text_base_offset, raw_text_base_length, raw_text->substr(raw_text_base_offset, raw_text_base_length).c_str());
-#endif
-                    auto source = std::distance(buffer.begin(), it);
-
-                    // if match is further than base offset
-                    //  then we have some text to the left of it
-                    if (match > raw_text_base_offset) {
-                        // left
-                        const int64_t left_reminder_offset = raw_text_base_offset + 0;
-                        int64_t left_reminder_length = match - raw_text_base_offset;
-
-                        if (data.attr & LLAMA_TOKEN_ATTR_LSTRIP) {
-                            while (left_reminder_length > 0 && isspace(raw_text[left_reminder_offset + left_reminder_length - 1])) {
-                                left_reminder_length--;
-                            }
-                        }
-
-                        if (left_reminder_length > 0) {
-                            buffer.emplace_after(it, raw_text, left_reminder_offset, left_reminder_length);
-                            it++;
-                        }
-
-#ifdef PRETOKENIZERDEBUG
-                        LLAMA_LOG_WARN("FL: (%ld %ld) '%s'\n", left_reminder_offset, left_reminder_length, raw_text->substr(left_reminder_offset, left_reminder_length).c_str());
-#endif
-                    }
-
-                    // special token
-                    buffer.emplace_after(it, special_id);
-                    it++;
-
-                    // right
-                    if (match + special_token.length() < raw_text_base_offset + raw_text_base_length) {
-                        int64_t right_reminder_offset = match + special_token.length();
-                        int64_t right_reminder_length = raw_text_base_length - ((match - raw_text_base_offset) + special_token.length());
-
-                        if (data.attr & LLAMA_TOKEN_ATTR_RSTRIP) {
-                            while (right_reminder_length > 0 && isspace(raw_text[right_reminder_offset])) {
-                                right_reminder_offset++;
-                                right_reminder_length--;
-                            }
-                        }
-
-                        if (right_reminder_length > 0) {
-                            buffer.emplace_after(it, raw_text, right_reminder_offset, right_reminder_length);
-                            it++;
-                        }
-
-#ifdef PRETOKENIZERDEBUG
-                        LLAMA_LOG_WARN("FR: (%ld %ld) '%s'\n", right_reminder_offset, right_reminder_length, raw_text->substr(right_reminder_offset, right_reminder_length).c_str());
-#endif
-
-                        if (source == 0) {
-                            buffer.erase_after(buffer.before_begin());
-                        } else {
-                            buffer.erase_after(std::next(buffer.begin(), (source-1)));
-                        }
-
-                        // repeat for the right side
-                        raw_text_base_offset = right_reminder_offset;
-                        raw_text_base_length = right_reminder_length;
-
-#ifdef PRETOKENIZERDEBUG
-                        LLAMA_LOG_WARN("RR: (%ld %ld) '%s'\n", raw_text_base_offset, raw_text_base_length, raw_text->substr(raw_text_base_offset, raw_text_base_length).c_str());
-#endif
-                    } else {
-                        if (source == 0) {
-                            buffer.erase_after(buffer.before_begin());
-                        } else {
-                            buffer.erase_after(std::next(buffer.begin(), (source-1)));
-                        }
-                        break;
-                    }
-                }
-            }
-            it++;
-        }
-    }
-}
-
-static std::vector llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool add_special, bool parse_special) {
-    std::vector output;
-    std::forward_list fragment_buffer;
-
-    if (!raw_text.empty()) {
-        fragment_buffer.emplace_front(raw_text, 0, raw_text.length());
-        if (parse_special) tokenizer_st_partition(vocab, fragment_buffer);
-    }
-
-    switch (vocab.type) {
-        case LLAMA_VOCAB_TYPE_SPM:
-            {
-                // OG tokenizer behavior:
-                //
-                // tokenizer.encode('', add_special_tokens=True)  returns [1]
-                // tokenizer.encode('', add_special_tokens=False) returns []
-
-                bool is_prev_special = false;
-
-                if (add_special && vocab.special_add_bos != 0) {
-                    GGML_ASSERT(vocab.special_bos_id != -1);
-                    output.push_back(vocab.special_bos_id);
-                    is_prev_special = true;
-                }
-
-                for (const auto & fragment : fragment_buffer) {
-                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
-                        auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
-
-                        if (vocab.add_space_prefix) {
-                            if (!output.size() || is_prev_special) {  // prefix with space if first token
-                                raw_text = " " + raw_text;
-                            }
-                        }
-
-#ifdef PRETOKENIZERDEBUG
-                        LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
-#endif
-                        llm_tokenizer_spm tokenizer(vocab);
-                        llama_escape_whitespace(raw_text);
-                        tokenizer.tokenize(raw_text, output);
-                    } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
-                        output.push_back(fragment.token);
-                        is_prev_special = true;
-                    }
-                }
-
-                if (add_special && vocab.special_add_bos != 0 && output.size() >= 2 && output[1] == vocab.special_bos_id) {
-                    LLAMA_LOG_WARN(
-                        "%s: Added a BOS token to the prompt as specified by the model but the prompt "
-                        "also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. "
-                        "Are you sure this is what you want?\n", __FUNCTION__);
-                }
-
-                if (add_special && vocab.special_add_eos == 1) {
-                    GGML_ASSERT(vocab.special_eos_id != -1);
-                    output.push_back(vocab.special_eos_id);
-                }
-            } break;
-        case LLAMA_VOCAB_TYPE_BPE:
-            {
-                if (add_special && vocab.special_add_bos != 0) {
-                    GGML_ASSERT(vocab.special_bos_id != -1);
-                    output.push_back(vocab.special_bos_id);
-                }
-
-                for (const auto & fragment : fragment_buffer) {
-                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
-                        auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
-
-#ifdef PRETOKENIZERDEBUG
-                        LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
-#endif
-                        llm_tokenizer_bpe tokenizer(vocab);
-                        tokenizer.tokenize(raw_text, output);
-                    } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
-                        output.push_back(fragment.token);
-                    }
-                }
-
-                if (add_special && vocab.special_add_bos != 0 && output.size() >= 2 && output[1] == vocab.special_bos_id) {
-                    LLAMA_LOG_WARN(
-                        "%s: Added a BOS token to the prompt as specified by the model but the prompt "
-                        "also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. "
-                        "Are you sure this is what you want?\n", __FUNCTION__);
-                }
-
-                if (add_special && vocab.special_add_eos == 1) {
-                    GGML_ASSERT(vocab.special_add_eos != -1);
-                    output.push_back(vocab.special_eos_id);
-                }
-            } break;
-        case LLAMA_VOCAB_TYPE_WPM:
-            {
-                if (add_special) {
-                    GGML_ASSERT(vocab.special_cls_id != -1);
-                    output.push_back(vocab.special_cls_id);
-                }
-
-                for (const auto & fragment : fragment_buffer) {
-                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
-                        auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
-
-#ifdef PRETOKENIZERDEBUG
-                        LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
-#endif
-                        llm_tokenizer_wpm tokenizer(vocab);
-                        tokenizer.tokenize(raw_text, output);
-                    } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
-                        output.push_back(fragment.token);
-                    }
-                }
-
-                if (add_special) {
-                    GGML_ASSERT(vocab.special_sep_id != -1);
-                    output.push_back(vocab.special_sep_id);
-                }
-            } break;
-        case LLAMA_VOCAB_TYPE_NONE:
-            GGML_ASSERT(false);
-    }
-
-    return output;
-}
-
-//
-// grammar - internal
-//
-
-
-// Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as
-// pointer. If an invalid sequence is encountered, returns `llama_partial_utf8.n_remain == -1`.
-std::pair, llama_partial_utf8> decode_utf8(
-        const std::string & src,
-        llama_partial_utf8   partial_start) {
-    static const int      lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 };
-    const char          * pos      = src.c_str();
-    std::vector code_points;
-    // common english strings have the same number of codepoints and bytes. `+ 1` for the terminating 0.
-    code_points.reserve(src.size() + 1);
-    uint32_t              value    = partial_start.value;
-    int                   n_remain = partial_start.n_remain;
-
-    // continue previous decode, if applicable
-    while (*pos != 0 && n_remain > 0) {
-        uint8_t next_byte = static_cast(*pos);
-        if ((next_byte >> 6) != 2) {
-            // invalid sequence, abort
-            code_points.push_back(0);
-            return std::make_pair(std::move(code_points), llama_partial_utf8{ 0, -1 });
-        }
-        value = (value << 6) + (next_byte & 0x3F);
-        ++pos;
-        --n_remain;
-    }
-
-    if (partial_start.n_remain > 0 && n_remain == 0) {
-        code_points.push_back(value);
-    }
-
-    // decode any subsequent utf-8 sequences, which may end in an incomplete one
-    while (*pos != 0) {
-        uint8_t  first_byte = static_cast(*pos);
-        uint8_t  highbits   = first_byte >> 4;
-                 n_remain   = lookup[highbits] - 1;
-
-        if (n_remain < 0) {
-            // invalid sequence, abort
-            code_points.clear();
-            code_points.push_back(0);
-            return std::make_pair(std::move(code_points), llama_partial_utf8{ 0, n_remain });
-        }
-
-        uint8_t  mask       = (1 << (7 - n_remain)) - 1;
-                 value      = first_byte & mask;
-        ++pos;
-        while (*pos != 0 && n_remain > 0) {
-            value = (value << 6) + (static_cast(*pos) & 0x3F);
-            ++pos;
-            --n_remain;
-        }
-        if (n_remain == 0) {
-            code_points.push_back(value);
-        }
-    }
-    code_points.push_back(0);
-
-    return std::make_pair(std::move(code_points), llama_partial_utf8{ value, n_remain });
-}
-
-// returns true iff pos points to the end of one of the definitions of a rule
-static bool llama_grammar_is_end_of_sequence(const llama_grammar_element * pos) {
-    switch (pos->type) {
-        case LLAMA_GRETYPE_END: return true;  // NOLINT
-        case LLAMA_GRETYPE_ALT: return true;  // NOLINT
-        default:                return false;
-    }
-}
-
-// returns true iff chr satisfies the char range at pos (regular or inverse range)
-// asserts that pos is pointing to a char range element
-static std::pair llama_grammar_match_char(
-        const llama_grammar_element * pos,
-        const uint32_t                chr) {
-
-    bool found            = false;
-    bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR || pos->type == LLAMA_GRETYPE_CHAR_ANY;
-
-    GGML_ASSERT(is_positive_char || pos->type == LLAMA_GRETYPE_CHAR_NOT); // NOLINT
-
-    do {
-        if (pos[1].type == LLAMA_GRETYPE_CHAR_RNG_UPPER) {
-            // inclusive range, e.g. [a-z]
-            found = found || (pos->value <= chr && chr <= pos[1].value);
-            pos += 2;
-        } else if (pos->type == LLAMA_GRETYPE_CHAR_ANY) {
-            // Any character matches "."
-            found = true;
-            pos += 1;
-        } else {
-            // exact char match, e.g. [a] or "a"
-            found = found || pos->value == chr;
-            pos += 1;
-        }
-    } while (pos->type == LLAMA_GRETYPE_CHAR_ALT);
-
-    return std::make_pair(found == is_positive_char, pos);
-}
-
-// returns true iff some continuation of the given partial UTF-8 sequence could satisfy the char
-// range at pos (regular or inverse range)
-// asserts that pos is pointing to a char range element
-static bool llama_grammar_match_partial_char(
-        const llama_grammar_element * pos,
-        const llama_partial_utf8      partial_utf8) {
-
-    bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR || pos->type == LLAMA_GRETYPE_CHAR_ANY;
-    GGML_ASSERT(is_positive_char || pos->type == LLAMA_GRETYPE_CHAR_NOT);
-
-    uint32_t partial_value = partial_utf8.value;
-    int      n_remain      = partial_utf8.n_remain;
-
-    // invalid sequence or 7-bit char split across 2 bytes (overlong)
-    if (n_remain < 0 || (n_remain == 1 && partial_value < 2)) {
-        return false;
-    }
-
-    // range of possible code points this partial UTF-8 sequence could complete to
-    uint32_t low  = partial_value << (n_remain * 6);
-    uint32_t high = low | ((1 << (n_remain * 6)) - 1);
-
-    if (low == 0) {
-        if (n_remain == 2) {
-            low = 1 << 11;
-        } else if (n_remain == 3) {
-            low = 1 << 16;
-        }
-    }
-
-    do {
-        if (pos[1].type == LLAMA_GRETYPE_CHAR_RNG_UPPER) {
-            // inclusive range, e.g. [a-z]
-            if (pos->value <= high && low <= pos[1].value) {
-                return is_positive_char;
-            }
-            pos += 2;
-        } else if (pos->type == LLAMA_GRETYPE_CHAR_ANY) {
-            // Any character matches "."
-            return true;
-        } else {
-            // exact char match, e.g. [a] or "a"
-            if (low <= pos->value && pos->value <= high) {
-                return is_positive_char;
-            }
-            pos += 1;
-        }
-    } while (pos->type == LLAMA_GRETYPE_CHAR_ALT);
-
-    return !is_positive_char;
-}
-
-
-// transforms a grammar pushdown stack into N possible stacks, all ending
-// at a character range (terminal element)
-static void llama_grammar_advance_stack(
-        const std::vector>   & rules,
-        const std::vector        & stack,
-        std::vector> & new_stacks) {
-
-    if (stack.empty()) {
-        if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) {
-            new_stacks.emplace_back(stack);
-        }
-        return;
-    }
-
-    const llama_grammar_element * pos = stack.back();
-
-    switch (pos->type) {
-        case LLAMA_GRETYPE_RULE_REF: {
-            const size_t                  rule_id = static_cast(pos->value);
-            const llama_grammar_element * subpos  = rules[rule_id].data();
-            do {
-                // init new stack without the top (pos)
-                std::vector new_stack(stack.begin(), stack.end() - 1);
-                if (!llama_grammar_is_end_of_sequence(pos + 1)) {
-                    // if this rule ref is followed by another element, add that to stack
-                    new_stack.push_back(pos + 1);
-                }
-                if (!llama_grammar_is_end_of_sequence(subpos)) {
-                    // if alternate is nonempty, add to stack
-                    new_stack.push_back(subpos);
-                }
-                llama_grammar_advance_stack(rules, new_stack, new_stacks);
-                while (!llama_grammar_is_end_of_sequence(subpos)) {
-                    // scan to end of alternate def
-                    subpos++;
-                }
-                if (subpos->type == LLAMA_GRETYPE_ALT) {
-                    // there's another alternate def of this rule to process
-                    subpos++;
-                } else {
-                    break;
-                }
-            } while (true);
-            break;
-        }
-        case LLAMA_GRETYPE_CHAR:
-        case LLAMA_GRETYPE_CHAR_NOT:
-        case LLAMA_GRETYPE_CHAR_ANY:
-            if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) {
-                // only add the stack if it's not a duplicate of one we already have
-                new_stacks.emplace_back(stack);
-            }
-            break;
-        default:
-            // end of alternate (LLAMA_GRETYPE_END, LLAMA_GRETYPE_ALT) or middle of char range
-            // (LLAMA_GRETYPE_CHAR_ALT, LLAMA_GRETYPE_CHAR_RNG_UPPER); stack should never be left on
-            // those
-            GGML_ASSERT(false);
-    }
-}
-
-// takes a set of possible pushdown stacks on a grammar, which are required to
-// be positioned at a character range (see `llama_grammar_advance_stack`), and
-// produces the N possible stacks if the given char is accepted at those
-// positions
-void llama_grammar_accept(
-        const std::vector>         & rules,
-        const std::vector> & stacks,
-        const uint32_t                                                  chr,
-        std::vector>       & new_stacks) {
-
-    new_stacks.clear();
-
-    for (const auto & stack : stacks) {
-        if (stack.empty()) {
-            continue;
-        }
-
-        auto match = llama_grammar_match_char(stack.back(), chr);
-        if (match.first) {
-            const llama_grammar_element * pos = match.second;
-
-            // update top of stack to next element, if any
-            std::vector new_stack(stack.begin(), stack.end() - 1);
-            if (!llama_grammar_is_end_of_sequence(pos)) {
-                new_stack.push_back(pos);
-            }
-            llama_grammar_advance_stack(rules, new_stack, new_stacks);
-        }
-    }
-}
-
-static std::vector llama_grammar_reject_candidates(
-        const std::vector>         & rules,
-        const std::vector> & stacks,
-        const std::vector                    & candidates);
-
-static std::vector llama_grammar_reject_candidates_for_stack(
-        const std::vector> & rules,
-        const std::vector      & stack,
-        const std::vector            & candidates) {
-
-    std::vector rejects;
-    rejects.reserve(candidates.size());
-
-    if (stack.empty()) {
-        for (const auto & tok : candidates) {
-            if (*tok.code_points != 0 || tok.partial_utf8.n_remain != 0) {
-                rejects.push_back(tok);
-            }
-        }
-        return rejects;
-    }
-
-    const llama_grammar_element * stack_pos = stack.back();
-
-    std::vector next_candidates;
-    next_candidates.reserve(candidates.size());
-
-    for (const auto & tok : candidates) {
-        if (*tok.code_points == 0) {
-            // reached end of full codepoints in token, reject iff it ended in a partial sequence
-            // that cannot satisfy this position in grammar
-            if (tok.partial_utf8.n_remain != 0 &&
-                    !llama_grammar_match_partial_char(stack_pos, tok.partial_utf8)) {
-                rejects.push_back(tok);
-            }
-        } else if (llama_grammar_match_char(stack_pos, *tok.code_points).first) {
-            next_candidates.push_back({ tok.index, tok.code_points + 1, tok.partial_utf8 });
-        } else {
-            rejects.push_back(tok);
-        }
-    }
-
-    const auto * stack_pos_after = llama_grammar_match_char(stack_pos, 0).second;
-
-    // update top of stack to next element, if any
-    std::vector stack_after(stack.begin(), stack.end() - 1);
-    if (!llama_grammar_is_end_of_sequence(stack_pos_after)) {
-        stack_after.push_back(stack_pos_after);
-    }
-    std::vector> next_stacks;
-    llama_grammar_advance_stack(rules, stack_after, next_stacks);
-
-    auto next_rejects = llama_grammar_reject_candidates(rules, next_stacks, next_candidates);
-    for (const auto & tok : next_rejects) {
-        rejects.push_back({ tok.index, tok.code_points - 1, tok.partial_utf8 });
-    }
-
-    return rejects;
-}
-
-static std::vector llama_grammar_reject_candidates(
-        const std::vector>         & rules,
-        const std::vector> & stacks,
-        const std::vector                    & candidates) {
-    GGML_ASSERT(!stacks.empty()); // REVIEW
-
-    if (candidates.empty()) {
-        return std::vector();
-    }
-
-    auto rejects = llama_grammar_reject_candidates_for_stack(rules, stacks.front(), candidates);
-
-    for (size_t i = 1, size = stacks.size(); i < size; ++i) {
-        rejects = llama_grammar_reject_candidates_for_stack(rules, stacks[i], rejects);
-    }
-    return rejects;
-}
-
-static bool llama_grammar_detect_left_recursion(
-        const std::vector> & rules,
-        size_t                                                  rule_index,
-        std::vector                                     * rules_visited,
-        std::vector                                     * rules_in_progress,
-        std::vector                                     * rules_may_be_empty) {
-    if ((*rules_in_progress)[rule_index]) {
-        return true;
-    }
-
-    (*rules_in_progress)[rule_index] = true;
-
-    const std::vector & rule = rules[rule_index];
-
-    // First check if the rule might produce the empty string. This could be done combined with the second
-    // step but it's more readable as two steps.
-    bool at_rule_start = true;
-    for (size_t i = 0; i < rule.size(); i++) {
-        if (llama_grammar_is_end_of_sequence(&rule[i])) {
-            if (at_rule_start) {
-                (*rules_may_be_empty)[rule_index] = true;
-                break;
-            }
-            at_rule_start = true;
-        } else {
-            at_rule_start = false;
-        }
-    }
-
-    // Second, recurse into leftmost nonterminals (or next-leftmost as long as the previous nonterminal may
-    // be empty)
-    bool recurse_into_nonterminal = true;
-    for (size_t i = 0; i < rule.size(); i++) {
-        if (rule[i].type == LLAMA_GRETYPE_RULE_REF && recurse_into_nonterminal) {
-            if (llama_grammar_detect_left_recursion(rules, (size_t)rule[i].value, rules_visited, rules_in_progress, rules_may_be_empty)) {
-                return true;
-            }
-            if (!((*rules_may_be_empty)[(size_t)rule[i].value])) {
-                recurse_into_nonterminal = false;
-            }
-        } else if (llama_grammar_is_end_of_sequence(&rule[i])) {
-            recurse_into_nonterminal = true;
-        } else {
-            recurse_into_nonterminal = false;
-        }
-    }
-
-    (*rules_in_progress)[rule_index] = false;
-    (*rules_visited)[rule_index] = true;
-    return false;
-}
-
-//
-// grammar - external
-//
-
-struct llama_grammar * llama_grammar_init(
-            const llama_grammar_element ** rules,
-                                 size_t    n_rules,
-                                 size_t    start_rule_index) {
-    const llama_grammar_element * pos;
-
-    // copy rule definitions into vectors
-    std::vector> vec_rules(n_rules);
-    for (size_t i = 0; i < n_rules; i++) {
-        for (pos = rules[i]; pos->type != LLAMA_GRETYPE_END; pos++) {
-            vec_rules[i].push_back(*pos);
-        }
-        vec_rules[i].push_back({LLAMA_GRETYPE_END, 0});
-    }
-
-    // Check for left recursion
-    std::vector rules_visited(n_rules);
-    std::vector rules_in_progress(n_rules);
-    std::vector rules_may_be_empty(n_rules);
-    for (size_t i = 0; i < n_rules; i++) {
-        if (rules_visited[i]) {
-            continue;
-        }
-        if (llama_grammar_detect_left_recursion(vec_rules, i, &rules_visited, &rules_in_progress, &rules_may_be_empty)) {
-            throw std::runtime_error(format("unsupported grammar, left recursion detected for nonterminal at index %zu", i));
-        }
-    }
-
-    // loop over alternates of start rule to build initial stacks
-    std::vector> stacks;
-    pos = vec_rules[start_rule_index].data();
-    do {
-        std::vector stack;
-        if (!llama_grammar_is_end_of_sequence(pos)) {
-            // if alternate is nonempty, add to stack
-            stack.push_back(pos);
-        }
-        llama_grammar_advance_stack(vec_rules, stack, stacks);
-        while (!llama_grammar_is_end_of_sequence(pos)) {
-            // scan to end of alternate def
-            pos++;
-        }
-        if (pos->type == LLAMA_GRETYPE_ALT) {
-            // there's another alternate def of this rule to process
-            pos++;
-        } else {
-            break;
-        }
-    } while (true);
-
-    // Important: vec_rules has to be moved here, not copied, because stacks contains
-    // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
-    // then the pointers would be invalidated when the local vec_rules goes out of scope.
-    return new llama_grammar{ std::move(vec_rules), std::move(stacks), {} };
-}
-
-void llama_grammar_free(struct llama_grammar * grammar) {
-    delete grammar;
-}
-
-struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar) {
-    llama_grammar * result = new llama_grammar{ grammar->rules, grammar->stacks, grammar->partial_utf8 };
-
-    // redirect elements in stacks to point to new rules
-    for (size_t is = 0; is < result->stacks.size(); is++) {
-        for (size_t ie = 0; ie < result->stacks[is].size(); ie++) {
-            for (size_t ir0 = 0; ir0 < grammar->rules.size(); ir0++) {
-                for (size_t ir1 = 0; ir1 < grammar->rules[ir0].size(); ir1++) {
-                    if (grammar->stacks[is][ie] == &grammar->rules[ir0][ir1]) {
-                         result->stacks[is][ie]  =  &result->rules[ir0][ir1];
-                    }
-                }
-            }
-        }
-    }
-
-    return result;
-}
-
-//
-// sampling
-//
-
-void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed) {
-    if (seed == LLAMA_DEFAULT_SEED) {
-        seed = time(NULL);
-    }
-    ctx->rng.seed(seed);
-}
-
-void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates) {
-    GGML_ASSERT(candidates->size > 0);
-
-    const int64_t t_start_sample_us = ggml_time_us();
-
-    // Sort the logits in descending order
-    if (!candidates->sorted) {
-        std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
-            return a.logit > b.logit;
-        });
-        candidates->sorted = true;
-    }
-
-    float max_l = candidates->data[0].logit;
-    float cum_sum = 0.0f;
-    for (size_t i = 0; i < candidates->size; ++i) {
-        float p = expf(candidates->data[i].logit - max_l);
-        candidates->data[i].p = p;
-        cum_sum += p;
-    }
-    for (size_t i = 0; i < candidates->size; ++i) {
-        candidates->data[i].p /= cum_sum;
-    }
-
-    if (ctx) {
-        ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
-    }
-}
-
-void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int32_t k, size_t min_keep) {
-    // TODO: move bucket sort to separate function so that top_p/tail_free/typical/softmax first is equally fast
-    // if (k >= (int32_t)candidates->size) {
-    //     return;
-    // }
-
-    const int64_t t_start_sample_us = ggml_time_us();
-
-    if (k <= 0) {
-        k = candidates->size;
-    }
-
-    k = std::max(k, (int) min_keep);
-    k = std::min(k, (int) candidates->size);
-
-    // Sort scores in descending order
-    if (!candidates->sorted) {
-        auto comp = [](const llama_token_data & a, const llama_token_data & b) {
-            return a.logit > b.logit;
-        };
-        if (k <= 128) {
-            std::partial_sort(candidates->data, candidates->data + k, candidates->data + candidates->size, comp);
-        } else {
-            constexpr int   nbuckets     = 128;
-            constexpr float bucket_low   = -10.0f;
-            constexpr float bucket_high  =  10.0f;
-            constexpr float bucket_scale = nbuckets/(bucket_high - bucket_low);
-            constexpr float bucker_inter = -bucket_low * bucket_scale;
-
-            std::vector bucket_idx(candidates->size);
-            std::vector histo(nbuckets, 0);
-
-            for (int i = 0; i < (int)candidates->size; ++i) {
-                const float val = candidates->data[i].logit;
-                int ib = int(bucket_scale * val + bucker_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low);
-                ib = std::max(0, std::min(nbuckets-1, ib));
-                bucket_idx[i] = ib;
-                ++histo[ib];
-            }
-            int nhave = 0;
-            int ib = nbuckets - 1;
-            for ( ; ib >= 0; --ib) {
-                nhave += histo[ib];
-                if (nhave >= k) break;
-            }
-            std::vector tmp_tokens(nhave);
-            auto ptr = tmp_tokens.data();
-            std::vector bucket_ptrs;
-            bucket_ptrs.reserve(nbuckets - ib);
-            for (int j = nbuckets - 1; j >= ib; --j) {
-                bucket_ptrs.push_back(ptr);
-                ptr += histo[j];
-            }
-            for (int i = 0; i < (int)candidates->size; ++i) {
-                int j = bucket_idx[i];
-                if (j >= ib) {
-                    *bucket_ptrs[nbuckets-1-j]++ = candidates->data[i];
-                }
-            }
-
-            ptr = tmp_tokens.data();
-            int ndone = 0;
-            for (int j = nbuckets-1; j > ib; --j) {
-                std::sort(ptr, ptr + histo[j], comp);
-                ptr += histo[j];
-                ndone += histo[j];
-            }
-            std::partial_sort(ptr, ptr + k - ndone, ptr + histo[ib], comp);
-
-            std::memcpy(candidates->data, tmp_tokens.data(), k*sizeof(llama_token_data));
-
-        }
-        candidates->sorted = true;
-    }
-    candidates->size = k;
-
-    if (ctx) {
-        ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
-    }
-}
-
-void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) {
-    if (p >= 1.0f) {
-        return;
-    }
-
-    llama_sample_softmax(ctx, candidates);
-
-    const int64_t t_start_sample_us = ggml_time_us();
-
-    // Compute the cumulative probabilities
-    float cum_sum = 0.0f;
-    size_t last_idx = candidates->size;
-
-    for (size_t i = 0; i < candidates->size; ++i) {
-        cum_sum += candidates->data[i].p;
-
-        // Check if the running sum is at least p or if we have kept at least min_keep tokens
-        // we set the last index to i+1 to indicate that the current iterate should be included in the set
-        if (cum_sum >= p && i + 1 >= min_keep) {
-            last_idx = i + 1;
-            break;
-        }
-    }
-
-    // Resize the output vector to keep only the top-p tokens
-    candidates->size = last_idx;
-
-    if (ctx) {
-        ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
-    }
-}
-
-void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) {
-    if (p <= 0.0f || !candidates->size) {
-        return;
-    }
-
-    const int64_t t_start_sample_us = ggml_time_us();
-
-    bool min_p_applied = false;
-
-    // if the candidates aren't sorted, try the unsorted implementation first
-    if (!candidates->sorted) {
-        std::vector filtered_tokens;
-
-        float max_logit = -FLT_MAX;
-        for (size_t i = 0; i < candidates->size; ++i) {
-            max_logit = std::max(max_logit, candidates->data[i].logit);
-        }
-        const float min_logit = max_logit + logf(p); // min logit for p_i >= p * p_max
-
-        for (size_t i = 0; i < candidates->size; ++i) {
-            if (candidates->data[i].logit >= min_logit) {
-                filtered_tokens.push_back(candidates->data[i]);
-            }
-        }
-
-        // if we have enough values the operation was a success
-        if (filtered_tokens.size() >= min_keep) {
-            memcpy(candidates->data, filtered_tokens.data(), filtered_tokens.size()*sizeof(llama_token_data));
-            candidates->size = filtered_tokens.size();
-            min_p_applied = true;
-        }
-    }
-
-    // if the candidates are sorted or the unsorted implementation failed, use this implementation
-    if (!min_p_applied) {
-        // Sort the logits in descending order
-        if (!candidates->sorted) {
-            std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
-                return a.logit > b.logit;
-            });
-            candidates->sorted = true;
-        }
-
-        const float min_logit = candidates->data[0].logit + logf(p); // min logit for p_i >= p * p_max
-        size_t i = 1; // first token always matches
-
-        for (; i < candidates->size; ++i) {
-            if (candidates->data[i].logit < min_logit && i >= min_keep) {
-                break; // prob too small
-            }
-        }
-
-        // Resize the output vector to keep only the matching tokens
-        candidates->size = i;
-    }
-
-    if (ctx) {
-        ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
-    }
-}
-
-void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep) {
-    if (z >= 1.0f || candidates->size <= 2) {
-        return;
-    }
-
-    llama_sample_softmax(nullptr, candidates);
-    const int64_t t_start_sample_us = ggml_time_us();
-
-    // Compute the first and second derivatives
-    std::vector first_derivatives(candidates->size - 1);
-    std::vector second_derivatives(candidates->size - 2);
-
-    for (size_t i = 0; i < first_derivatives.size(); ++i) {
-        first_derivatives[i] = candidates->data[i].p - candidates->data[i + 1].p;
-    }
-    for (size_t i = 0; i < second_derivatives.size(); ++i) {
-        second_derivatives[i] = first_derivatives[i] - first_derivatives[i + 1];
-    }
-
-    // Calculate absolute value of second derivatives
-    for (size_t i = 0; i < second_derivatives.size(); ++i) {
-        second_derivatives[i] = std::abs(second_derivatives[i]);
-    }
-
-    // Normalize the second derivatives
-    {
-        const float second_derivatives_sum = std::accumulate(second_derivatives.begin(), second_derivatives.end(), 0.0f);
-
-        if (second_derivatives_sum > 1e-6f) {
-            for (float & value : second_derivatives) {
-                value /= second_derivatives_sum;
-            }
-        } else {
-            for (float & value : second_derivatives) {
-                value = 1.0f / second_derivatives.size();
-            }
-        }
-    }
-
-    float cum_sum = 0.0f;
-    size_t last_idx = candidates->size;
-    for (size_t i = 0; i < second_derivatives.size(); ++i) {
-        cum_sum += second_derivatives[i];
-
-        // Check if the running sum is greater than z or if we have kept at least min_keep tokens
-        if (cum_sum > z && i >= min_keep) {
-            last_idx = i;
-            break;
-        }
-    }
-
-    // Resize the output vector to keep only the tokens above the tail location
-    candidates->size = last_idx;
-
-    if (ctx) {
-        ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
-    }
-}
-
-void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) {
-    // Reference implementation:
-    // https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr
-    if (p >= 1.0f) {
-        return;
-    }
-
-    // Compute the softmax of logits and calculate entropy
-    llama_sample_softmax(nullptr, candidates);
-
-    const int64_t t_start_sample_us = ggml_time_us();
-
-    float entropy = 0.0f;
-    for (size_t i = 0; i < candidates->size; ++i) {
-        entropy += -candidates->data[i].p * logf(candidates->data[i].p);
-    }
-
-    // Compute the absolute difference between negative log probability and entropy for each candidate
-    std::vector shifted_scores;
-    for (size_t i = 0; i < candidates->size; ++i) {
-        float shifted_score = fabsf(-logf(candidates->data[i].p) - entropy);
-        shifted_scores.push_back(shifted_score);
-    }
-
-    // Sort tokens based on the shifted_scores and their corresponding indices
-    std::vector indices(candidates->size);
-    std::iota(indices.begin(), indices.end(), 0);
-
-    std::sort(indices.begin(), indices.end(), [&](size_t a, size_t b) {
-        return shifted_scores[a] < shifted_scores[b];
-    });
-
-    // Compute the cumulative probabilities
-    float cum_sum = 0.0f;
-    size_t last_idx = indices.size();
-
-    for (size_t i = 0; i < indices.size(); ++i) {
-        size_t idx = indices[i];
-        cum_sum += candidates->data[idx].p;
-
-        // Check if the running sum is greater than typical or if we have kept at least min_keep tokens
-        if (cum_sum > p && i >= min_keep - 1) {
-            last_idx = i + 1;
-            break;
-        }
-    }
-
-    // Resize the output vector to keep only the locally typical tokens
-    std::vector new_candidates;
-    for (size_t i = 0; i < last_idx; ++i) {
-        size_t idx = indices[i];
-        new_candidates.push_back(candidates->data[idx]);
-    }
-
-    // Replace the data in candidates with the new_candidates data
-    std::copy(new_candidates.begin(), new_candidates.end(), candidates->data);
-    candidates->size = new_candidates.size();
-    candidates->sorted = false;
-
-    if (ctx) {
-        ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
-    }
-}
-
-void llama_sample_entropy(struct llama_context * ctx, llama_token_data_array * candidates_p, float min_temp, float max_temp, float exponent_val) {
-    const int64_t t_start_sample_us = ggml_time_us();
-
-    // no need to do anything if there is only one (or zero) candidates
-    if(candidates_p->size <= 1) {
-        return;
-    }
-
-    // Calculate maximum possible entropy
-    float max_entropy = -logf(1.0f / candidates_p->size);
-
-    llama_sample_softmax(nullptr, candidates_p);
-
-    // Calculate entropy of the softmax probabilities
-    float entropy = 0.0f;
-    for (size_t i = 0; i < candidates_p->size; ++i) {
-        float prob = candidates_p->data[i].p;
-        if (prob > 0.0f) { // Ensure no log(0)
-            entropy -= prob * logf(prob);
-        }
-    }
-
-    // Normalize the entropy (max_entropy cannot be 0 here because we checked candidates_p->size != 1 above)
-    float normalized_entropy = entropy / max_entropy;
-
-    // Map the normalized entropy to the desired temperature range using the power function
-    float dyn_temp = min_temp + (max_temp - min_temp) * powf(normalized_entropy, exponent_val);
-
-#ifdef DEBUG
-    LLAMA_LOG_INFO("Your text maxtemp value is: %f\n", max_temp);
-    LLAMA_LOG_INFO("Entropy: %f\n", entropy);
-    LLAMA_LOG_INFO("Max Possible Entropy: %f\n", max_entropy);
-    LLAMA_LOG_INFO("Normalized Entropy: %f\n", normalized_entropy);
-    LLAMA_LOG_INFO("Exponent: %f\n", exponent_val);
-    LLAMA_LOG_INFO("Dynamic Temperature (dyn_temp): %f\n", dyn_temp);
-#endif
-
-    // Apply the dynamically calculated temperature scaling
-    for (size_t i = 0; i < candidates_p->size; ++i) {
-        candidates_p->data[i].logit /= dyn_temp;
-    }
-
-    // Re-compute softmax probabilities after scaling logits with dynamic temperature
-    double max_l_double = candidates_p->data[0].logit;
-    double cum_sum_double = 0.0;
-    for (size_t i = 0; i < candidates_p->size; ++i) {
-        double p = exp(candidates_p->data[i].logit - max_l_double);
-        candidates_p->data[i].p = p; // Store the scaled probability
-        cum_sum_double += p;
-    }
-    for (size_t i = 0; i < candidates_p->size; ++i) {
-        candidates_p->data[i].p /= cum_sum_double; // Re-normalize the probabilities
-    }
-
-#ifdef DEBUG
-    // Print the updated top 25 probabilities after temperature scaling
-    LLAMA_LOG_INFO("\nUpdated Top 25 Probabilities After Dynamic Temperature Scaling (in percentages):\n");
-    for (size_t i = 0; i < 25 && i < candidates_p->size; ++i) {
-        LLAMA_LOG_INFO("Token %zu: %f%%\n", i + 1, candidates_p->data[i].p * 100.0f);
-    }
-#endif
-
-    if (ctx) {
-        ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
-    }
-}
-
-void llama_sample_temp(struct llama_context * ctx, llama_token_data_array * candidates_p, float temp) {
-    const int64_t t_start_sample_us = ggml_time_us();
-
-    for (size_t i = 0; i < candidates_p->size; ++i) {
-        candidates_p->data[i].logit /= temp;
-    }
-
-    if (ctx) {
-        ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
-    }
-}
-
-void llama_sample_repetition_penalties(
-            struct llama_context * ctx,
-          llama_token_data_array * candidates,
-               const llama_token * last_tokens,
-                          size_t   penalty_last_n,
-                           float   penalty_repeat,
-                           float   penalty_freq,
-                           float   penalty_present) {
-    if (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f)) {
-        return;
-    }
-
-    const int64_t t_start_sample_us = ggml_time_us();
-
-    // Create a frequency map to count occurrences of each token in last_tokens
-    std::unordered_map token_count;
-    for (size_t i = 0; i < penalty_last_n; ++i) {
-        token_count[last_tokens[i]]++;
-    }
-
-    // Apply frequency and presence penalties to the candidates
-    for (size_t i = 0; i < candidates->size; ++i) {
-        const auto token_iter = token_count.find(candidates->data[i].id);
-        if (token_iter == token_count.end()) {
-            continue;
-        }
-
-        const int count = token_iter->second;
-
-        // The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
-        // This is common fix for this problem, which is to multiply by the penalty instead of dividing.
-        if (candidates->data[i].logit <= 0) {
-            candidates->data[i].logit *= penalty_repeat;
-        } else {
-            candidates->data[i].logit /= penalty_repeat;
-        }
-
-        candidates->data[i].logit -= float(count) * penalty_freq + float(count > 0) * penalty_present;
-    }
-
-    candidates->sorted = false;
-
-    if (ctx) {
-        ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
-    }
-}
-
-void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar) {
-    GGML_ASSERT(ctx);
-    int64_t t_start_sample_us = ggml_time_us();
-
-    bool allow_eog = false;
-    for (const auto & stack : grammar->stacks) {
-        if (stack.empty()) {
-            allow_eog = true;
-            break;
-        }
-    }
-
-    std::vector, llama_partial_utf8>> candidates_decoded;
-    candidates_decoded.reserve(candidates->size);
-
-    std::vector candidates_grammar;
-    candidates_grammar.reserve(candidates->size);
-
-    for (size_t i = 0; i < candidates->size; ++i) {
-        const llama_token id      = candidates->data[i].id;
-        const std::string & piece = ctx->model.vocab.cache_token_to_piece.at(id);
-
-        if (llama_token_is_eog(&ctx->model, id)) {
-            if (!allow_eog) {
-                candidates->data[i].logit = -INFINITY;
-            }
-        } else if (piece.empty() || piece[0] == 0) {
-            candidates->data[i].logit = -INFINITY;
-        } else {
-            candidates_decoded.push_back(decode_utf8(piece, grammar->partial_utf8));
-            candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second });
-        }
-    }
-
-    const auto rejects = llama_grammar_reject_candidates(grammar->rules, grammar->stacks, candidates_grammar);
-    for (const auto & reject : rejects) {
-        candidates->data[reject.index].logit = -INFINITY;
-    }
-
-    ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
-}
-
-static void llama_log_softmax(float * array, size_t size) {
-    float max_l = *std::max_element(array, array + size);
-    float sum = 0.f;
-    for (size_t i = 0; i < size; ++i) {
-        float p = expf(array[i] - max_l);
-        sum += p;
-        array[i] = p;
-    }
-
-    for (size_t i = 0; i < size; ++i) {
-        array[i] = logf(array[i] / sum);
-    }
-}
-
-void llama_sample_apply_guidance(
-          struct llama_context * ctx,
-                         float * logits,
-                         float * logits_guidance,
-                         float   scale) {
-    GGML_ASSERT(ctx);
-
-    const auto t_start_sample_us = ggml_time_us();
-    const auto n_vocab = llama_n_vocab(llama_get_model(ctx));
-
-    llama_log_softmax(logits, n_vocab);
-    llama_log_softmax(logits_guidance, n_vocab);
-
-    for (int i = 0; i < n_vocab; ++i) {
-              auto & l = logits[i];
-        const auto & g = logits_guidance[i];
-
-        l = scale * (l - g) + g;
-    }
-
-    ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
-}
-
-llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) {
-    GGML_ASSERT(ctx);
-
-    auto N = float(llama_n_vocab(llama_get_model(ctx)));
-    int64_t t_start_sample_us;
-    t_start_sample_us = ggml_time_us();
-
-    llama_sample_softmax(nullptr, candidates);
-
-    // Estimate s_hat using the most probable m tokens
-    float s_hat = 0.0;
-    float sum_ti_bi = 0.0;
-    float sum_ti_sq = 0.0;
-    for (size_t i = 0; i < size_t(m - 1) && i < candidates->size - 1; ++i) {
-        float t_i = logf(float(i + 2) / float(i + 1));
-        float b_i = logf(candidates->data[i].p / candidates->data[i + 1].p);
-        sum_ti_bi += t_i * b_i;
-        sum_ti_sq += t_i * t_i;
-    }
-    s_hat = sum_ti_bi / sum_ti_sq;
-
-    // Compute k from the estimated s_hat and target surprise value
-    float epsilon_hat = s_hat - 1;
-    float k = powf((epsilon_hat * powf(2, *mu)) / (1 - powf(N, -epsilon_hat)), 1 / s_hat);
-
-    // Sample the next word X using top-k sampling
-    llama_sample_top_k(nullptr, candidates, int(k), 1);
-    ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
-    llama_token X = llama_sample_token(ctx, candidates);
-    t_start_sample_us = ggml_time_us();
-
-    // Compute error as the difference between observed surprise and target surprise value
-    size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
-        return candidate.id == X;
-    }));
-    float observed_surprise = -log2f(candidates->data[X_idx].p);
-    float e = observed_surprise - tau;
-
-    // Update mu using the learning rate and error
-    *mu = *mu - eta * e;
-
-    ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
-    return X;
-}
-
-llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, float * mu) {
-    int64_t t_start_sample_us;
-    t_start_sample_us = ggml_time_us();
-
-    llama_sample_softmax(ctx, candidates);
-
-    // Truncate the words with surprise values greater than mu
-    candidates->size = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
-        return -log2f(candidate.p) > *mu;
-    }));
-
-    if (candidates->size == 0) {
-        candidates->size = 1;
-    }
-
-    if (ctx) {
-        ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
-    }
-
-    // Normalize the probabilities of the remaining words
-    llama_sample_softmax(ctx, candidates);
-
-    // Sample the next word X from the remaining words
-    llama_token X = llama_sample_token(ctx, candidates);
-    t_start_sample_us = ggml_time_us();
-
-    // Compute error as the difference between observed surprise and target surprise value
-    size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
-        return candidate.id == X;
-    }));
-    float observed_surprise = -log2f(candidates->data[X_idx].p);
-    float e = observed_surprise - tau;
-
-    // Update mu using the learning rate and error
-    *mu = *mu - eta * e;
-
-    if (ctx) {
-        ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
-    }
-    return X;
-}
-
-llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_data_array * candidates) {
-    const int64_t t_start_sample_us = ggml_time_us();
-
-    // Find max element
-    auto * max_iter = std::max_element(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
-        return a.logit < b.logit;
-    });
-
-    llama_token result = max_iter->id;
-    if (ctx) {
-        ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
-        ctx->n_sample++;
-    }
-    return result;
-}
-
-llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_data_array * candidates, std::mt19937 & rng) {
-    GGML_ASSERT(ctx);
-
-    const int64_t t_start_sample_us = ggml_time_us();
-    llama_sample_softmax(nullptr, candidates);
-
-    std::vector probs;
-    probs.reserve(candidates->size);
-    for (size_t i = 0; i < candidates->size; ++i) {
-        probs.push_back(candidates->data[i].p);
-    }
-
-    std::discrete_distribution<> dist(probs.begin(), probs.end());
-    int idx = dist(rng);
-
-    llama_token result = candidates->data[idx].id;
-
-    ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
-    ctx->n_sample++;
-    return result;
-}
-
-llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates) {
-    return llama_sample_token_with_rng(ctx, candidates, ctx->rng);
-}
-
-void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token) {
-    const int64_t t_start_sample_us = ggml_time_us();
-
-    if (llama_token_is_eog(&ctx->model, token)) {
-        for (const auto & stack : grammar->stacks) {
-            if (stack.empty()) {
-                return;
-            }
-        }
-        GGML_ASSERT(false);
-    }
-
-    const std::string & piece = ctx->model.vocab.cache_token_to_piece.at(token);
-
-    // Note terminating 0 in decoded string
-    const auto   decoded     = decode_utf8(piece, grammar->partial_utf8);
-    const auto & code_points = decoded.first;
-    std::vector> tmp_new_stacks;
-    for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
-        llama_grammar_accept(grammar->rules, grammar->stacks, *it, tmp_new_stacks);
-        grammar->stacks = tmp_new_stacks;
-    }
-    grammar->partial_utf8 = decoded.second;
-    GGML_ASSERT(!grammar->stacks.empty());
-
-    ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
-}
-
 //
 // quantization
 //
@@ -14797,7 +15259,7 @@ static void llama_tensor_dequantize_internal(
         } else if (ggml_is_quantized(tensor->type)) {
             qtype.to_float(tensor->data, f32_output, nelements);
         } else {
-            GGML_ASSERT(false); // unreachable
+            GGML_ABORT("fatal error"); // unreachable
         }
         return;
     }
@@ -14849,8 +15311,8 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
     const llm_arch arch = qs.model.arch;
     const auto       tn = LLM_TN(arch);
 
-    auto use_more_bits = [](int i_layer, int num_layers) -> bool {
-        return i_layer < num_layers/8 || i_layer >= 7*num_layers/8 || (i_layer - num_layers/8)%3 == 2;
+    auto use_more_bits = [](int i_layer, int n_layers) -> bool {
+        return i_layer < n_layers/8 || i_layer >= 7*n_layers/8 || (i_layer - n_layers/8)%3 == 2;
     };
     const int n_expert = std::max(1, (int)qs.model.hparams.n_expert);
     auto layer_info = [n_expert] (int i_layer, int n_layer, const char * name) {
@@ -14902,6 +15364,10 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
             else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) {
                 new_type = GGML_TYPE_IQ3_S;
             }
+            else if (new_type == GGML_TYPE_Q4_0_4_4 || new_type == GGML_TYPE_Q4_0_4_8 ||
+                     new_type == GGML_TYPE_Q4_0_8_8) {
+                new_type = GGML_TYPE_Q4_0;
+            }
         }
     } else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ1_S ||
                ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M    || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) {
@@ -15085,10 +15551,10 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
     //    if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_S) new_type = GGML_TYPE_Q4_K;
     //}
     bool convert_incompatible_tensor = false;
-    if (new_type == GGML_TYPE_Q2_K || new_type == GGML_TYPE_Q3_K || new_type == GGML_TYPE_Q4_K ||
-        new_type == GGML_TYPE_Q5_K || new_type == GGML_TYPE_Q6_K || new_type == GGML_TYPE_IQ4_XS ||
-        new_type == GGML_TYPE_IQ2_XS || new_type == GGML_TYPE_IQ2_XXS || new_type == GGML_TYPE_IQ2_S ||
-        new_type == GGML_TYPE_IQ3_XXS || new_type == GGML_TYPE_IQ1_S || new_type == GGML_TYPE_IQ3_S ||
+    if (new_type == GGML_TYPE_Q2_K    || new_type == GGML_TYPE_Q3_K    || new_type == GGML_TYPE_Q4_K   ||
+        new_type == GGML_TYPE_Q5_K    || new_type == GGML_TYPE_Q6_K    || new_type == GGML_TYPE_IQ4_XS ||
+        new_type == GGML_TYPE_IQ2_XS  || new_type == GGML_TYPE_IQ2_XXS || new_type == GGML_TYPE_IQ2_S  ||
+        new_type == GGML_TYPE_IQ3_XXS || new_type == GGML_TYPE_IQ1_S   || new_type == GGML_TYPE_IQ3_S  ||
         new_type == GGML_TYPE_IQ1_M) {
         int nx = tensor->ne[0];
         int ny = tensor->ne[1];
@@ -15214,6 +15680,9 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
         case LLAMA_FTYPE_MOSTLY_IQ4_XS:  default_type = GGML_TYPE_IQ4_XS;  break;
         case LLAMA_FTYPE_MOSTLY_IQ3_S:   default_type = GGML_TYPE_IQ3_S;   break;
         case LLAMA_FTYPE_MOSTLY_IQ3_M:   default_type = GGML_TYPE_IQ3_S;   break;
+        case LLAMA_FTYPE_MOSTLY_Q4_0_4_4: default_type = GGML_TYPE_Q4_0_4_4; break;
+        case LLAMA_FTYPE_MOSTLY_Q4_0_4_8: default_type = GGML_TYPE_Q4_0_4_8; break;
+        case LLAMA_FTYPE_MOSTLY_Q4_0_8_8: default_type = GGML_TYPE_Q4_0_8_8; break;
 
         default: throw std::runtime_error(format("invalid output file type %d\n", ftype));
     }
@@ -15255,6 +15724,14 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
         if (imatrix_data) {
             LLAMA_LOG_INFO("================================ Have weights data with %d entries\n",int(imatrix_data->size()));
             qs.has_imatrix = true;
+            // check imatrix for nans or infs
+            for (const auto & kv : *imatrix_data) {
+                for (float f : kv.second) {
+                    if (!std::isfinite(f)) {
+                        throw std::runtime_error(format("imatrix contains non-finite value %f\n", f));
+                    }
+                }
+            }
         }
     }
 
@@ -15263,8 +15740,9 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
 
     // copy the KV pairs from the input file
     gguf_set_kv     (ctx_out, ml.meta);
-    gguf_set_val_u32(ctx_out, "general.quantization_version", GGML_QNT_VERSION);
-    gguf_set_val_u32(ctx_out, "general.file_type", ftype);
+    gguf_set_val_u32(ctx_out, "general.quantization_version", GGML_QNT_VERSION); // TODO: use LLM_KV
+    gguf_set_val_u32(ctx_out, "general.file_type", ftype); // TODO: use LLM_KV
+
     // Remove split metadata
     gguf_remove_key(ctx_out, ml.llm_kv(LLM_KV_SPLIT_NO).c_str());
     gguf_remove_key(ctx_out, ml.llm_kv(LLM_KV_SPLIT_COUNT).c_str());
@@ -15306,10 +15784,11 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
 
     // sanity checks
     //
-    //  - qs.n_attention_wv == 0                     for Mamba       models
-    //  - qs.n_attention_wv == model.hparams.n_layer for Transformer models
+    //  - qs.n_attention_wv == 0                         for Mamba           models
+    //  - qs.n_attention_wv == model.hparams.n_layer     for Transformer     models
+    //  - qs.n_attention_wv == 3 * model.hparams.n_layer for Encoder-Decoder models
     //
-    GGML_ASSERT((qs.n_attention_wv == 0 || qs.n_attention_wv == (int)model.hparams.n_layer) && "n_attention_wv is unexpected");
+    GGML_ASSERT((qs.n_attention_wv == 0 || qs.n_attention_wv == (int)model.hparams.n_layer || qs.n_attention_wv == 3 * (int)model.hparams.n_layer) && "n_attention_wv is unexpected");
 
     size_t total_size_org = 0;
     size_t total_size_new = 0;
@@ -15434,6 +15913,9 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
         quantize &= name.find("ssm_x.weight")      == std::string::npos;
         quantize &= name.find("ssm_dt.weight")     == std::string::npos;
 
+        // do not quantize relative position bias (T5)
+        quantize &= name.find("attn_rel_b.weight") == std::string::npos;
+
         enum ggml_type new_type;
         void * new_data;
         size_t new_size;
@@ -15512,6 +15994,14 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
                 f32_data = (float *) f32_conv_buf.data();
             }
 
+            int chunk_size_multiplier = 1;
+            if (new_type == GGML_TYPE_Q4_0_4_4 || new_type == GGML_TYPE_Q4_0_4_8 || new_type == GGML_TYPE_Q4_0_8_8) {
+                if ((new_type == GGML_TYPE_Q4_0_8_8) && (tensor->ne[1] % 8 != 0)) new_type = GGML_TYPE_Q4_0;
+                else if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_Q4_0;
+                if (new_type == GGML_TYPE_Q4_0_8_8) chunk_size_multiplier = 8;
+                else if (new_type == GGML_TYPE_Q4_0_4_4 || new_type == GGML_TYPE_Q4_0_4_8) chunk_size_multiplier = 4;
+            }
+
             LLAMA_LOG_INFO("converting to %s .. ", ggml_type_name(new_type));
             fflush(stdout);
 
@@ -15524,7 +16014,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
             const int64_t nrows = tensor->ne[1];
 
             static const int64_t min_chunk_size = 32 * 512;
-            const int64_t chunk_size = n_per_row >= min_chunk_size ? n_per_row : n_per_row * ((min_chunk_size + n_per_row - 1)/n_per_row);
+            const int64_t chunk_size = (n_per_row >= min_chunk_size ? n_per_row : n_per_row * ((min_chunk_size + n_per_row - 1)/n_per_row)) *
+                                       chunk_size_multiplier;
 
             const int64_t nelements_matrix = tensor->ne[0] * tensor->ne[1];
             const int64_t nchunk = (nelements_matrix + chunk_size - 1)/chunk_size;
@@ -15566,6 +16057,3134 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
     }
 }
 
+static void llama_lora_adapter_init_internal(struct llama_model * model, const char * path_lora, struct llama_lora_adapter & adapter) {
+    LLAMA_LOG_INFO("%s: loading lora adapter from '%s' ...\n", __func__, path_lora);
+
+    ggml_context * ctx = nullptr;
+    struct gguf_init_params meta_gguf_params = {
+        /* .no_alloc = */ true,
+        /* .ctx      = */ &ctx,
+    };
+    struct gguf_context * ctx_gguf = gguf_init_from_file(path_lora, meta_gguf_params);
+    if (!ctx_gguf) {
+        throw std::runtime_error("failed to load lora adapter file from " + std::string(path_lora));
+    }
+
+    // check metadata
+    {
+        auto get_kv_str = [&](const std::string & key) -> std::string {
+            int id = gguf_find_key(ctx_gguf, key.c_str());
+            return id < 0 ? "" : std::string(gguf_get_val_str(ctx_gguf, id));
+        };
+        auto get_kv_f32 = [&](const std::string & key) -> float {
+            int id = gguf_find_key(ctx_gguf, key.c_str());
+            return id < 0 ? 0.0f : gguf_get_val_f32(ctx_gguf, id);
+        };
+        LLM_KV llm_kv = LLM_KV(LLM_ARCH_UNKNOWN);
+
+        auto general_type = get_kv_str(llm_kv(LLM_KV_GENERAL_TYPE));
+        if (general_type != "adapter") {
+            gguf_free(ctx_gguf);
+            throw std::runtime_error("expect general.type to be 'adapter', but got: " + general_type);
+        }
+
+        auto general_arch_str = get_kv_str(llm_kv(LLM_KV_GENERAL_ARCHITECTURE));
+        auto general_arch = llm_arch_from_string(general_arch_str);
+        if (general_arch != model->arch) {
+            gguf_free(ctx_gguf);
+            throw std::runtime_error("model arch and LoRA arch mismatch");
+        }
+
+        auto adapter_type = get_kv_str(llm_kv(LLM_KV_ADAPTER_TYPE));
+        if (adapter_type != "lora") {
+            gguf_free(ctx_gguf);
+            throw std::runtime_error("expect adapter.type to be 'lora', but got: " + adapter_type);
+        }
+
+        adapter.alpha = get_kv_f32(llm_kv(LLM_KV_ADAPTER_LORA_ALPHA));
+    }
+
+    int n_tensors = gguf_get_n_tensors(ctx_gguf);
+
+    // contexts for each buffer type
+    std::map ctx_map;
+    auto get_ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
+        auto it = ctx_map.find(buft);
+        if (it == ctx_map.end()) {
+            // add a new context
+            struct ggml_init_params params = {
+                /*.mem_size   =*/ n_tensors*ggml_tensor_overhead(),
+                /*.mem_buffer =*/ NULL,
+                /*.no_alloc   =*/ true,
+            };
+            ggml_context * buft_ctx = ggml_init(params);
+            ctx_map[buft] = buft_ctx;
+            return buft_ctx;
+        };
+        return it->second;
+    };
+
+    // bundle lora_a and lora_b into pairs
+    std::map ab_map;
+    auto str_endswith = [](const std::string & str, const std::string & suffix) {
+        return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0;
+    };
+    for (ggml_tensor * cur = ggml_get_first_tensor(ctx); cur; cur = ggml_get_next_tensor(ctx, cur)) {
+        std::string name(cur->name);
+        if (str_endswith(name, ".lora_a")) {
+            replace_all(name, ".lora_a", "");
+            if (ab_map.find(name) == ab_map.end()) {
+                ab_map[name] = llama_lora_weight(cur, nullptr);
+            } else {
+                ab_map[name].a = cur;
+            }
+        } else if (str_endswith(name, ".lora_b")) {
+            replace_all(name, ".lora_b", "");
+            if (ab_map.find(name) == ab_map.end()) {
+                ab_map[name] = llama_lora_weight(nullptr, cur);
+            } else {
+                ab_map[name].b = cur;
+            }
+        } else {
+            gguf_free(ctx_gguf);
+            ggml_free(ctx);
+            throw std::runtime_error("LoRA tensor '" + name + "' has unexpected suffix");
+        }
+    }
+
+    // add tensors
+    for (auto & it : ab_map) {
+        const std::string & name = it.first;
+        llama_lora_weight & w = it.second;
+
+        if (!w.a || !w.b) {
+            gguf_free(ctx_gguf);
+            ggml_free(ctx);
+            throw std::runtime_error("LoRA tensor pair for '" + name + "' is missing one component");
+        }
+
+        // device buft and device ctx
+        auto * model_tensor = llama_get_model_tensor(model, name.c_str());
+        if (!model_tensor) {
+            gguf_free(ctx_gguf);
+            ggml_free(ctx);
+            throw std::runtime_error("LoRA tensor '" + name + "' does not exist in base model");
+        }
+        struct ggml_context * dev_ctx = get_ctx_for_buft(ggml_backend_buffer_get_type(model_tensor->buffer));
+        // validate tensor shape
+        if (model_tensor->ne[0] != w.a->ne[0] || model_tensor->ne[1] != w.b->ne[1]) {
+            gguf_free(ctx_gguf);
+            ggml_free(ctx);
+            throw std::runtime_error("tensor '" + name + "' has incorrect shape");
+        }
+        if (w.a->ne[1] != w.b->ne[0]) {
+            gguf_free(ctx_gguf);
+            ggml_free(ctx);
+            throw std::runtime_error("lora_a tensor is not transposed (hint: adapter from \"finetune\" example is no longer supported)");
+        }
+        // save tensor to adapter
+        struct ggml_tensor * tensor_a = ggml_dup_tensor(dev_ctx, w.a);
+        struct ggml_tensor * tensor_b = ggml_dup_tensor(dev_ctx, w.b);
+        ggml_set_name(tensor_a, w.a->name);
+        ggml_set_name(tensor_b, w.b->name);
+        adapter.ab_map[name] = llama_lora_weight(tensor_a, tensor_b);
+    }
+
+    // allocate tensors / buffers and zero
+    {
+        adapter.ctxs.reserve(ctx_map.size());
+        adapter.bufs.reserve(ctx_map.size());
+        for (auto it : ctx_map) {
+            ggml_backend_buffer_type_t buft = it.first;
+            ggml_context * ctx_dev = it.second;
+            ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx_dev, buft);
+            if (!buf) {
+                gguf_free(ctx_gguf);
+                ggml_free(ctx);
+                throw std::runtime_error("failed to allocate buffer for lora adapter\n");
+            }
+            LLAMA_LOG_INFO("%s: %10s LoRA buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
+            adapter.ctxs.push_back(ctx_dev);
+            adapter.bufs.push_back(buf);
+        }
+    }
+
+    // set tensor data
+    {
+        llama_file gguf_file(path_lora, "rb");
+        std::vector read_buf;
+        auto set_tensor = [&](struct ggml_tensor * orig, struct ggml_tensor * dev) {
+            size_t offs = gguf_get_data_offset(ctx_gguf) + gguf_get_tensor_offset(ctx_gguf, gguf_find_tensor(ctx_gguf, orig->name));
+            size_t size = ggml_nbytes(orig);
+            read_buf.resize(size);
+            gguf_file.seek(offs, SEEK_SET);
+            gguf_file.read_raw(read_buf.data(), size);
+            ggml_backend_tensor_set(dev, read_buf.data(), 0, size);
+        };
+        for (auto & it : adapter.ab_map) {
+            auto orig = ab_map[it.first];
+            auto dev  = it.second;
+            set_tensor(orig.a, dev.a);
+            set_tensor(orig.b, dev.b);
+        }
+    }
+
+    LLAMA_LOG_INFO("%s: loaded %ld tensors from lora file\n", __func__, adapter.ab_map.size()*2);
+
+    // free ctx for reading gguf
+    gguf_free(ctx_gguf);
+    ggml_free(ctx);
+}
+
+int32_t llama_lora_adapter_set(
+            struct llama_context * ctx,
+            struct llama_lora_adapter * adapter,
+            float scale) {
+    if (ctx->cparams.flash_attn) {
+        LLAMA_LOG_ERROR("%s: flash_attn is not compatible with LoRA\n", __func__);
+        return -1;
+    }
+    ctx->lora_adapters[adapter] = scale;
+    return 0;
+}
+
+int32_t llama_lora_adapter_remove(
+            struct llama_context * ctx,
+            struct llama_lora_adapter * adapter) {
+    auto pos = ctx->lora_adapters.find(adapter);
+    if (pos != ctx->lora_adapters.end()) {
+        ctx->lora_adapters.erase(pos);
+        return 0;
+    }
+    return -1;
+}
+
+void llama_lora_adapter_clear(struct llama_context * ctx) {
+    ctx->lora_adapters.clear();
+}
+
+void llama_lora_adapter_free(struct llama_lora_adapter * adapter) {
+    delete adapter;
+}
+
+//
+// interface implementation
+//
+struct llama_model_params llama_model_default_params() {
+    struct llama_model_params result = {
+        /*.n_gpu_layers                =*/ 0,
+        /*.split_mode                  =*/ LLAMA_SPLIT_MODE_LAYER,
+        /*.main_gpu                    =*/ 0,
+        /*.tensor_split                =*/ nullptr,
+        /*.rpc_servers                 =*/ nullptr,
+        /*.progress_callback           =*/ nullptr,
+        /*.progress_callback_user_data =*/ nullptr,
+        /*.kv_overrides                =*/ nullptr,
+        /*.vocab_only                  =*/ false,
+        /*.use_mmap                    =*/ true,
+        /*.use_mlock                   =*/ false,
+        /*.check_tensors               =*/ false,
+    };
+
+#ifdef GGML_USE_METAL
+    // note: we usually have plenty of VRAM, so by default offload all layers to the GPU
+    result.n_gpu_layers = 999;
+#endif
+
+    return result;
+}
+
+struct llama_context_params llama_context_default_params() {
+    struct llama_context_params result = {
+        /*.seed                        =*/ LLAMA_DEFAULT_SEED,
+        /*.n_ctx                       =*/ 512,
+        /*.n_batch                     =*/ 2048,
+        /*.n_ubatch                    =*/ 512,
+        /*.n_seq_max                   =*/ 1,
+        /*.n_threads                   =*/ GGML_DEFAULT_N_THREADS, // TODO: better default
+        /*.n_threads_batch             =*/ GGML_DEFAULT_N_THREADS,
+        /*.rope_scaling_type           =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
+        /*.pooling_type                =*/ LLAMA_POOLING_TYPE_UNSPECIFIED,
+        /*.attention_type              =*/ LLAMA_ATTENTION_TYPE_UNSPECIFIED,
+        /*.rope_freq_base              =*/ 0.0f,
+        /*.rope_freq_scale             =*/ 0.0f,
+        /*.yarn_ext_factor             =*/ -1.0f,
+        /*.yarn_attn_factor            =*/ 1.0f,
+        /*.yarn_beta_fast              =*/ 32.0f,
+        /*.yarn_beta_slow              =*/ 1.0f,
+        /*.yarn_orig_ctx               =*/ 0,
+        /*.defrag_thold                =*/ -1.0f,
+        /*.cb_eval                     =*/ nullptr,
+        /*.cb_eval_user_data           =*/ nullptr,
+        /*.type_k                      =*/ GGML_TYPE_F16,
+        /*.type_v                      =*/ GGML_TYPE_F16,
+        /*.logits_all                  =*/ false,
+        /*.embeddings                  =*/ false,
+        /*.offload_kqv                 =*/ true,
+        /*.flash_attn                  =*/ false,
+        /*.abort_callback              =*/ nullptr,
+        /*.abort_callback_data         =*/ nullptr,
+    };
+
+    return result;
+}
+
+struct llama_model_quantize_params llama_model_quantize_default_params() {
+    struct llama_model_quantize_params result = {
+        /*.nthread                     =*/ 0,
+        /*.ftype                       =*/ LLAMA_FTYPE_MOSTLY_Q5_1,
+        /*.output_tensor_type          =*/ GGML_TYPE_COUNT,
+        /*.token_embedding_type        =*/ GGML_TYPE_COUNT,
+        /*.allow_requantize            =*/ false,
+        /*.quantize_output_tensor      =*/ true,
+        /*.only_copy                   =*/ false,
+        /*.pure                        =*/ false,
+        /*.keep_split                  =*/ false,
+        /*.imatrix                     =*/ nullptr,
+        /*.kv_overrides                =*/ nullptr,
+    };
+
+    return result;
+}
+
+size_t llama_max_devices(void) {
+#if defined(GGML_USE_RPC)
+    return GGML_RPC_MAX_SERVERS;
+#elif defined(GGML_USE_METAL)
+    return 1;
+#elif defined(GGML_USE_CUDA)
+    return GGML_CUDA_MAX_DEVICES;
+#elif defined(GGML_USE_SYCL)
+    return GGML_SYCL_MAX_DEVICES;
+#elif defined(GGML_USE_VULKAN)
+    return GGML_VK_MAX_DEVICES;
+#elif defined(GGML_USE_CANN)
+    return GGML_CANN_MAX_DEVICES;
+#else
+    return 1;
+#endif
+}
+
+bool llama_supports_mmap(void) {
+    return llama_mmap::SUPPORTED;
+}
+
+bool llama_supports_mlock(void) {
+    return llama_mlock::SUPPORTED;
+}
+
+bool llama_supports_gpu_offload(void) {
+#if defined(GGML_USE_CUDA) || defined(GGML_USE_METAL)   || defined(GGML_USE_VULKAN) || \
+    defined(GGML_USE_SYCL) || defined(GGML_USE_KOMPUTE) || defined(GGML_USE_RPC)
+    // Defined when llama.cpp is compiled with support for offloading model layers to GPU.
+    return true;
+#else
+    return false;
+#endif
+}
+
+void llama_backend_init(void) {
+    ggml_time_init();
+
+    // needed to initialize f16 tables
+    {
+        struct ggml_init_params params = { 0, NULL, false };
+        struct ggml_context * ctx = ggml_init(params);
+        ggml_free(ctx);
+    }
+}
+
+void llama_numa_init(enum ggml_numa_strategy numa) {
+    if (numa != GGML_NUMA_STRATEGY_DISABLED) {
+        ggml_numa_init(numa);
+    }
+}
+
+void llama_backend_free(void) {
+    ggml_quantize_free();
+}
+
+int64_t llama_time_us(void) {
+    return ggml_time_us();
+}
+
+struct llama_model * llama_load_model_from_file(
+        const char * path_model,
+        struct llama_model_params   params) {
+    ggml_time_init();
+
+    llama_model * model = new llama_model;
+
+    unsigned cur_percentage = 0;
+    if (params.progress_callback == NULL) {
+        params.progress_callback_user_data = &cur_percentage;
+        params.progress_callback = [](float progress, void * ctx) {
+            unsigned * cur_percentage_p = (unsigned *) ctx;
+            unsigned percentage = (unsigned) (100 * progress);
+            while (percentage > *cur_percentage_p) {
+                *cur_percentage_p = percentage;
+                LLAMA_LOG_INFO(".");
+                if (percentage >= 100) {
+                    LLAMA_LOG_INFO("\n");
+                }
+            }
+            return true;
+        };
+    }
+    if (params.rpc_servers != nullptr && params.rpc_servers[0] != '\0') {
+        // split the servers set them into model->rpc_servers
+        std::string servers(params.rpc_servers);
+        size_t pos = 0;
+        while ((pos = servers.find(",")) != std::string::npos) {
+            std::string server = servers.substr(0, pos);
+            model->rpc_servers.push_back(server);
+            servers.erase(0, pos + 1);
+        }
+        model->rpc_servers.push_back(servers);
+    }
+    int status = llama_model_load(path_model, *model, params);
+    GGML_ASSERT(status <= 0);
+    if (status < 0) {
+        if (status == -1) {
+            LLAMA_LOG_ERROR("%s: failed to load model\n", __func__);
+        } else if (status == -2) {
+            LLAMA_LOG_INFO("%s: cancelled model load\n", __func__);
+        }
+        delete model;
+        return nullptr;
+    }
+
+    return model;
+}
+
+void llama_free_model(struct llama_model * model) {
+    delete model;
+}
+
+struct llama_context * llama_new_context_with_model(
+                 struct llama_model * model,
+        struct llama_context_params   params) {
+
+    if (!model) {
+        LLAMA_LOG_ERROR("%s: model cannot be NULL\n", __func__);
+        return nullptr;
+    }
+
+    if (params.n_batch == 0 && params.n_ubatch == 0) {
+        LLAMA_LOG_ERROR("%s: n_batch and n_ubatch cannot both be zero\n", __func__);
+        return nullptr;
+    }
+
+    if (params.n_ctx == 0 && model->hparams.n_ctx_train == 0) {
+        LLAMA_LOG_ERROR("%s: n_ctx and model->hparams.n_ctx_train cannot both be zero\n", __func__);
+        return nullptr;
+    }
+
+    if (params.flash_attn && model->arch == LLM_ARCH_GROK) {
+        LLAMA_LOG_WARN("%s: flash_attn is not compatible with Grok - forcing off\n", __func__);
+        params.flash_attn = false;
+    }
+
+    if (params.flash_attn && model->hparams.attn_soft_cap) {
+        LLAMA_LOG_WARN("%s: flash_attn is not compatible with attn_soft_cap - forcing off\n", __func__);
+        params.flash_attn = false;
+    }
+
+
+    if (params.flash_attn && model->hparams.n_embd_head_k != model->hparams.n_embd_head_v) {
+        LLAMA_LOG_WARN("%s: flash_attn requires n_embd_head_k == n_embd_head_v - forcing off\n", __func__);
+        params.flash_attn = false;
+    }
+
+    if (params.type_v != GGML_TYPE_F16 && !params.flash_attn) {
+        LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__);
+        return nullptr;
+    }
+
+    llama_context * ctx = new llama_context(*model);
+
+    const auto & hparams = model->hparams;
+    auto       & cparams = ctx->cparams;
+
+    cparams.n_seq_max        = std::max(1u, params.n_seq_max);
+    cparams.n_threads        = params.n_threads;
+    cparams.n_threads_batch  = params.n_threads_batch;
+    cparams.yarn_ext_factor  = params.yarn_ext_factor;
+    cparams.yarn_attn_factor = params.yarn_attn_factor;
+    cparams.yarn_beta_fast   = params.yarn_beta_fast;
+    cparams.yarn_beta_slow   = params.yarn_beta_slow;
+    cparams.defrag_thold     = params.defrag_thold;
+    cparams.embeddings       = params.embeddings;
+    cparams.offload_kqv      = params.offload_kqv;
+    cparams.flash_attn       = params.flash_attn;
+    cparams.pooling_type     = params.pooling_type;
+
+    cparams.n_ctx            = params.n_ctx           == 0    ? hparams.n_ctx_train           : params.n_ctx;
+    cparams.rope_freq_base   = params.rope_freq_base  == 0.0f ? hparams.rope_freq_base_train  : params.rope_freq_base;
+    cparams.rope_freq_scale  = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale;
+
+    // this is necessary due to kv_self.n being padded later during inference
+    cparams.n_ctx            = GGML_PAD(cparams.n_ctx, llama_kv_cache_get_padding(cparams));
+
+    // with causal attention, the batch size is limited by the context size
+    cparams.n_batch          = hparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;
+
+    // the batch has to be at least GGML_KQ_MASK_PAD because we will be padding the KQ_mask
+    // this is required by GPU kernels in order to avoid out-of-bounds accesses (e.g. ggml_flash_attn_ext)
+    // ref: https://github.com/ggerganov/llama.cpp/pull/5021
+    if (cparams.n_batch < GGML_KQ_MASK_PAD) {
+        LLAMA_LOG_WARN("%s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n", __func__, GGML_KQ_MASK_PAD);
+        cparams.n_batch = GGML_KQ_MASK_PAD;
+    }
+
+    cparams.n_ubatch         = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
+
+    cparams.n_ctx_orig_yarn  = params.yarn_orig_ctx    != 0 ? params.yarn_orig_ctx    :
+                               hparams.n_ctx_orig_yarn != 0 ? hparams.n_ctx_orig_yarn :
+                                                              hparams.n_ctx_train;
+
+    cparams.cb_eval           = params.cb_eval;
+    cparams.cb_eval_user_data = params.cb_eval_user_data;
+
+    auto rope_scaling_type = params.rope_scaling_type;
+    if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED) {
+        rope_scaling_type = hparams.rope_scaling_type_train;
+    }
+
+    if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_NONE) {
+        cparams.rope_freq_scale = 1.0f; // never scale if scaling type is none
+    }
+
+    if (cparams.yarn_ext_factor < 0.0f) { // negative indicates 'not set'
+        cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_YARN ? 1.0f : 0.0f;
+    }
+
+    cparams.yarn_attn_factor *= hparams.rope_attn_factor;
+
+    if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
+        if (hparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
+            cparams.pooling_type = LLAMA_POOLING_TYPE_NONE;
+        } else {
+            cparams.pooling_type = hparams.pooling_type;
+        }
+    }
+
+    if (params.attention_type == LLAMA_ATTENTION_TYPE_UNSPECIFIED) {
+        cparams.causal_attn = hparams.causal_attn;
+    } else {
+        cparams.causal_attn = params.attention_type == LLAMA_ATTENTION_TYPE_CAUSAL;
+    }
+
+    if (params.seed == LLAMA_DEFAULT_SEED) {
+        params.seed = time(NULL);
+    }
+
+    LLAMA_LOG_INFO("%s: n_ctx      = %u\n",     __func__, cparams.n_ctx);
+    LLAMA_LOG_INFO("%s: n_batch    = %u\n",     __func__, cparams.n_batch);
+    LLAMA_LOG_INFO("%s: n_ubatch   = %u\n",     __func__, cparams.n_ubatch);
+    LLAMA_LOG_INFO("%s: flash_attn = %d\n",     __func__, cparams.flash_attn);
+    LLAMA_LOG_INFO("%s: freq_base  = %.1f\n",   __func__, cparams.rope_freq_base);
+    LLAMA_LOG_INFO("%s: freq_scale = %g\n",     __func__, cparams.rope_freq_scale);
+
+    ctx->abort_callback      = params.abort_callback;
+    ctx->abort_callback_data = params.abort_callback_data;
+
+    ctx->sampling.rng = std::mt19937(params.seed);
+    ctx->logits_all   = params.logits_all;
+
+    uint32_t kv_size = cparams.n_ctx;
+    ggml_type type_k = params.type_k;
+    ggml_type type_v = params.type_v;
+
+    // Mamba only needs a constant number of KV cache cells per sequence
+    if (model->arch == LLM_ARCH_MAMBA) {
+        // Mamba needs at least as many KV cells as there are sequences kept at any time
+        kv_size = std::max((uint32_t) 1, params.n_seq_max);
+        // it's probably best to keep as much precision as possible for the states
+        type_k = GGML_TYPE_F32; // required by ggml_ssm_conv for Mamba's conv_states
+        type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_states
+    }
+
+    GGML_ASSERT(hparams.n_embd_head_k % ggml_blck_size(type_k) == 0);
+    GGML_ASSERT(hparams.n_embd_head_v % ggml_blck_size(type_v) == 0);
+
+    if (!hparams.vocab_only) {
+        // initialize backends
+#if defined(GGML_USE_METAL)
+        if (model->n_gpu_layers > 0) {
+            ctx->backend_metal = ggml_backend_metal_init();
+            if (ctx->backend_metal == nullptr) {
+                LLAMA_LOG_ERROR("%s: failed to initialize Metal backend\n", __func__);
+                llama_free(ctx);
+                return nullptr;
+            }
+            ctx->backends.push_back(ctx->backend_metal);
+        }
+#elif defined(GGML_USE_CUDA)
+        if (model->split_mode == LLAMA_SPLIT_MODE_NONE || model->split_mode == LLAMA_SPLIT_MODE_ROW) {
+            // with split_mode LLAMA_SPLIT_MODE_NONE or LLAMA_SPLIT_MODE_ROW, only the main GPU backend is used
+            ggml_backend_t backend = ggml_backend_cuda_init(model->main_gpu);
+            if (backend == nullptr) {
+                LLAMA_LOG_ERROR("%s: failed to initialize CUDA%d backend\n", __func__, model->main_gpu);
+                llama_free(ctx);
+                return nullptr;
+            }
+            ctx->backends.push_back(backend);
+        } else {
+            // LLAMA_SPLIT_MODE_LAYER requires a backend for each GPU
+            for (int device = 0; device < ggml_backend_cuda_get_device_count(); ++device) {
+                ggml_backend_t backend = ggml_backend_cuda_init(device);
+                if (backend == nullptr) {
+                    LLAMA_LOG_ERROR("%s: failed to initialize CUDA%d backend\n", __func__, device);
+                    llama_free(ctx);
+                    return nullptr;
+                }
+                ctx->backends.push_back(backend);
+            }
+        }
+#elif defined(GGML_USE_VULKAN)
+        if (model->split_mode == LLAMA_SPLIT_MODE_ROW) {
+            LLAMA_LOG_ERROR("%s: Row split not supported. Failed to initialize Vulkan backend\n", __func__);
+            llama_free(ctx);
+            return nullptr;
+        }
+        if (model->split_mode == LLAMA_SPLIT_MODE_NONE) {
+            ggml_backend_t backend = ggml_backend_vk_init(model->main_gpu);
+            if (backend == nullptr) {
+                LLAMA_LOG_ERROR("%s: failed to initialize Vulkan backend\n", __func__);
+                llama_free(ctx);
+                return nullptr;
+            }
+            ctx->backends.push_back(backend);
+        } else {
+            for (int device = 0; device < ggml_backend_vk_get_device_count(); ++device) {
+                ggml_backend_t backend = ggml_backend_vk_init(device);
+                if (backend == nullptr) {
+                    LLAMA_LOG_ERROR("%s: failed to initialize Vulkan%d backend\n", __func__, device);
+                    llama_free(ctx);
+                    return nullptr;
+                }
+                ctx->backends.push_back(backend);
+            }
+        }
+#elif defined(GGML_USE_SYCL)
+        // with split_mode LLAMA_SPLIT_MODE_NONE or LLAMA_SPLIT_MODE_ROW, only the main GPU backend is used
+        if (model->split_mode == LLAMA_SPLIT_MODE_NONE || model->split_mode == LLAMA_SPLIT_MODE_ROW) {
+            ggml_backend_t backend = ggml_backend_sycl_init(model->main_gpu);
+            if (backend == nullptr) {
+                LLAMA_LOG_ERROR("%s: failed to initialize SYCL%d backend\n", __func__, model->main_gpu);
+                llama_free(ctx);
+                return nullptr;
+            }
+            ctx->backends.push_back(backend);
+        } else {
+            // LLAMA_SPLIT_LAYER requires a backend for each GPU
+            for (int i = 0; i < ggml_backend_sycl_get_device_count(); ++i) {
+                ggml_backend_t backend = ggml_backend_sycl_init(i);
+                if (backend == nullptr) {
+                    LLAMA_LOG_ERROR("%s: failed to initialize SYCL%d for No.%d backend\n", __func__, i, i);
+                    llama_free(ctx);
+                    return nullptr;
+                }
+                ctx->backends.push_back(backend);
+            }
+        }
+#elif defined(GGML_USE_KOMPUTE)
+        if (model->n_gpu_layers > 0) {
+            auto * backend = ggml_backend_kompute_init(model->main_gpu);
+            if (backend == nullptr) {
+                LLAMA_LOG_ERROR("%s: failed to initialize Kompute backend\n", __func__);
+                llama_free(ctx);
+                return nullptr;
+            }
+            ctx->backends.push_back(backend);
+        }
+#elif defined(GGML_USE_CANN)
+    // with split_mode LLAMA_SPLIT_MODE_NONE or LLAMA_SPLIT_MODE_ROW, only the main GPU backend is used
+    // TODO: ggml_backend_cann is not support split tensor now, just leave code here.
+    if (model->split_mode == LLAMA_SPLIT_MODE_NONE || model->split_mode == LLAMA_SPLIT_MODE_ROW) {
+        ggml_backend_t backend = ggml_backend_cann_init(model->main_gpu);
+        if (backend == nullptr) {
+            LLAMA_LOG_ERROR("%s: failed to initialize CANN%d backend\n", __func__, model->main_gpu);
+            llama_free(ctx);
+            return nullptr;
+        }
+        ctx->backends.push_back(backend);
+    } else {
+        // LLAMA_SPLIT_MODE_LAYER requires a backend for each GPU
+        // TODO: currently, CANN can't use multi-gpus, just leave code here for further cann version.
+        for (int32_t device = 0; device < ggml_backend_cann_get_device_count(); ++device) {
+            ggml_backend_t backend = ggml_backend_cann_init(device);
+            if (backend == nullptr) {
+                LLAMA_LOG_ERROR("%s: failed to initialize CANN%d backend\n", __func__, device);
+                llama_free(ctx);
+                return nullptr;
+            }
+            ctx->backends.push_back(backend);
+        }
+    }
+#endif
+
+#ifdef GGML_USE_BLAS
+        ctx->backend_blas = ggml_backend_blas_init();
+        if (ctx->backend_blas == nullptr) {
+            LLAMA_LOG_WARN("%s: failed to initialize BLAS backend\n", __func__);
+        } else {
+            ctx->backends.push_back(ctx->backend_blas);
+        }
+#endif
+
+#if defined(GGML_USE_RPC)
+        if (model->n_gpu_layers > 0) {
+            for (const auto & endpoint : model->rpc_servers) {
+                ggml_backend_t backend = ggml_backend_rpc_init(endpoint.c_str());
+                if (backend == nullptr) {
+                    LLAMA_LOG_ERROR("%s: failed to initialize RPC to '%s'\n", __func__, endpoint.c_str());
+                    llama_free(ctx);
+                    return nullptr;
+                }
+                ctx->backends.push_back(backend);
+            }
+        }
+#endif
+        ctx->backend_cpu = ggml_backend_cpu_init();
+        if (ctx->backend_cpu == nullptr) {
+            LLAMA_LOG_ERROR("%s: failed to initialize CPU backend\n", __func__);
+            llama_free(ctx);
+            return nullptr;
+        }
+        ctx->backends.push_back(ctx->backend_cpu);
+
+        if (!llama_kv_cache_init(ctx->kv_self, ctx, type_k, type_v, kv_size, cparams.offload_kqv)) {
+            LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__);
+            llama_free(ctx);
+            return nullptr;
+        }
+
+        {
+            size_t memory_size_k = 0;
+            size_t memory_size_v = 0;
+
+            for (auto & k : ctx->kv_self.k_l) {
+                memory_size_k += ggml_nbytes(k);
+            }
+
+            for (auto & v : ctx->kv_self.v_l) {
+                memory_size_v += ggml_nbytes(v);
+            }
+
+            LLAMA_LOG_INFO("%s: KV self size  = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
+                (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
+                ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
+                ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
+        }
+
+        // graph outputs buffer
+        {
+            // resized during inference when a batch uses more outputs
+            if (llama_output_reserve(*ctx, params.n_seq_max) < params.n_seq_max) {
+                LLAMA_LOG_ERROR("%s: failed to reserve initial output buffer\n", __func__);
+                llama_free(ctx);
+                return nullptr;
+            }
+
+            LLAMA_LOG_INFO("%s: %10s  output buffer size = %8.2f MiB\n", __func__,
+                    ggml_backend_buffer_name(ctx->buf_output),
+                    ggml_backend_buffer_get_size(ctx->buf_output) / 1024.0 / 1024.0);
+        }
+
+        // scheduler and compute buffers
+        {
+            // buffer types used for the compute buffer of each backend
+            std::vector backend_buft;
+            for (auto * backend : ctx->backends) {
+                if (ggml_backend_is_cpu(backend)) {
+                    // use host buffers for the CPU backend compute buffer
+                    backend_buft.push_back(llama_default_buffer_type_cpu(true));
+                } else {
+                    backend_buft.push_back(ggml_backend_get_default_buffer_type(backend));
+                }
+            }
+
+            const size_t max_nodes = llama_model_max_nodes(*model);
+
+            // buffer used to store the computation graph and the tensor meta data
+            ctx->buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false));
+
+            // enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary
+            bool pipeline_parallel =
+                llama_get_device_count(*model) > 1 &&
+                model->n_gpu_layers > (int)model->hparams.n_layer &&
+                model->split_mode == LLAMA_SPLIT_MODE_LAYER &&
+                params.offload_kqv;
+#ifndef GGML_USE_CUDA
+            // pipeline parallelism requires support for async compute and events
+            // currently this is only implemented in the CUDA backend
+            pipeline_parallel = false;
+#endif
+            ctx->sched = ggml_backend_sched_new(ctx->backends.data(), backend_buft.data(), ctx->backends.size(), max_nodes, pipeline_parallel);
+
+            if (pipeline_parallel) {
+                LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, ggml_backend_sched_get_n_copies(ctx->sched));
+            }
+
+            // build worst-case graph
+            int n_tokens = (int)std::min(cparams.n_ctx, cparams.n_ubatch);
+            int n_past = cparams.n_ctx - n_tokens;
+            llama_token token = llama_token_bos(&ctx->model); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
+            ggml_cgraph * gf = llama_build_graph(*ctx, llama_batch_get_one(&token, n_tokens, n_past, 0), true);
+
+            // initialize scheduler with the worst-case graph
+            if (!ggml_backend_sched_reserve(ctx->sched, gf)) {
+                LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
+                llama_free(ctx);
+                return nullptr;
+            }
+
+            for (size_t i = 0; i < ctx->backends.size(); i++) {
+                ggml_backend_t backend = ctx->backends[i];
+                ggml_backend_buffer_type_t buft = backend_buft[i];
+                size_t size = ggml_backend_sched_get_buffer_size(ctx->sched, backend);
+                if (size > 1) {
+                    LLAMA_LOG_INFO("%s: %10s compute buffer size = %8.2f MiB\n", __func__,
+                            ggml_backend_buft_name(buft),
+                            size / 1024.0 / 1024.0);
+                }
+            }
+
+            // note: the number of splits during measure is higher than during inference due to the kv shift
+            int n_splits = ggml_backend_sched_get_n_splits(ctx->sched);
+            LLAMA_LOG_INFO("%s: graph nodes  = %d\n", __func__, gf->n_nodes);
+            LLAMA_LOG_INFO("%s: graph splits = %d\n", __func__, n_splits);
+        }
+    }
+
+    return ctx;
+}
+
+void llama_free(struct llama_context * ctx) {
+    delete ctx;
+}
+
+const struct llama_model * llama_get_model(const struct llama_context * ctx) {
+    return &ctx->model;
+}
+
+const struct llama_vocab * llama_get_vocab(const struct llama_context * ctx) {
+    return &ctx->model.vocab;
+}
+
+uint32_t llama_n_ctx(const struct llama_context * ctx) {
+    return ctx->cparams.n_ctx;
+}
+
+uint32_t llama_n_batch(const struct llama_context * ctx) {
+    return ctx->cparams.n_batch;
+}
+
+uint32_t llama_n_ubatch(const struct llama_context * ctx) {
+    return ctx->cparams.n_ubatch;
+}
+
+uint32_t llama_n_seq_max(const struct llama_context * ctx) {
+    return ctx->kv_self.size;
+}
+
+enum llama_vocab_type llama_vocab_type(const struct llama_model * model) {
+    return model->vocab.type;
+}
+
+enum llama_rope_type llama_rope_type(const struct llama_model * model) {
+    switch (model->arch) {
+        // these models do not use RoPE
+        case LLM_ARCH_GPT2:
+        case LLM_ARCH_GPTJ:
+        case LLM_ARCH_MPT:
+        case LLM_ARCH_REFACT:
+        case LLM_ARCH_BLOOM:
+        case LLM_ARCH_MAMBA:
+        case LLM_ARCH_JINA_BERT_V2:
+        case LLM_ARCH_T5:
+        case LLM_ARCH_JAIS:
+            return LLAMA_ROPE_TYPE_NONE;
+
+        // use what we call a normal RoPE, operating on pairs of consecutive head values
+        case LLM_ARCH_LLAMA:
+        case LLM_ARCH_BAICHUAN:
+        case LLM_ARCH_STARCODER:
+        case LLM_ARCH_PLAMO:
+        case LLM_ARCH_ORION:
+        case LLM_ARCH_INTERNLM2:
+        case LLM_ARCH_MINICPM:
+        case LLM_ARCH_XVERSE:
+        case LLM_ARCH_COMMAND_R:
+        case LLM_ARCH_OLMO:
+        case LLM_ARCH_ARCTIC:
+        case LLM_ARCH_DEEPSEEK2:
+        case LLM_ARCH_CHATGLM:
+            return LLAMA_ROPE_TYPE_NORM;
+
+        // the pairs of head values are offset by n_rot/2
+        case LLM_ARCH_FALCON:
+        case LLM_ARCH_GROK:
+        case LLM_ARCH_DBRX:
+        case LLM_ARCH_BERT:
+        case LLM_ARCH_NOMIC_BERT:
+        case LLM_ARCH_STABLELM:
+        case LLM_ARCH_BITNET:
+        case LLM_ARCH_QWEN:
+        case LLM_ARCH_QWEN2:
+        case LLM_ARCH_QWEN2MOE:
+        case LLM_ARCH_PHI2:
+        case LLM_ARCH_PHI3:
+        case LLM_ARCH_GEMMA:
+        case LLM_ARCH_GEMMA2:
+        case LLM_ARCH_STARCODER2:
+        case LLM_ARCH_OPENELM:
+        case LLM_ARCH_GPTNEOX:
+        case LLM_ARCH_CODESHELL:
+            return LLAMA_ROPE_TYPE_NEOX;
+
+        // all model arches should be listed explicitly here
+        case LLM_ARCH_UNKNOWN:
+            GGML_ABORT("unknown architecture");
+    }
+
+    return LLAMA_ROPE_TYPE_NONE;
+}
+
+enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx) {
+    return ctx->cparams.pooling_type;
+}
+
+int32_t llama_n_vocab(const struct llama_model * model) {
+    return model->hparams.n_vocab;
+}
+
+int32_t llama_n_ctx_train(const struct llama_model * model) {
+    return model->hparams.n_ctx_train;
+}
+
+int32_t llama_n_embd(const struct llama_model * model) {
+    return model->hparams.n_embd;
+}
+
+int32_t llama_n_layer(const struct llama_model * model) {
+    return model->hparams.n_layer;
+}
+
+float llama_rope_freq_scale_train(const struct llama_model * model) {
+    return model->hparams.rope_freq_scale_train;
+}
+
+int32_t llama_model_meta_val_str(const struct llama_model * model, const char * key, char * buf, size_t buf_size) {
+    const auto & it = model->gguf_kv.find(key);
+    if (it == model->gguf_kv.end()) {
+        if (buf_size > 0) {
+            buf[0] = '\0';
+        }
+        return -1;
+    }
+    return snprintf(buf, buf_size, "%s", it->second.c_str());
+}
+
+int32_t llama_model_meta_count(const struct llama_model * model) {
+    return (int)model->gguf_kv.size();
+}
+
+int32_t llama_model_meta_key_by_index(const struct llama_model * model, int i, char * buf, size_t buf_size) {
+    if (i < 0 || i >= (int)model->gguf_kv.size()) {
+        if (buf_size > 0) {
+            buf[0] = '\0';
+        }
+        return -1;
+    }
+    auto it = model->gguf_kv.begin();
+    std::advance(it, i);
+    return snprintf(buf, buf_size, "%s", it->first.c_str());
+}
+
+int32_t llama_model_meta_val_str_by_index(const struct llama_model * model, int32_t i, char * buf, size_t buf_size) {
+    if (i < 0 || i >= (int)model->gguf_kv.size()) {
+        if (buf_size > 0) {
+            buf[0] = '\0';
+        }
+        return -1;
+    }
+    auto it = model->gguf_kv.begin();
+    std::advance(it, i);
+    return snprintf(buf, buf_size, "%s", it->second.c_str());
+}
+
+int32_t llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size) {
+    return snprintf(buf, buf_size, "%s %s %s",
+            llama_model_arch_name(model->arch),
+            llama_model_type_name(model->type),
+            llama_model_ftype_name(model->ftype).c_str());
+}
+
+uint64_t llama_model_size(const struct llama_model * model) {
+    uint64_t size = 0;
+    for (const auto & it : model->tensors_by_name) {
+        size += ggml_nbytes(it.second);
+    }
+    return size;
+}
+
+uint64_t llama_model_n_params(const struct llama_model * model) {
+    uint64_t nparams = 0;
+    for (const auto & it : model->tensors_by_name) {
+        nparams += ggml_nelements(it.second);
+    }
+    return nparams;
+}
+
+struct ggml_tensor * llama_get_model_tensor(struct llama_model * model, const char * name) {
+    auto it = std::find_if(model->tensors_by_name.begin(), model->tensors_by_name.end(),
+            [name](const std::pair & it) {
+                return it.first == name;
+            });
+    if (it == model->tensors_by_name.end()) {
+        return nullptr;
+    }
+    return it->second;
+}
+
+bool llama_model_has_encoder(const struct llama_model * model) {
+    switch (model->arch) {
+        case LLM_ARCH_T5: return true;
+        default:          return false;
+    }
+}
+
+llama_token llama_model_decoder_start_token(const struct llama_model * model) {
+    return model->hparams.dec_start_token_id;
+}
+
+uint32_t llama_model_quantize(
+        const char * fname_inp,
+        const char * fname_out,
+        const llama_model_quantize_params * params) {
+    try {
+        llama_model_quantize_internal(fname_inp, fname_out, params);
+        return 0;
+    } catch (const std::exception & err) {
+        LLAMA_LOG_ERROR("%s: failed to quantize: %s\n", __func__, err.what());
+        return 1;
+    }
+}
+
+struct llama_lora_adapter * llama_lora_adapter_init(struct llama_model * model, const char * path_lora) {
+    try {
+        struct llama_lora_adapter * adapter = new llama_lora_adapter(model);
+        llama_lora_adapter_init_internal(model, path_lora, *adapter);
+        return adapter;
+    } catch (const std::exception & err) {
+        LLAMA_LOG_ERROR("%s: failed to apply lora adapter: %s\n", __func__, err.what());
+        return nullptr;
+    }
+}
+
+static bool llama_control_vector_init(struct llama_control_vector & cvec, const llama_model & model) {
+    GGML_ASSERT(cvec.tensors.empty());
+    GGML_ASSERT(cvec.ctxs.empty());
+    GGML_ASSERT(cvec.bufs.empty());
+
+    // count layer buffer types
+    std::map buft_layer_count;
+    for (int64_t i = 0; i < model.hparams.n_layer; i++) {
+        buft_layer_count[model.buft_layer[i].buft]++;
+    }
+
+    // allocate contexts
+    std::map ctx_map;
+    for (auto & it : buft_layer_count) {
+        int n_layers = it.second;
+        struct ggml_init_params params = {
+            /*.mem_size   =*/ n_layers * ggml_tensor_overhead(),
+            /*.mem_buffer =*/ NULL,
+            /*.no_alloc   =*/ true,
+        };
+        ggml_context * ctx = ggml_init(params);
+        if (!ctx) {
+            LLAMA_LOG_ERROR("%s: failed to allocate context for control vector\n", __func__);
+            return 1;
+        }
+        ctx_map[it.first] = ctx;
+    }
+
+    // make tensors
+    cvec.tensors.reserve(model.hparams.n_layer);
+    cvec.tensors.push_back(nullptr); // there's never a tensor for layer 0
+    for (size_t il = 1; il < model.hparams.n_layer; il++) {
+        struct ggml_context * ctx = ctx_map.at(model.buft_layer[il].buft);
+        ggml_tensor * tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, model.hparams.n_embd);
+        cvec.tensors.push_back(tensor);
+    }
+
+    // allocate tensors / buffers and zero
+    cvec.ctxs.reserve(ctx_map.size());
+    cvec.bufs.reserve(ctx_map.size());
+    for (auto it : ctx_map) {
+        ggml_backend_buffer_type_t buft = it.first;
+        ggml_context * ctx = it.second;
+        ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
+        if (!buf) {
+            LLAMA_LOG_ERROR("%s: failed to allocate buffer for control vector\n", __func__);
+            return false;
+        }
+        ggml_backend_buffer_clear(buf, 0);
+        cvec.ctxs.push_back(ctx);
+        cvec.bufs.push_back(buf);
+    }
+
+    return true;
+}
+
+int32_t llama_control_vector_apply(struct llama_context * lctx, const float * data, size_t len, int32_t n_embd, int32_t il_start, int32_t il_end) {
+    const llama_model & model = lctx->model;
+    llama_control_vector & cvec = lctx->cvec;
+
+    if (data == nullptr) {
+        // disable the current control vector (but leave allocated for later)
+        cvec.layer_start = -1;
+        cvec.layer_end   = -1;
+        return 0;
+    }
+
+    if (n_embd != (int) model.hparams.n_embd) {
+        LLAMA_LOG_ERROR("%s: control vector n_embd does not match model\n", __func__);
+        return 1;
+    }
+
+    if (cvec.tensors.empty()) {
+        if (!llama_control_vector_init(cvec, model)) {
+            return 1;
+        }
+    }
+
+    cvec.layer_start = il_start;
+    cvec.layer_end   = il_end;
+
+    for (size_t il = 1; il < model.hparams.n_layer; il++) {
+        assert(cvec.tensors[il] != nullptr);
+
+        const size_t off = n_embd * (il - 1); // buffer doesn't have data for layer 0, since it's never present
+        if (off + n_embd <= len) {
+            ggml_backend_tensor_set(cvec.tensors[il], data + off, 0, n_embd * ggml_element_size(cvec.tensors[il]));
+        }
+    }
+
+    return 0;
+}
+
+struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_seq_max) {
+    struct llama_kv_cache_view result = {
+        /*.n_cells            = */ 0,
+        /*.n_seq_max          = */ n_seq_max,
+        /*.token_count        = */ 0,
+        /*.used_cells         = */ llama_get_kv_cache_used_cells(ctx),
+        /*.max_contiguous     = */ 0,
+        /*.max_contiguous_idx = */ -1,
+        /*.cells              = */ nullptr,
+        /*.cells_sequences    = */ nullptr,
+    };
+    return result;
+}
+
+void llama_kv_cache_view_free(struct llama_kv_cache_view * view) {
+    if (view->cells != nullptr) {
+        free(view->cells);
+        view->cells = nullptr;
+    }
+    if (view->cells_sequences != nullptr) {
+        free(view->cells_sequences);
+        view->cells_sequences = nullptr;
+    }
+}
+
+void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view) {
+    if (uint32_t(view->n_cells) < ctx->kv_self.size || view->cells == nullptr) {
+        view->n_cells = int32_t(ctx->kv_self.size);
+        void * p = realloc(view->cells, sizeof(struct llama_kv_cache_view_cell) * view->n_cells);
+        GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells");
+        view->cells = (struct llama_kv_cache_view_cell *)p;
+        p = realloc(view->cells_sequences, sizeof(llama_seq_id) * view->n_seq_max * view->n_cells);
+        GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells sequences");
+        view->cells_sequences = (llama_seq_id *)p;
+    }
+
+    const std::vector & kv_cells = ctx->kv_self.cells;
+    llama_kv_cache_view_cell * c_curr = view->cells;
+    llama_seq_id * cs_curr = view->cells_sequences;
+    int32_t used_cells = 0;
+    int32_t token_count = 0;
+    int32_t curr_contig_idx = -1;
+    uint32_t max_contig = 0;
+    int32_t max_contig_idx = -1;
+
+    for (int32_t i = 0; i < int32_t(ctx->kv_self.size); i++, c_curr++, cs_curr += view->n_seq_max) {
+        const size_t curr_size = kv_cells[i].seq_id.size();
+        token_count += curr_size;
+        c_curr->pos = kv_cells[i].pos + kv_cells[i].delta;
+
+        if (curr_size > 0) {
+            if (curr_contig_idx >= 0 && uint32_t(i - curr_contig_idx) > max_contig) {
+                max_contig = i - curr_contig_idx;
+                max_contig_idx = curr_contig_idx;
+            }
+            curr_contig_idx = -1;
+        } else if (curr_contig_idx < 0) {
+            curr_contig_idx = i;
+        }
+
+        int seq_idx = 0;
+        for (const llama_seq_id it : kv_cells[i].seq_id) {
+            if (seq_idx >= view->n_seq_max) {
+                break;
+            }
+            cs_curr[seq_idx] = it;
+            seq_idx++;
+        }
+        if (seq_idx != 0) {
+            used_cells++;
+        }
+        for (; seq_idx < view->n_seq_max; seq_idx++) {
+            cs_curr[seq_idx] = -1;
+        }
+    }
+    if (curr_contig_idx >= 0 && kv_cells.size() - curr_contig_idx > max_contig) {
+        max_contig_idx = curr_contig_idx;
+        max_contig = kv_cells.size() - curr_contig_idx;
+    }
+    view->max_contiguous = max_contig;
+    view->max_contiguous_idx = max_contig_idx;
+    view->token_count = token_count;
+    view->used_cells = used_cells;
+    if (uint32_t(used_cells) != ctx->kv_self.used) {
+        LLAMA_LOG_ERROR("%s: used cells mismatch. kv_cache says %d but we calculated %d\n",
+            __func__, ctx->kv_self.used, used_cells);
+    }
+}
+
+int32_t llama_get_kv_cache_token_count(const struct llama_context * ctx) {
+    int result = 0;
+
+    for (uint32_t i = 0; i < ctx->kv_self.size; i++) {
+        result += ctx->kv_self.cells[i].seq_id.size();
+    }
+
+    return result;
+}
+
+int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx) {
+    return ctx->kv_self.used;
+}
+
+void llama_kv_cache_clear(struct llama_context * ctx) {
+    llama_kv_cache_clear(ctx->kv_self);
+}
+
+bool llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
+    return llama_kv_cache_seq_rm(ctx->kv_self, seq_id, p0, p1);
+}
+
+void llama_kv_cache_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
+    if (seq_id_src == seq_id_dst) {
+        return;
+    }
+    llama_kv_cache_seq_cp(ctx->kv_self, seq_id_src, seq_id_dst, p0, p1);
+}
+
+void llama_kv_cache_seq_keep(struct llama_context * ctx, llama_seq_id seq_id) {
+    llama_kv_cache_seq_keep(ctx->kv_self, seq_id);
+}
+
+void llama_kv_cache_seq_add(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
+    if (delta == 0) {
+        return;
+    }
+
+    llama_kv_cache_seq_add(ctx->kv_self, seq_id, p0, p1, delta);
+}
+
+void llama_kv_cache_seq_div(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
+    if (d == 1) {
+        return;
+    }
+
+    llama_kv_cache_seq_div(ctx->kv_self, seq_id, p0, p1, d);
+}
+
+llama_pos llama_kv_cache_seq_pos_max(struct llama_context * ctx, llama_seq_id seq_id) {
+    return llama_kv_cache_seq_pos_max(ctx->kv_self, seq_id);
+}
+
+void llama_kv_cache_defrag(struct llama_context * ctx) {
+    llama_kv_cache_defrag(ctx->kv_self);
+}
+
+void llama_kv_cache_update(struct llama_context * ctx) {
+    llama_kv_cache_update_internal(*ctx);
+}
+
+// deprecated
+size_t llama_get_state_size(struct llama_context * ctx) {
+    return llama_state_get_size(ctx);
+}
+
+// deprecated
+size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) {
+    return llama_state_get_data(ctx, dst, -1);
+}
+
+// deprecated
+size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
+    return llama_state_set_data(ctx, src, -1);
+}
+
+// deprecated
+bool llama_load_session_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
+    return llama_state_load_file(ctx, path_session, tokens_out, n_token_capacity, n_token_count_out);
+}
+
+// deprecated
+bool llama_save_session_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) {
+    return llama_state_save_file(ctx, path_session, tokens, n_token_count);
+}
+
+// TODO: replace all non-fatal assertions with returned errors or exceptions
+struct llama_data_write {
+    virtual void write(const void * src, size_t size) = 0;
+    virtual size_t get_size_written() = 0;
+    virtual ~llama_data_write() = default;
+
+    void write_string(const std::string & str) {
+        uint32_t str_size = str.size();
+
+        write(&str_size,  sizeof(str_size));
+        write(str.data(), str_size);
+    }
+
+    void write_model_info(const struct llama_context * ctx) {
+        std::string arch_str = LLM_ARCH_NAMES.at(ctx->model.arch);
+        write_string(arch_str);
+        // TODO: add more model-specific info which should prevent loading the session file if not identical
+    }
+
+    void write_rng(const std::mt19937 & rng) {
+        std::ostringstream rng_ss;
+        rng_ss << rng;
+
+        const std::string & rng_str = rng_ss.str();
+
+        write_string(rng_str);
+    }
+
+    void write_output_ids(const struct llama_context * ctx) {
+        const uint32_t n_outputs = ctx->n_outputs;
+
+        std::vector output_pos;
+
+        const size_t    n_batch = ctx->cparams.n_batch;
+        const auto & output_ids = ctx->output_ids;
+
+        GGML_ASSERT(n_outputs <= ctx->output_size);
+
+        output_pos.resize(n_outputs);
+
+        // build a more compact representation of the output ids
+        for (size_t i = 0; i < n_batch; ++i) {
+            // map an output id to a position in the batch
+            int32_t pos = output_ids[i];
+            if (pos >= 0) {
+                GGML_ASSERT((uint32_t) pos < n_outputs);
+                output_pos[pos] = i;
+            }
+        }
+
+        write(&n_outputs, sizeof(n_outputs));
+
+        if (n_outputs) {
+            write(output_pos.data(), n_outputs * sizeof(int32_t));
+        }
+    }
+
+    void write_logits(const struct llama_context * ctx) {
+        const uint64_t logits_size = std::min((uint64_t) ctx->logits_size, (uint64_t) ctx->n_outputs * ctx->model.hparams.n_vocab);
+
+        write(&logits_size, sizeof(logits_size));
+
+        if (logits_size) {
+            write(ctx->logits, logits_size * sizeof(float));
+        }
+    }
+
+    void write_embeddings(const struct llama_context * ctx) {
+        const uint64_t embeddings_size = std::min((uint64_t) ctx->embd_size, (uint64_t) ctx->n_outputs * ctx->model.hparams.n_embd);
+
+        write(&embeddings_size, sizeof(embeddings_size));
+
+        if (embeddings_size) {
+            write(ctx->embd, embeddings_size * sizeof(float));
+        }
+    }
+
+    void write_kv_cache_meta(const llama_kv_cache & kv_self, const std::vector> & cell_ranges, llama_seq_id seq_id = -1) {
+
+        for (const auto & range : cell_ranges) {
+            for (uint32_t i = range.first; i < range.second; ++i) {
+                const auto & cell = kv_self.cells[i];
+                const llama_pos pos      = cell.pos;
+                const uint32_t  n_seq_id = seq_id == -1 ? cell.seq_id.size() : 0;
+
+                write(&pos,      sizeof(pos));
+                write(&n_seq_id, sizeof(n_seq_id));
+
+                if (n_seq_id) {
+                    for (auto seq_id : cell.seq_id) {
+                        write(&seq_id, sizeof(seq_id));
+                    }
+                }
+            }
+        }
+    }
+
+    void write_kv_cache_data(const struct llama_context * ctx, const std::vector> & cell_ranges) {
+        const struct llama_kv_cache & kv_self = ctx->kv_self;
+        const struct llama_hparams & hparams = ctx->model.hparams;
+
+        const uint32_t v_trans = kv_self.v_trans ? 1 : 0;
+        const uint32_t n_layer = hparams.n_layer;
+
+        write(&v_trans, sizeof(v_trans));
+        write(&n_layer, sizeof(n_layer));
+
+        std::vector tmp_buf;
+
+        // Iterate and write all the keys first, each row is a cell
+        // Get whole range at a time
+        for (uint32_t il = 0; il < n_layer; ++il) {
+            const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
+
+            // Write key type
+            const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type;
+            write(&k_type_i, sizeof(k_type_i));
+
+            // Write row size of key
+            const uint64_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa);
+            write(&k_size_row, sizeof(k_size_row));
+
+            // Read each range of cells of k_size length each into tmp_buf and write out
+            for (const auto & range : cell_ranges) {
+                const size_t range_size = range.second - range.first;
+                tmp_buf.resize(range_size * k_size_row);
+                ggml_backend_tensor_get(kv_self.k_l[il], tmp_buf.data(), range.first * k_size_row, range_size * k_size_row);
+                write(tmp_buf.data(), tmp_buf.size());
+            }
+        }
+
+        if (!kv_self.v_trans) {
+            for (uint32_t il = 0; il < n_layer; ++il) {
+                const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
+
+                // Write value type
+                const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
+                write(&v_type_i, sizeof(v_type_i));
+
+                // Write row size of value
+                const uint64_t v_size_row = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
+                write(&v_size_row, sizeof(v_size_row));
+
+                // Read each range of cells of v_size length each into tmp_buf and write out
+                for (const auto & range : cell_ranges) {
+                    const size_t range_size = range.second - range.first;
+                    tmp_buf.resize(range_size * v_size_row);
+                    ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), range.first * v_size_row, range_size * v_size_row);
+                    write(tmp_buf.data(), tmp_buf.size());
+                }
+            }
+        } else {
+            // When v is transposed, we also need the element size and get the element ranges from each row
+            const uint32_t kv_size = kv_self.size;
+            for (uint32_t il = 0; il < n_layer; ++il) {
+                const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
+
+                // Write value type
+                const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
+                write(&v_type_i, sizeof(v_type_i));
+
+                // Write element size
+                const uint32_t v_size_el = ggml_type_size(kv_self.v_l[il]->type);
+                write(&v_size_el, sizeof(v_size_el));
+
+                // Write GQA embedding size
+                write(&n_embd_v_gqa, sizeof(n_embd_v_gqa));
+
+                // For each row, we get the element values of each cell
+                for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
+                    // Read each range of cells of v_size_el length each into tmp_buf and write out
+                    for (const auto & range : cell_ranges) {
+                        const size_t range_size = range.second - range.first;
+                        const size_t src_offset = (range.first + j * kv_size) * v_size_el;
+                        tmp_buf.resize(range_size * v_size_el);
+                        ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), src_offset, tmp_buf.size());
+                        write(tmp_buf.data(), tmp_buf.size());
+                    }
+                }
+            }
+        }
+    }
+
+    void write_kv_cache(const struct llama_context * ctx, llama_seq_id seq_id = -1) {
+        const struct llama_kv_cache & kv_self = ctx->kv_self;
+        std::vector> cell_ranges; // ranges, from inclusive, to exclusive
+        uint32_t cell_count = 0;
+
+        // Count the number of cells with the specified seq_id
+        // Find all the ranges of cells with this seq id (or all, when -1)
+        uint32_t cell_range_begin = kv_self.size;
+        for (uint32_t i = 0; i < kv_self.size; ++i) {
+            const auto & cell = kv_self.cells[i];
+            if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) {
+                ++cell_count;
+                if (cell_range_begin == kv_self.size) {
+                    cell_range_begin = i;
+                }
+            } else {
+                if (cell_range_begin != kv_self.size) {
+                    cell_ranges.emplace_back(cell_range_begin, i);
+                    cell_range_begin = kv_self.size;
+                }
+            }
+        }
+        if (cell_range_begin != kv_self.size) {
+            cell_ranges.emplace_back(cell_range_begin, kv_self.size);
+        }
+
+        // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
+        uint32_t cell_count_check = 0;
+        for (const auto & range : cell_ranges) {
+            cell_count_check += range.second - range.first;
+        }
+        GGML_ASSERT(cell_count == cell_count_check);
+
+        write(&cell_count, sizeof(cell_count));
+
+        write_kv_cache_meta(kv_self, cell_ranges, seq_id);
+        write_kv_cache_data(ctx, cell_ranges);
+    }
+};
+
+struct llama_data_read {
+    virtual const uint8_t * read(size_t size) = 0;
+    virtual void read_to(void * dst, size_t size) = 0;
+    virtual size_t get_size_read() = 0;
+    virtual ~llama_data_read() = default;
+
+    void read_string(std::string & str) {
+        uint32_t str_size;
+        read_to(&str_size, sizeof(str_size));
+
+        str.assign((const char *) read(str_size), str_size);
+    }
+
+    // validate model information
+    void read_model_info(const struct llama_context * ctx) {
+        std::string cur_arch_str = LLM_ARCH_NAMES.at(ctx->model.arch);
+        std::string arch_str;
+        read_string(arch_str);
+        if (cur_arch_str != arch_str) {
+            throw std::runtime_error(format("wrong model arch: '%s' instead of '%s'", arch_str.c_str(), cur_arch_str.c_str()));
+        }
+        // TODO: add more info which needs to be identical but which is not verified otherwise
+    }
+
+    void read_rng(std::mt19937 & rng) {
+        std::string rng_str;
+        read_string(rng_str);
+
+        std::istringstream rng_ss(rng_str);
+        rng_ss >> rng;
+
+        if (rng_ss.fail()) {
+            throw std::runtime_error("failed to load RNG state");
+        }
+    }
+
+    void read_output_ids(struct llama_context * ctx) {
+        std::vector output_pos;
+
+        uint32_t n_outputs;
+        read_to(&n_outputs, sizeof(n_outputs));
+
+        if (n_outputs > llama_output_reserve(*ctx, n_outputs)) {
+            throw std::runtime_error("could not reserve outputs");
+        }
+
+        if (n_outputs) {
+            output_pos.resize(n_outputs);
+            read_to(output_pos.data(), n_outputs * sizeof(int32_t));
+
+            for (int32_t i = 0; i < (int32_t) output_pos.size(); ++i) {
+                int32_t id = output_pos[i];
+                if ((uint32_t) id >= ctx->cparams.n_batch) {
+                    throw std::runtime_error(format("invalid output id, %d does not fit in batch size of %u", id, ctx->cparams.n_batch));
+                }
+                ctx->output_ids[id] = i;
+            }
+
+            ctx->n_outputs = n_outputs;
+        }
+    }
+
+    void read_logits(struct llama_context * ctx) {
+        uint64_t logits_size;
+        read_to(&logits_size, sizeof(logits_size));
+
+        if (ctx->logits_size < logits_size) {
+            throw std::runtime_error("logits buffer too small");
+        }
+
+        if (logits_size) {
+            read_to(ctx->logits, logits_size * sizeof(float));
+        }
+    }
+
+    void read_embeddings(struct llama_context * ctx) {
+        uint64_t embeddings_size;
+        read_to(&embeddings_size, sizeof(embeddings_size));
+
+        if (ctx->embd_size < embeddings_size) {
+            throw std::runtime_error("embeddings buffer too small");
+        }
+
+        if (embeddings_size) {
+            read_to(ctx->embd, embeddings_size * sizeof(float));
+        }
+    }
+
+    bool read_kv_cache_meta(struct llama_context * ctx, uint32_t cell_count, llama_seq_id dest_seq_id = -1) {
+        struct llama_kv_cache & kv_self = ctx->kv_self;
+
+        if (dest_seq_id != -1) {
+            // single sequence
+
+            llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
+
+            llama_batch batch = llama_batch_init(cell_count, 0, 1);
+            batch.n_tokens = cell_count;
+            for (uint32_t i = 0; i < cell_count; ++i) {
+                llama_pos pos;
+                uint32_t n_seq_id;
+
+                read_to(&pos, sizeof(pos));
+                read_to(&n_seq_id, sizeof(n_seq_id));
+
+                if (n_seq_id != 0) {
+                    LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__);
+                    return false;
+                }
+
+                batch.pos[i] = pos;
+                batch.n_seq_id[i] = 1;
+                batch.seq_id[i][0] = dest_seq_id;
+            }
+            if (!llama_kv_cache_find_slot(kv_self, batch)) {
+                llama_batch_free(batch);
+                LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
+                return false;
+            }
+
+            // DEBUG CHECK: kv_self.head should be our first cell, kv_self.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
+            // Assume that this is one contiguous block of cells
+            GGML_ASSERT(kv_self.head + cell_count <= kv_self.size);
+            GGML_ASSERT(kv_self.cells[kv_self.head].pos == batch.pos[0]);
+            GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].pos == batch.pos[cell_count - 1]);
+            GGML_ASSERT(kv_self.cells[kv_self.head].has_seq_id(dest_seq_id));
+            GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].has_seq_id(dest_seq_id));
+
+            // Cleanup
+            llama_batch_free(batch);
+        } else {
+            // whole KV cache restore
+
+            if (cell_count > kv_self.size) {
+                LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__);
+                return false;
+            }
+
+            llama_kv_cache_clear(kv_self);
+
+            for (uint32_t i = 0; i < cell_count; ++i) {
+                llama_kv_cell & cell = kv_self.cells[i];
+
+                llama_pos pos;
+                uint32_t  n_seq_id;
+
+                read_to(&pos,      sizeof(pos));
+                read_to(&n_seq_id, sizeof(n_seq_id));
+
+                cell.pos = pos;
+
+                for (uint32_t j = 0; j < n_seq_id; ++j) {
+                    llama_seq_id seq_id;
+                    read_to(&seq_id, sizeof(seq_id));
+
+                    if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) {
+                        LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx));
+                        return false;
+                    }
+
+                    cell.seq_id.insert(seq_id);
+                }
+            }
+
+            kv_self.head = 0;
+            kv_self.used = cell_count;
+        }
+
+        return true;
+    }
+
+    bool read_kv_cache_data(struct llama_context * ctx, uint32_t cell_count) {
+        const struct llama_hparams & hparams = ctx->model.hparams;
+        struct llama_kv_cache & kv_self = ctx->kv_self;
+        uint32_t v_trans;
+        uint32_t n_layer;
+        read_to(&v_trans, sizeof(v_trans));
+        read_to(&n_layer, sizeof(n_layer));
+
+        if (n_layer != hparams.n_layer) {
+            LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, hparams.n_layer);
+            return false;
+        }
+        if (cell_count > kv_self.size) {
+            LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, kv_self.size);
+            return false;
+        }
+        if (kv_self.v_trans != (bool) v_trans) {
+            LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__);
+            return false;
+        }
+
+        // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
+        for (uint32_t il = 0; il < n_layer; ++il) {
+            const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
+
+            // Read type of key
+            int32_t k_type_i_ref;
+            read_to(&k_type_i_ref, sizeof(k_type_i_ref));
+            const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type;
+            if (k_type_i != k_type_i_ref) {
+                LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
+                return false;
+            }
+
+            // Read row size of key
+            uint64_t k_size_row_ref;
+            read_to(&k_size_row_ref, sizeof(k_size_row_ref));
+            const size_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa);
+            if (k_size_row != k_size_row_ref) {
+                LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
+                return false;
+            }
+
+            if (cell_count) {
+                // Read and set the keys for the whole cell range
+                ggml_backend_tensor_set(kv_self.k_l[il], read(cell_count * k_size_row), kv_self.head * k_size_row, cell_count * k_size_row);
+            }
+        }
+
+        if (!kv_self.v_trans) {
+            for (uint32_t il = 0; il < n_layer; ++il) {
+                const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
+
+                // Read type of value
+                int32_t v_type_i_ref;
+                read_to(&v_type_i_ref, sizeof(v_type_i_ref));
+                const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
+                if (v_type_i != v_type_i_ref) {
+                    LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
+                    return false;
+                }
+
+                // Read row size of value
+                uint64_t v_size_row_ref;
+                read_to(&v_size_row_ref, sizeof(v_size_row_ref));
+                const size_t v_size_row = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
+                if (v_size_row != v_size_row_ref) {
+                    LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
+                    return false;
+                }
+
+                if (cell_count) {
+                    // Read and set the values for the whole cell range
+                    ggml_backend_tensor_set(kv_self.v_l[il], read(cell_count * v_size_row), kv_self.head * v_size_row, cell_count * v_size_row);
+                }
+            }
+        } else {
+            // For each layer, read the values for each cell (transposed)
+            for (uint32_t il = 0; il < n_layer; ++il) {
+                const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
+
+                // Read type of value
+                int32_t v_type_i_ref;
+                read_to(&v_type_i_ref, sizeof(v_type_i_ref));
+                const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
+                if (v_type_i != v_type_i_ref) {
+                    LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
+                    return false;
+                }
+
+                // Read element size of value
+                uint32_t v_size_el_ref;
+                read_to(&v_size_el_ref, sizeof(v_size_el_ref));
+                const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type);
+                if (v_size_el != v_size_el_ref) {
+                    LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
+                    return false;
+                }
+
+                // Read GQA embedding size
+                uint32_t n_embd_v_gqa_ref;
+                read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref));
+                if (n_embd_v_gqa != n_embd_v_gqa_ref) {
+                    LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il);
+                    return false;
+                }
+
+                if (cell_count) {
+                    // For each row in the transposed matrix, read the values for the whole cell range
+                    for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
+                        const size_t dst_offset = (kv_self.head + j * kv_self.size) * v_size_el;
+                        ggml_backend_tensor_set(kv_self.v_l[il], read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
+                    }
+                }
+            }
+        }
+        return true;
+    }
+
+    void read_kv_cache(struct llama_context * ctx, llama_seq_id seq_id = -1) {
+        uint32_t cell_count;
+        read_to(&cell_count, sizeof(cell_count));
+
+        bool res = read_kv_cache_meta(ctx, cell_count, seq_id) && read_kv_cache_data(ctx, cell_count);
+
+        if (!res) {
+            if (seq_id == -1) {
+                llama_kv_cache_clear(ctx);
+            } else {
+                llama_kv_cache_seq_rm(ctx, seq_id, -1, -1);
+            }
+            throw std::runtime_error("failed to restore kv cache");
+        }
+    }
+};
+
+struct llama_data_write_dummy : llama_data_write {
+    size_t size_written = 0;
+
+    llama_data_write_dummy() {}
+
+    // TODO: avoid unnecessary calls to ggml_backend_tensor_get in a dummy context
+
+    void write(const void * /* src */, size_t size) override {
+        size_written += size;
+    }
+
+    size_t get_size_written() override {
+        return size_written;
+    }
+};
+
+struct llama_data_write_buffer : llama_data_write {
+    uint8_t * ptr;
+    size_t buf_size = 0;
+    size_t size_written = 0;
+
+    llama_data_write_buffer(uint8_t * p, size_t len) : ptr(p), buf_size(len) {}
+
+    void write(const void * src, size_t size) override {
+        if (size > buf_size) {
+            throw std::runtime_error("unexpectedly reached end of buffer");
+        }
+        memcpy(ptr, src, size);
+        ptr += size;
+        size_written += size;
+        buf_size -= size;
+    }
+
+    size_t get_size_written() override {
+        return size_written;
+    }
+};
+
+struct llama_data_read_buffer : llama_data_read {
+    const uint8_t * ptr;
+    size_t buf_size = 0;
+    size_t size_read = 0;
+
+    llama_data_read_buffer(const uint8_t * p, size_t len) : ptr(p), buf_size(len) {}
+
+    const uint8_t * read(size_t size) override {
+        const uint8_t * base_ptr = ptr;
+        if (size > buf_size) {
+            throw std::runtime_error("unexpectedly reached end of buffer");
+        }
+        ptr += size;
+        size_read += size;
+        buf_size -= size;
+        return base_ptr;
+    }
+
+    void read_to(void * dst, size_t size) override {
+        memcpy(dst, read(size), size);
+    }
+
+    size_t get_size_read() override {
+        return size_read;
+    }
+};
+
+struct llama_data_write_file : llama_data_write {
+    llama_file * file;
+    size_t size_written = 0;
+
+    llama_data_write_file(llama_file * f) : file(f) {}
+
+    void write(const void * src, size_t size) override {
+        file->write_raw(src, size);
+        size_written += size;
+    }
+
+    size_t get_size_written() override {
+        return size_written;
+    }
+};
+
+struct llama_data_read_file : llama_data_read {
+    llama_file * file;
+    size_t size_read = 0;
+    std::vector temp_buffer;
+
+    llama_data_read_file(llama_file * f) : file(f) {}
+
+    void read_to(void * dst, size_t size) override {
+        file->read_raw(dst, size);
+        size_read += size;
+    }
+
+    const uint8_t * read(size_t size) override {
+        temp_buffer.resize(size);
+        read_to(temp_buffer.data(), size);
+        return temp_buffer.data();
+    }
+
+    size_t get_size_read() override {
+        return size_read;
+    }
+};
+
+/** copy state data into either a buffer or file depending on the passed in context
+ *
+ * file context:
+ * llama_file file("/path", "wb");
+ * llama_data_write_file data_ctx(&file);
+ * llama_state_get_data_internal(ctx, data_ctx);
+ *
+ * buffer context:
+ * std::vector buf(max_size, 0);
+ * llama_data_write_buffer data_ctx(buf.data(), max_size);
+ * llama_state_get_data_internal(ctx, data_ctx);
+ *
+*/
+static size_t llama_state_get_data_internal(struct llama_context * ctx, llama_data_write & data_ctx) {
+    llama_synchronize(ctx);
+
+    data_ctx.write_model_info(ctx);
+
+    data_ctx.write_rng(ctx->sampling.rng);
+
+    // copy outputs
+    data_ctx.write_output_ids(ctx);
+    data_ctx.write_logits(ctx);
+    data_ctx.write_embeddings(ctx);
+
+    data_ctx.write_kv_cache(ctx);
+
+    return data_ctx.get_size_written();
+}
+
+size_t llama_state_get_data(struct llama_context * ctx, uint8_t * dst, size_t size) {
+    llama_data_write_buffer data_ctx(dst, size);
+    try {
+        return llama_state_get_data_internal(ctx, data_ctx);
+    } catch (const std::exception & err) {
+        LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what());
+        return 0;
+    }
+}
+
+// Returns the *actual* size of the state.
+// Intended to be used when saving to state to a buffer.
+size_t llama_state_get_size(struct llama_context * ctx) {
+    llama_data_write_dummy data_ctx;
+    try {
+        return llama_state_get_data_internal(ctx, data_ctx);
+    } catch (const std::exception & err) {
+        LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what());
+        return 0;
+    }
+}
+
+static size_t llama_state_set_data_internal(struct llama_context * ctx, llama_data_read & data_ctx) {
+    llama_synchronize(ctx);
+
+    data_ctx.read_model_info(ctx);
+
+    // set rng
+    data_ctx.read_rng(ctx->sampling.rng);
+
+    // set outputs
+    data_ctx.read_output_ids(ctx);
+    data_ctx.read_logits(ctx);
+    data_ctx.read_embeddings(ctx);
+
+    data_ctx.read_kv_cache(ctx);
+
+    return data_ctx.get_size_read();
+}
+
+// Sets the state reading from the specified source address
+size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src, size_t size) {
+    llama_data_read_buffer data_ctx(src, size);
+    try {
+        return llama_state_set_data_internal(ctx, data_ctx);
+    } catch (const std::exception & err) {
+        LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what());
+        return 0;
+    }
+}
+
+static bool llama_state_load_file_internal(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
+    llama_file file(path_session, "rb");
+
+    // sanity checks
+    {
+        const uint32_t magic   = file.read_u32();
+        const uint32_t version = file.read_u32();
+
+        if (magic != LLAMA_SESSION_MAGIC || version != LLAMA_SESSION_VERSION) {
+            LLAMA_LOG_ERROR("%s: unknown (magic, version) for session file: %08x, %08x\n", __func__, magic, version);
+            return false;
+        }
+    }
+
+    // load the prompt
+    {
+        const uint32_t n_token_count = file.read_u32();
+
+        if (n_token_count > n_token_capacity) {
+            LLAMA_LOG_ERROR("%s: token count in session file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity);
+            return false;
+        }
+
+        file.read_raw(tokens_out, sizeof(llama_token) * n_token_count);
+        *n_token_count_out = n_token_count;
+    }
+
+    // restore the context state
+    {
+        const size_t n_state_size_cur = file.size - file.tell();
+
+        llama_data_read_file data_ctx(&file);
+        const size_t n_read = llama_state_set_data_internal(ctx, data_ctx);
+
+        if (n_read != n_state_size_cur) {
+            LLAMA_LOG_ERROR("%s: did not read all of the session file data! size %zu, got %zu\n", __func__, n_state_size_cur, n_read);
+            return false;
+        }
+    }
+    return true;
+}
+
+bool llama_state_load_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
+    try {
+        return llama_state_load_file_internal(ctx, path_session, tokens_out, n_token_capacity, n_token_count_out);
+    } catch (const std::exception & err) {
+        LLAMA_LOG_ERROR("%s: error loading session file: %s\n", __func__, err.what());
+        return false;
+    }
+}
+
+static bool llama_state_save_file_internal(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) {
+    llama_file file(path_session, "wb");
+
+    file.write_u32(LLAMA_SESSION_MAGIC);
+    file.write_u32(LLAMA_SESSION_VERSION);
+
+    // save the prompt
+    file.write_u32((uint32_t) n_token_count);
+    file.write_raw(tokens, sizeof(llama_token) * n_token_count);
+
+    // save the context state using stream saving
+    llama_data_write_file data_ctx(&file);
+    llama_state_get_data_internal(ctx, data_ctx);
+
+    return true;
+}
+
+bool llama_state_save_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) {
+    try {
+        return llama_state_save_file_internal(ctx, path_session, tokens, n_token_count);
+    } catch (const std::exception & err) {
+        LLAMA_LOG_ERROR("%s: error saving session file: %s\n", __func__, err.what());
+        return false;
+    }
+}
+
+static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llama_data_write & data_ctx, llama_seq_id seq_id) {
+    llama_synchronize(ctx);
+
+    data_ctx.write_kv_cache(ctx, seq_id);
+
+    return data_ctx.get_size_written();
+}
+
+size_t llama_state_seq_get_size(struct llama_context * ctx, llama_seq_id seq_id) {
+    llama_data_write_dummy data_ctx;
+    return llama_state_seq_get_data_internal(ctx, data_ctx, seq_id);
+}
+
+size_t llama_state_seq_get_data(struct llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id) {
+    llama_data_write_buffer data_ctx(dst, size);
+    try {
+        return llama_state_seq_get_data_internal(ctx, data_ctx, seq_id);
+    } catch (const std::exception & err) {
+        LLAMA_LOG_ERROR("%s: error saving sequence state: %s\n", __func__, err.what());
+        return 0;
+    }
+}
+
+static size_t llama_state_seq_set_data_internal(struct llama_context * ctx, llama_data_read & data_ctx, llama_seq_id dest_seq_id) {
+    llama_synchronize(ctx);
+
+    data_ctx.read_kv_cache(ctx, dest_seq_id);
+
+    return data_ctx.get_size_read();
+}
+
+size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id dest_seq_id) {
+    llama_data_read_buffer data_ctx(src, size);
+    try {
+        return llama_state_seq_set_data_internal(ctx, data_ctx, dest_seq_id);
+    } catch (const std::exception & err) {
+        LLAMA_LOG_ERROR("%s: error loading sequence state: %s\n", __func__, err.what());
+        return 0;
+    }
+}
+
+static size_t llama_state_seq_save_file_internal(struct llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) {
+    llama_file file(filepath, "wb");
+
+    file.write_u32(LLAMA_STATE_SEQ_MAGIC);
+    file.write_u32(LLAMA_STATE_SEQ_VERSION);
+
+    // save the prompt
+    file.write_u32((uint32_t) n_token_count);
+    file.write_raw(tokens, sizeof(llama_token) * n_token_count);
+
+    // save the context state using stream saving
+    llama_data_write_file data_ctx(&file);
+    llama_state_seq_get_data_internal(ctx, data_ctx, seq_id);
+
+    const size_t res = file.tell();
+    GGML_ASSERT(res == sizeof(uint32_t) * 3 + sizeof(llama_token) * n_token_count + data_ctx.get_size_written());
+    return res;
+}
+
+static size_t llama_state_seq_load_file_internal(struct llama_context * ctx, const char * filepath, llama_seq_id dest_seq_id, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
+    llama_file file(filepath, "rb");
+
+    // version checks
+    {
+        const uint32_t magic   = file.read_u32();
+        const uint32_t version = file.read_u32();
+
+        if (magic != LLAMA_STATE_SEQ_MAGIC || version != LLAMA_STATE_SEQ_VERSION) {
+            LLAMA_LOG_ERROR("%s: unknown (magic, version) for sequence state file: %08x, %08x\n", __func__, magic, version);
+            return 0;
+        }
+    }
+
+    // load the prompt
+    {
+        const uint32_t n_token_count = file.read_u32();
+
+        if (n_token_count > n_token_capacity) {
+            LLAMA_LOG_ERROR("%s: token count in sequence state file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity);
+            return 0;
+        }
+
+        file.read_raw(tokens_out, sizeof(llama_token) * n_token_count);
+        *n_token_count_out = n_token_count;
+    }
+
+    // restore the context state
+    {
+        const size_t state_size = file.size - file.tell();
+        llama_data_read_file data_ctx(&file);
+        const size_t nread = llama_state_seq_set_data_internal(ctx, data_ctx, dest_seq_id);
+        if (!nread) {
+            LLAMA_LOG_ERROR("%s: failed to restore sequence state\n", __func__);
+            return 0;
+        }
+        GGML_ASSERT(nread <= state_size);
+        GGML_ASSERT(nread + sizeof(uint32_t) * 3 + sizeof(llama_token) * *n_token_count_out == file.tell());
+    }
+
+    return file.tell();
+}
+
+size_t llama_state_seq_save_file(struct llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) {
+    try {
+        return llama_state_seq_save_file_internal(ctx, filepath, seq_id, tokens, n_token_count);
+    } catch (const std::exception & err) {
+        LLAMA_LOG_ERROR("%s: error saving sequence state file: %s\n", __func__, err.what());
+        return 0;
+    }
+}
+
+size_t llama_state_seq_load_file(struct llama_context * ctx, const char * filepath, llama_seq_id dest_seq_id, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
+    try {
+        return llama_state_seq_load_file_internal(ctx, filepath, dest_seq_id, tokens_out, n_token_capacity, n_token_count_out);
+    } catch (const std::exception & err) {
+        LLAMA_LOG_ERROR("%s: error loading sequence state file: %s\n", __func__, err.what());
+        return 0;
+    }
+}
+
+void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch) {
+    ctx->cparams.n_threads       = n_threads;
+    ctx->cparams.n_threads_batch = n_threads_batch;
+}
+
+uint32_t llama_n_threads(struct llama_context * ctx) {
+    return ctx->cparams.n_threads;
+}
+
+uint32_t llama_n_threads_batch(struct llama_context * ctx) {
+    return ctx->cparams.n_threads_batch;
+}
+
+void llama_set_abort_callback(struct llama_context * ctx, bool (*abort_callback)(void * data), void * abort_callback_data) {
+    ctx->abort_callback      = abort_callback;
+    ctx->abort_callback_data = abort_callback_data;
+}
+
+void llama_set_embeddings(struct llama_context * ctx, bool embeddings) {
+    ctx->cparams.embeddings = embeddings;
+}
+
+void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn) {
+    ctx->cparams.causal_attn = causal_attn;
+}
+
+struct llama_batch llama_batch_get_one(
+             llama_token * tokens,
+                 int32_t   n_tokens,
+               llama_pos   pos_0,
+            llama_seq_id   seq_id) {
+    return {
+        /*n_tokens       =*/ n_tokens,
+        /*tokens         =*/ tokens,
+        /*embd           =*/ nullptr,
+        /*pos            =*/ nullptr,
+        /*n_seq_id       =*/ nullptr,
+        /*seq_id         =*/ nullptr,
+        /*logits         =*/ nullptr,
+        /*all_pos_0      =*/ pos_0,
+        /*all_pos_1      =*/ 1,
+        /*all_seq_id     =*/ seq_id,
+    };
+}
+
+struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) {
+    llama_batch batch = { 0, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, 0, 0, 0, };
+
+    if (embd) {
+        batch.embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd);
+    } else {
+        batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc);
+    }
+
+    batch.pos      = (llama_pos *)     malloc(sizeof(llama_pos)      * n_tokens_alloc);
+    batch.n_seq_id = (int32_t *)       malloc(sizeof(int32_t)        * n_tokens_alloc);
+    batch.seq_id   = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * (n_tokens_alloc + 1));
+    for (int i = 0; i < n_tokens_alloc; ++i) {
+        batch.seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
+    }
+    batch.seq_id[n_tokens_alloc] = nullptr;
+
+    batch.logits   = (int8_t *)        malloc(sizeof(int8_t)         * n_tokens_alloc);
+
+    return batch;
+}
+
+void llama_batch_free(struct llama_batch batch) {
+    if (batch.token)    free(batch.token);
+    if (batch.embd)     free(batch.embd);
+    if (batch.pos)      free(batch.pos);
+    if (batch.n_seq_id) free(batch.n_seq_id);
+    if (batch.seq_id) {
+        for (int i = 0; batch.seq_id[i] != nullptr; ++i) {
+            free(batch.seq_id[i]);
+        }
+        free(batch.seq_id);
+    }
+    if (batch.logits)   free(batch.logits);
+}
+
+int32_t llama_encode(
+        struct llama_context * ctx,
+          struct llama_batch   batch) {
+    const int ret = llama_encode_internal(*ctx, batch);
+    if (ret < 0) {
+        LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret);
+    }
+
+    return ret;
+}
+
+int32_t llama_decode(
+        struct llama_context * ctx,
+          struct llama_batch   batch) {
+    const int ret = llama_decode_internal(*ctx, batch);
+    if (ret < 0) {
+        LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
+    }
+
+    return ret;
+}
+
+void llama_synchronize(struct llama_context * ctx) {
+    ggml_backend_sched_synchronize(ctx->sched);
+
+    // FIXME: if multiple single tokens are evaluated without a synchronization,
+    // the stats will be added to the prompt evaluation stats
+    // this should only happen when using batch size 1 to evaluate a batch
+
+    // add the evaluation to the stats
+    if (ctx->n_queued_tokens == 1) {
+        ctx->t_eval_us += ggml_time_us() - ctx->t_compute_start_us;
+        ctx->n_eval++;
+    } else if (ctx->n_queued_tokens > 1) {
+        ctx->t_p_eval_us += ggml_time_us() - ctx->t_compute_start_us;
+        ctx->n_p_eval += ctx->n_queued_tokens;
+    }
+
+    // get a more accurate load time, upon first eval
+    if (ctx->n_queued_tokens > 0 && !ctx->has_evaluated_once) {
+        ctx->t_load_us = ggml_time_us() - ctx->t_start_us;
+        ctx->has_evaluated_once = true;
+    }
+
+    ctx->n_queued_tokens = 0;
+    ctx->t_compute_start_us = 0;
+}
+
+float * llama_get_logits(struct llama_context * ctx) {
+    llama_synchronize(ctx);
+
+    return ctx->logits;
+}
+
+float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
+    int32_t j = -1;
+    llama_synchronize(ctx);
+
+    try {
+        if (ctx->logits == nullptr) {
+            throw std::runtime_error("no logits");
+        }
+
+        if (i < 0) {
+            j = ctx->n_outputs + i;
+            if (j < 0) {
+                throw std::runtime_error(format("negative index out of range [0, %d)", ctx->n_outputs));
+            }
+        } else if ((size_t) i >= ctx->output_ids.size()) {
+            throw std::runtime_error(format("out of range [0, %lu)", ctx->output_ids.size()));
+        } else {
+            j = ctx->output_ids[i];
+        }
+
+        if (j < 0) {
+            throw std::runtime_error(format("batch.logits[%d] != true", i));
+        }
+        if (j >= ctx->n_outputs) {
+            // This should not happen
+            throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs));
+        }
+
+        return ctx->logits + j*ctx->model.hparams.n_vocab;
+    } catch (const std::exception & err) {
+        LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what());
+#ifndef NDEBUG
+        GGML_ABORT("fatal error");
+#endif
+        return nullptr;
+    }
+}
+
+float * llama_get_embeddings(struct llama_context * ctx) {
+    llama_synchronize(ctx);
+
+    return ctx->embd;
+}
+
+float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) {
+    int32_t j = -1;
+
+    llama_synchronize(ctx);
+
+    try {
+        if (ctx->embd == nullptr) {
+            throw std::runtime_error("no embeddings");
+        }
+
+        if (i < 0) {
+            j = ctx->n_outputs + i;
+            if (j < 0) {
+                throw std::runtime_error(format("negative index out of range [0, %d)", ctx->n_outputs));
+            }
+        } else if ((size_t) i >= ctx->output_ids.size()) {
+            throw std::runtime_error(format("out of range [0, %lu)", ctx->output_ids.size()));
+        } else {
+            j = ctx->output_ids[i];
+        }
+
+        if (j < 0) {
+            throw std::runtime_error(format("batch.logits[%d] != true", i));
+        }
+        if (j >= ctx->n_outputs) {
+            // This should not happen
+            throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs));
+        }
+
+        return ctx->embd + j*ctx->model.hparams.n_embd;
+    } catch (const std::exception & err) {
+        LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what());
+#ifndef NDEBUG
+        GGML_ABORT("fatal error");
+#endif
+        return nullptr;
+    }
+}
+
+float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id) {
+    llama_synchronize(ctx);
+
+    auto it = ctx->embd_seq.find(seq_id);
+    if (it == ctx->embd_seq.end()) {
+        return nullptr;
+    }
+
+    return it->second.data();
+}
+
+//
+// vocab
+//
+
+const char * llama_token_get_text(const struct llama_model * model, llama_token token) {
+    return llama_token_get_text_impl(model->vocab, token);
+}
+
+float llama_token_get_score(const struct llama_model * model, llama_token token) {
+    return llama_token_get_score_impl(model->vocab, token);
+}
+
+enum llama_token_attr llama_token_get_attr(const struct llama_model * model, llama_token token) {
+    return llama_token_get_attr_impl(model->vocab, token);
+}
+
+bool llama_token_is_eog(const struct llama_model * model, llama_token token) {
+    return llama_token_is_eog_impl(model->vocab, token);
+}
+
+bool llama_token_is_control(const struct llama_model * model, llama_token token) {
+    return llama_token_is_control_impl(model->vocab, token);
+}
+
+llama_token llama_token_bos(const struct llama_model * model) {
+    return llama_token_bos_impl(model->vocab);
+}
+
+llama_token llama_token_eos(const struct llama_model * model) {
+    return llama_token_eos_impl(model->vocab);
+}
+
+llama_token llama_token_cls(const struct llama_model * model) {
+    return llama_token_cls_impl(model->vocab);
+}
+
+llama_token llama_token_sep(const struct llama_model * model) {
+    return llama_token_sep_impl(model->vocab);
+}
+
+llama_token llama_token_nl (const struct llama_model * model) {
+    return llama_token_nl_impl(model->vocab);
+}
+
+llama_token llama_token_pad(const struct llama_model * model) {
+    return llama_token_pad_impl(model->vocab);
+}
+
+int32_t llama_add_bos_token(const struct llama_model * model) {
+    return llama_add_bos_token_impl(model->vocab);
+}
+
+int32_t llama_add_eos_token(const struct llama_model * model) {
+    return llama_add_eos_token_impl(model->vocab);
+}
+
+llama_token llama_token_prefix(const struct llama_model * model) {
+    return llama_token_prefix_impl(model->vocab);
+}
+
+llama_token llama_token_middle(const struct llama_model * model) {
+    return llama_token_middle_impl(model->vocab);
+}
+
+llama_token llama_token_suffix(const struct llama_model * model) {
+    return llama_token_suffix_impl(model->vocab);
+}
+
+llama_token llama_token_eot(const struct llama_model * model) {
+    return llama_token_eot_impl(model->vocab);
+}
+
+//
+// tokenization
+//
+
+int32_t llama_tokenize(
+    const struct llama_model * model,
+                  const char * text,
+                     int32_t   text_len,
+                 llama_token * tokens,
+                     int32_t   n_tokens_max,
+                        bool   add_special,
+                        bool   parse_special) {
+    return llama_tokenize_impl(model->vocab, text, text_len, tokens, n_tokens_max, add_special, parse_special);
+}
+
+int32_t llama_token_to_piece(
+    const struct llama_model * model,
+                 llama_token   token,
+                        char * buf,
+                     int32_t   length,
+                     int32_t   lstrip,
+                        bool   special) {
+    return llama_token_to_piece_impl(model->vocab, token, buf, length, lstrip, special);
+}
+
+int32_t llama_detokenize(
+    const struct llama_model * model,
+           const llama_token * tokens,
+                     int32_t   n_tokens,
+                        char * text,
+                     int32_t   text_len_max,
+                        bool   remove_special,
+                        bool   unparse_special) {
+    return llama_detokenize_impl(model->vocab, tokens, n_tokens, text, text_len_max, remove_special, unparse_special);
+}
+
+//
+// chat templates
+//
+
+// Simple version of "llama_apply_chat_template" that only works with strings
+// This function uses heuristic checks to determine commonly used template. It is not a jinja parser.
+static int32_t llama_chat_apply_template_internal(
+    const std::string & tmpl,
+    const std::vector & chat,
+    std::string & dest, bool add_ass) {
+    // Taken from the research: https://github.com/ggerganov/llama.cpp/issues/5527
+    std::stringstream ss;
+    auto tmpl_contains = [&tmpl](std::string haystack) -> bool {
+        return tmpl.find(haystack) != std::string::npos;
+    };
+    if (tmpl == "chatml" || tmpl_contains("<|im_start|>")) {
+        // chatml template
+        for (auto message : chat) {
+            ss << "<|im_start|>" << message->role << "\n" << message->content << "<|im_end|>\n";
+        }
+        if (add_ass) {
+            ss << "<|im_start|>assistant\n";
+        }
+    } else if (tmpl == "llama2" || tmpl == "mistral" || tmpl_contains("[INST]")) {
+        // llama2 template and its variants
+        // [variant] support system message
+        bool support_system_message = tmpl_contains("<>") || tmpl == "mistral";
+        // [variant] space before + after response
+        bool space_around_response = tmpl_contains("' ' + eos_token");
+        // [variant] add BOS inside history
+        bool add_bos_inside_history = tmpl_contains("bos_token + '[INST]");
+        // [variant] trim spaces from the input message
+        bool strip_message = tmpl_contains("content.strip()");
+        // construct the prompt
+        bool is_inside_turn = true; // skip BOS at the beginning
+        ss << "[INST] ";
+        for (auto message : chat) {
+            std::string content = strip_message ? trim(message->content) : message->content;
+            std::string role(message->role);
+            if (!is_inside_turn) {
+                is_inside_turn = true;
+                ss << (add_bos_inside_history ? "[INST] " : "[INST] ");
+            }
+            if (role == "system") {
+                if (support_system_message) {
+                    ss << "<>\n" << content << "\n<>\n\n";
+                } else {
+                    // if the model does not support system message, we still include it in the first message, but without <>
+                    ss << content << "\n";
+                }
+            } else if (role == "user") {
+                ss << content << " [/INST]";
+            } else {
+                ss << (space_around_response ? " " : "") << content << (space_around_response ? " " : "") << "";
+                is_inside_turn = false;
+            }
+        }
+        // llama2 templates seem to not care about "add_generation_prompt"
+    } else if (tmpl == "phi3" || (tmpl_contains("<|assistant|>") && tmpl_contains("<|end|>"))) {
+        // Phi 3
+        for (auto message : chat) {
+            std::string role(message->role);
+            ss << "<|" << role << "|>\n" << message->content << "<|end|>\n";
+        }
+        if (add_ass) {
+            ss << "<|assistant|>\n";
+        }
+    } else if (tmpl == "zephyr" || tmpl_contains("<|user|>")) {
+        // zephyr template
+        for (auto message : chat) {
+            ss << "<|" << message->role << "|>" << "\n" << message->content << "<|endoftext|>\n";
+        }
+        if (add_ass) {
+            ss << "<|assistant|>\n";
+        }
+    } else if (tmpl == "monarch" || tmpl_contains("bos_token + message['role']")) {
+        // mlabonne/AlphaMonarch-7B template (the  is included inside history)
+        for (auto message : chat) {
+            std::string bos = (message == chat.front()) ? "" : ""; // skip BOS for first message
+            ss << bos << message->role << "\n" << message->content << "\n";
+        }
+        if (add_ass) {
+            ss << "assistant\n";
+        }
+    } else if (tmpl == "gemma" || tmpl == "gemma2" || tmpl_contains("")) {
+        // google/gemma-7b-it
+        std::string system_prompt = "";
+        for (auto message : chat) {
+            std::string role(message->role);
+            if (role == "system") {
+                // there is no system message for gemma, but we will merge it with user prompt, so nothing is broken
+                system_prompt = trim(message->content);
+                continue;
+            }
+            // in gemma, "assistant" is "model"
+            role = role == "assistant" ? "model" : message->role;
+            ss << "" << role << "\n";
+            if (!system_prompt.empty() && role != "model") {
+                ss << system_prompt << "\n\n";
+                system_prompt = "";
+            }
+            ss << trim(message->content) << "\n";
+        }
+        if (add_ass) {
+            ss << "model\n";
+        }
+    } else if (tmpl == "orion" || tmpl_contains("'\\n\\nAssistant: ' + eos_token")) {
+        // OrionStarAI/Orion-14B-Chat
+        std::string system_prompt = "";
+        for (auto message : chat) {
+            std::string role(message->role);
+            if (role == "system") {
+                // there is no system message support, we will merge it with user prompt
+                system_prompt = message->content;
+                continue;
+            } else if (role == "user") {
+                ss << "Human: ";
+                if (!system_prompt.empty()) {
+                    ss << system_prompt << "\n\n";
+                    system_prompt = "";
+                }
+                ss << message->content << "\n\nAssistant: ";
+            } else {
+                ss << message->content << "";
+            }
+        }
+    } else if (tmpl == "openchat" || tmpl_contains("GPT4 Correct ")) {
+        // openchat/openchat-3.5-0106,
+        for (auto message : chat) {
+            std::string role(message->role);
+            if (role == "system") {
+                ss << message->content << "<|end_of_turn|>";
+            } else {
+                role[0] = toupper(role[0]);
+                ss << "GPT4 Correct " << role << ": " << message->content << "<|end_of_turn|>";
+            }
+        }
+        if (add_ass) {
+            ss << "GPT4 Correct Assistant:";
+        }
+    } else if (tmpl == "vicuna" || tmpl == "vicuna-orca" || (tmpl_contains("USER: ") && tmpl_contains("ASSISTANT: "))) {
+        // eachadea/vicuna-13b-1.1 (and Orca variant)
+        for (auto message : chat) {
+            std::string role(message->role);
+            if (role == "system") {
+                // Orca-Vicuna variant uses a system prefix
+                if (tmpl == "vicuna-orca" || tmpl_contains("SYSTEM: ")) {
+                    ss << "SYSTEM: " << message->content << "\n";
+                } else {
+                    ss << message->content << "\n\n";
+                }
+            } else if (role == "user") {
+                ss << "USER: " << message->content << "\n";
+            } else if (role == "assistant") {
+                ss << "ASSISTANT: " << message->content << "\n";
+            }
+        }
+        if (add_ass) {
+            ss << "ASSISTANT:";
+        }
+    } else if (tmpl == "deepseek" || (tmpl_contains("### Instruction:") && tmpl_contains("<|EOT|>"))) {
+        // deepseek-ai/deepseek-coder-33b-instruct
+        for (auto message : chat) {
+            std::string role(message->role);
+            if (role == "system") {
+                ss << message->content;
+            } else if (role == "user") {
+                ss << "### Instruction:\n" << message->content << "\n";
+            } else if (role == "assistant") {
+                ss << "### Response:\n" << message->content << "\n<|EOT|>\n";
+            }
+        }
+        if (add_ass) {
+            ss << "### Response:\n";
+        }
+    } else if (tmpl == "command-r" || (tmpl_contains("<|START_OF_TURN_TOKEN|>") && tmpl_contains("<|USER_TOKEN|>"))) {
+        // CohereForAI/c4ai-command-r-plus
+        for (auto message : chat) {
+            std::string role(message->role);
+            if (role == "system") {
+                ss << "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>" << trim(message->content) << "<|END_OF_TURN_TOKEN|>";
+            } else if (role == "user") {
+                ss << "<|START_OF_TURN_TOKEN|><|USER_TOKEN|>" << trim(message->content) << "<|END_OF_TURN_TOKEN|>";
+            } else if (role == "assistant") {
+                ss << "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>" << trim(message->content) << "<|END_OF_TURN_TOKEN|>";
+            }
+        }
+        if (add_ass) {
+            ss << "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>";
+        }
+    } else if (tmpl == "llama3" || (tmpl_contains("<|start_header_id|>") && tmpl_contains("<|end_header_id|>"))) {
+        // Llama 3
+        for (auto message : chat) {
+            std::string role(message->role);
+            ss << "<|start_header_id|>" << role << "<|end_header_id|>\n\n" << trim(message->content) << "<|eot_id|>";
+        }
+        if (add_ass) {
+            ss << "<|start_header_id|>assistant<|end_header_id|>\n\n";
+        }
+    } else if (tmpl == "chatglm3" || tmpl_contains("[gMASK]sop")) {
+        // chatglm3-6b
+        ss << "[gMASK]" << "sop";
+        for (auto message : chat) {
+            std::string role(message->role);
+            ss << "<|" << role << "|>" << "\n " << message->content;
+        }
+        if (add_ass) {
+            ss << "<|assistant|>";
+        }
+    } else if (tmpl == "chatglm4" || tmpl_contains("[gMASK]")) {
+        ss << "[gMASK]" << "";
+        for (auto message : chat) {
+            std::string role(message->role);
+            ss << "<|" << role << "|>" << "\n" << message->content;
+        }
+        if (add_ass) {
+            ss << "<|assistant|>";
+        }
+    } else if (tmpl == "minicpm" || tmpl_contains(LU8("<用户>"))) {
+        // MiniCPM-3B-OpenHermes-2.5-v2-GGUF
+        for (auto message : chat) {
+            std::string role(message->role);
+            if (role == "user") {
+                ss << LU8("<用户>");
+                ss << trim(message->content);
+                ss << "";
+            } else {
+                ss << trim(message->content);
+            }
+        }
+    } else if (tmpl == "deepseek2" || tmpl_contains("'Assistant: ' + message['content'] + eos_token")) {
+        // DeepSeek-V2
+        for (auto message : chat) {
+            std::string role(message->role);
+            if (role == "system") {
+                ss << message->content << "\n\n";
+            } else if (role == "user") {
+                ss << "User: " << message->content << "\n\n";
+            } else if (role == "assistant") {
+                ss << "Assistant: " << message->content << LU8("<|end▁of▁sentence|>");
+            }
+        }
+        if (add_ass) {
+            ss << "Assistant:";
+        }
+    } else {
+        // template not supported
+        return -1;
+    }
+    dest = ss.str();
+    return dest.size();
+}
+
+int32_t llama_chat_apply_template(
+                const struct llama_model * model,
+                              const char * tmpl,
+         const struct llama_chat_message * chat,
+                                  size_t   n_msg,
+                                    bool   add_ass,
+                                    char * buf,
+                                 int32_t   length) {
+    std::string curr_tmpl(tmpl == nullptr ? "" : tmpl);
+    if (tmpl == nullptr) {
+        GGML_ASSERT(model != nullptr);
+        // load template from model
+        std::vector model_template(2048, 0); // longest known template is about 1200 bytes
+        std::string template_key = "tokenizer.chat_template";
+        int32_t res = llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), model_template.size());
+        if (res < 0) {
+            // worst case: there is no information about template, we will use chatml by default
+            curr_tmpl = "chatml"; // see llama_chat_apply_template_internal
+        } else {
+            curr_tmpl = std::string(model_template.data(), model_template.size());
+        }
+    }
+
+    // format the chat to string
+    std::vector chat_vec;
+    chat_vec.resize(n_msg);
+    for (size_t i = 0; i < n_msg; i++) {
+        chat_vec[i] = &chat[i];
+    }
+
+    std::string formatted_chat;
+    int32_t res = llama_chat_apply_template_internal(curr_tmpl, chat_vec, formatted_chat, add_ass);
+    if (res < 0) {
+        return res;
+    }
+    if (buf && length > 0) {
+        strncpy(buf, formatted_chat.c_str(), length);
+    }
+    return res;
+}
+
+//
+// grammar
+//
+
+struct llama_grammar * llama_grammar_init(
+        const llama_grammar_element ** rules,
+        size_t    n_rules,
+        size_t    start_rule_index) {
+    return llama_grammar_init_impl(rules, n_rules, start_rule_index);
+}
+
+void llama_grammar_free(struct llama_grammar * grammar) {
+    llama_grammar_free_impl(grammar);
+}
+
+struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar) {
+    return llama_grammar_copy_impl(grammar);
+}
+
+void llama_grammar_sample(
+      const struct llama_grammar * grammar,
+      const struct llama_context * ctx,
+          llama_token_data_array * candidates) {
+    llama_grammar_sample_impl(grammar, &ctx->model.vocab, &ctx->sampling, candidates);
+}
+
+void llama_sample_grammar(
+            struct llama_context * ctx,
+          llama_token_data_array * candidates,
+      const struct llama_grammar * grammar) {
+    llama_grammar_sample(grammar, ctx, candidates);
+}
+
+void llama_grammar_accept_token(
+            struct llama_grammar * grammar,
+            struct llama_context * ctx,
+                     llama_token   token) {
+    llama_grammar_accept_token_impl(grammar, &ctx->model.vocab, &ctx->sampling, token);
+}
+
+//
+// sampling
+//
+
+void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed) {
+    llama_set_rng_seed_impl(&ctx->sampling, seed);
+}
+
+void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates) {
+    llama_sample_softmax_impl(ctx ? &ctx->sampling : nullptr, candidates);
+}
+
+void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int32_t k, size_t min_keep) {
+    llama_sample_top_k_impl(ctx ? &ctx->sampling : nullptr, candidates, k, min_keep);
+}
+
+void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) {
+    llama_sample_top_p_impl(ctx ? &ctx->sampling : nullptr, candidates, p, min_keep);
+}
+
+void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) {
+    llama_sample_min_p_impl(ctx ? &ctx->sampling : nullptr, candidates, p, min_keep);
+}
+
+void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep) {
+    llama_sample_tail_free_impl(ctx ? &ctx->sampling : nullptr, candidates, z, min_keep);
+}
+
+void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) {
+    llama_sample_typical_impl(ctx ? &ctx->sampling : nullptr, candidates, p, min_keep);
+}
+
+void llama_sample_entropy(struct llama_context * ctx, llama_token_data_array * candidates_p, float min_temp, float max_temp, float exponent_val) {
+    llama_sample_entropy_impl(ctx ? &ctx->sampling : nullptr, candidates_p, min_temp, max_temp, exponent_val);
+}
+
+void llama_sample_temp(struct llama_context * ctx, llama_token_data_array * candidates_p, float temp) {
+    llama_sample_temp_impl(ctx ? &ctx->sampling : nullptr, candidates_p, temp);
+}
+
+void llama_sample_repetition_penalties(
+            struct llama_context * ctx,
+          llama_token_data_array * candidates,
+               const llama_token * last_tokens,
+                          size_t   penalty_last_n,
+                           float   penalty_repeat,
+                           float   penalty_freq,
+                           float   penalty_present) {
+    llama_sample_repetition_penalties_impl(ctx ? &ctx->sampling : nullptr, candidates, last_tokens, penalty_last_n, penalty_repeat, penalty_freq, penalty_present);
+}
+
+void llama_sample_apply_guidance(
+          struct llama_context * ctx,
+                         float * logits,
+                         float * logits_guidance,
+                         float   scale) {
+    llama_sample_apply_guidance_impl(&ctx->sampling, logits, logits_guidance, scale);
+}
+
+llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) {
+    return llama_sample_token_mirostat_impl(&ctx->sampling, candidates, tau, eta, m, mu);
+}
+
+llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, float * mu) {
+    return llama_sample_token_mirostat_v2_impl(ctx ? &ctx->sampling : nullptr, candidates, tau, eta, mu);
+}
+
+llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_data_array * candidates) {
+    return llama_sample_token_greedy_impl(ctx ? &ctx->sampling : nullptr, candidates);
+}
+
+llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_data_array * candidates, std::mt19937 & rng) {
+    return llama_sample_token_with_rng_impl(&ctx->sampling, candidates, rng);
+}
+
+llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates) {
+    return llama_sample_token_with_rng_impl(&ctx->sampling, candidates, ctx->sampling.rng);
+}
+
+int llama_split_path(char * split_path, size_t maxlen, const char * path_prefix, int split_no, int split_count) {
+    static const char * const SPLIT_PATH_FORMAT = "%s-%05d-of-%05d.gguf";
+    if (snprintf(split_path, maxlen, SPLIT_PATH_FORMAT, path_prefix, split_no + 1, split_count)) {
+        return strlen(split_path);
+    }
+    return 0;
+}
+
+int llama_split_prefix(char * dest, size_t maxlen, const char * split_path, int split_no, int split_count) {
+    std::string str_split_path(split_path);
+    char postfix[32];
+    snprintf(postfix, 32, "-%05d-of-%05d.gguf", split_no + 1, split_count);
+    std::string str_postfix(postfix);
+
+    // check if dest ends with postfix
+    int size_prefix = str_split_path.size() - str_postfix.size();
+    if (size_prefix > 0 && str_split_path.find(str_postfix, size_prefix) != std::string::npos) {
+        snprintf(dest, std::min((size_t) size_prefix + 1, maxlen), "%s", split_path);
+        return size_prefix;
+    }
+
+    return 0;
+}
+
+struct llama_timings llama_get_timings(struct llama_context * ctx) {
+    struct llama_timings result = {
+        /*.t_start_ms  =*/ 1e-3 * ctx->t_start_us,
+        /*.t_end_ms    =*/ 1.00 * ggml_time_ms(),
+        /*.t_load_ms   =*/ 1e-3 * ctx->t_load_us,
+        /*.t_sample_ms =*/ 1e-3 * ctx->sampling.t_sample_us,
+        /*.t_p_eval_ms =*/ 1e-3 * ctx->t_p_eval_us,
+        /*.t_eval_ms   =*/ 1e-3 * ctx->t_eval_us,
+
+        /*.n_sample =*/ std::max(1, ctx->sampling.n_sample),
+        /*.n_p_eval =*/ std::max(0, ctx->n_p_eval),
+        /*.n_eval   =*/ std::max(1, ctx->n_eval),
+    };
+
+    return result;
+}
+
+void llama_print_timings(struct llama_context * ctx) {
+    const llama_timings timings = llama_get_timings(ctx);
+
+    LLAMA_LOG_INFO("\n");
+    LLAMA_LOG_INFO("%s:        load time = %10.2f ms\n", __func__, timings.t_load_ms);
+    LLAMA_LOG_INFO("%s:      sample time = %10.2f ms / %5d runs   (%8.2f ms per token, %8.2f tokens per second)\n",
+            __func__, timings.t_sample_ms, timings.n_sample, timings.t_sample_ms / timings.n_sample, 1e3 / timings.t_sample_ms * timings.n_sample);
+    LLAMA_LOG_INFO("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n",
+            __func__, timings.t_p_eval_ms, timings.n_p_eval, timings.t_p_eval_ms / timings.n_p_eval, 1e3 / timings.t_p_eval_ms * timings.n_p_eval);
+    LLAMA_LOG_INFO("%s:        eval time = %10.2f ms / %5d runs   (%8.2f ms per token, %8.2f tokens per second)\n",
+            __func__, timings.t_eval_ms, timings.n_eval, timings.t_eval_ms / timings.n_eval, 1e3 / timings.t_eval_ms * timings.n_eval);
+    LLAMA_LOG_INFO("%s:       total time = %10.2f ms / %5d tokens\n", __func__, (timings.t_end_ms - timings.t_start_ms), (timings.n_p_eval + timings.n_eval));
+}
+
+void llama_reset_timings(struct llama_context * ctx) {
+    ctx->t_start_us  = ggml_time_us();
+    ctx->t_eval_us   = ctx->n_eval   = 0;
+    ctx->t_p_eval_us = ctx->n_p_eval = 0;
+
+    ctx->sampling.reset_timings();
+}
+
+const char * llama_print_system_info(void) {
+    static std::string s;
+
+    s  = "";
+    s += "AVX = "         + std::to_string(ggml_cpu_has_avx())         + " | ";
+    s += "AVX_VNNI = "    + std::to_string(ggml_cpu_has_avx_vnni())    + " | ";
+    s += "AVX2 = "        + std::to_string(ggml_cpu_has_avx2())        + " | ";
+    s += "AVX512 = "      + std::to_string(ggml_cpu_has_avx512())      + " | ";
+    s += "AVX512_VBMI = " + std::to_string(ggml_cpu_has_avx512_vbmi()) + " | ";
+    s += "AVX512_VNNI = " + std::to_string(ggml_cpu_has_avx512_vnni()) + " | ";
+    s += "AVX512_BF16 = " + std::to_string(ggml_cpu_has_avx512_bf16()) + " | ";
+    s += "FMA = "         + std::to_string(ggml_cpu_has_fma())         + " | ";
+    s += "NEON = "        + std::to_string(ggml_cpu_has_neon())        + " | ";
+    s += "SVE = "         + std::to_string(ggml_cpu_has_sve())         + " | ";
+    s += "ARM_FMA = "     + std::to_string(ggml_cpu_has_arm_fma())     + " | ";
+    s += "F16C = "        + std::to_string(ggml_cpu_has_f16c())        + " | ";
+    s += "FP16_VA = "     + std::to_string(ggml_cpu_has_fp16_va())     + " | ";
+    s += "WASM_SIMD = "   + std::to_string(ggml_cpu_has_wasm_simd())   + " | ";
+    s += "BLAS = "        + std::to_string(ggml_cpu_has_blas())        + " | ";
+    s += "SSE3 = "        + std::to_string(ggml_cpu_has_sse3())        + " | ";
+    s += "SSSE3 = "       + std::to_string(ggml_cpu_has_ssse3())       + " | ";
+    s += "VSX = "         + std::to_string(ggml_cpu_has_vsx())         + " | ";
+    s += "MATMUL_INT8 = " + std::to_string(ggml_cpu_has_matmul_int8()) + " | ";
+    s += "LLAMAFILE = "   + std::to_string(ggml_cpu_has_llamafile())   + " | ";
+
+    return s.c_str();
+}
+
+void llama_dump_timing_info_yaml(FILE * stream, const llama_context * ctx) {
+    fprintf(stream, "\n");
+    fprintf(stream, "###########\n");
+    fprintf(stream, "# Timings #\n");
+    fprintf(stream, "###########\n");
+    fprintf(stream, "\n");
+
+    fprintf(stream, "mst_eval: %.2f  # ms / token during generation\n",
+            1.0e-3 * ctx->t_eval_us / ctx->n_eval);
+    fprintf(stream, "mst_p_eval: %.2f  # ms / token during prompt processing\n",
+            1.0e-3 * ctx->t_p_eval_us / ctx->n_p_eval);
+    fprintf(stream, "mst_sample: %.2f  # ms / token during sampling\n",
+            1.0e-3 * ctx->sampling.t_sample_us / ctx->sampling.n_sample);
+    fprintf(stream, "n_eval: %d  # number of tokens generated (excluding the first one)\n", ctx->n_eval);
+    fprintf(stream, "n_p_eval: %d  # number of tokens processed in batches at the beginning\n", ctx->n_p_eval);
+    fprintf(stream, "n_sample: %d  # number of sampled tokens\n", ctx->sampling.n_sample);
+    fprintf(stream, "t_eval_us: %" PRId64 "  # total microseconds spent generating tokens\n", ctx->t_eval_us);
+    fprintf(stream, "t_load_us: %" PRId64 "  # total microseconds spent loading the model\n", ctx->t_load_us);
+    fprintf(stream, "t_p_eval_us: %" PRId64 "  # total microseconds spent prompt processing\n", ctx->t_p_eval_us);
+    fprintf(stream, "t_sample_us: %" PRId64 "  # total microseconds spent sampling\n", ctx->sampling.t_sample_us);
+    fprintf(stream, "ts_eval: %.2f  # tokens / second during generation\n",
+            1.0e6 * ctx->n_eval / ctx->t_eval_us);
+    fprintf(stream, "ts_p_eval: %.2f  # tokens / second during prompt processing\n",
+            1.0e6 * ctx->n_p_eval / ctx->t_p_eval_us);
+    fprintf(stream, "ts_sample: %.2f  # tokens / second during sampling\n",
+            1.0e6 * ctx->sampling.n_sample / ctx->sampling.t_sample_us);
+}
+
+// For internal test use
+const std::vector> & llama_internal_get_tensor_map(
+    struct llama_context * ctx
+) {
+    return ctx->model.tensors_by_name;
+}
+
+void llama_log_set(ggml_log_callback log_callback, void * user_data) {
+    g_state.log_callback = log_callback ? log_callback : llama_log_callback_default;
+    g_state.log_callback_user_data = user_data;
+#ifdef GGML_USE_METAL
+    ggml_backend_metal_log_set_callback(g_state.log_callback, g_state.log_callback_user_data);
+#elif defined(GGML_USE_CUDA)
+    ggml_backend_cuda_log_set_callback(g_state.log_callback, g_state.log_callback_user_data);
+#elif defined(GGML_USE_CANN)
+    ggml_backend_cann_log_set_callback(g_state.log_callback, g_state.log_callback_user_data);
+#endif
+}
+
+static void llama_log_internal_v(ggml_log_level level, const char * format, va_list args) {
+    va_list args_copy;
+    va_copy(args_copy, args);
+    char buffer[128];
+    int len = vsnprintf(buffer, 128, format, args);
+    if (len < 128) {
+        g_state.log_callback(level, buffer, g_state.log_callback_user_data);
+    } else {
+        char* buffer2 = new char[len+1];
+        vsnprintf(buffer2, len+1, format, args_copy);
+        buffer2[len] = 0;
+        g_state.log_callback(level, buffer2, g_state.log_callback_user_data);
+        delete[] buffer2;
+    }
+    va_end(args_copy);
+}
+
+void llama_log_internal(ggml_log_level level, const char * format, ...) {
+    va_list args;
+    va_start(args, format);
+    llama_log_internal_v(level, format, args);
+    va_end(args);
+}
+
+void llama_log_callback_default(ggml_log_level level, const char * text, void * user_data) {
+    (void) level;
+    (void) user_data;
+    fputs(text, stderr);
+    fflush(stderr);
+}
+
 static int llama_apply_lora_from_file_internal(
     const struct llama_model & model, const char * path_lora, float scale, const char * path_base_model, int n_threads
 ) {
@@ -15844,741 +19463,6 @@ static int llama_apply_lora_from_file_internal(
     return 0;
 }
 
-//
-// interface implementation
-//
-struct llama_model_params llama_model_default_params() {
-    struct llama_model_params result = {
-        /*.n_gpu_layers                =*/ 0,
-        /*.split_mode                  =*/ LLAMA_SPLIT_MODE_LAYER,
-        /*.main_gpu                    =*/ 0,
-        /*.tensor_split                =*/ nullptr,
-        /*.rpc_servers                 =*/ nullptr,
-        /*.progress_callback           =*/ nullptr,
-        /*.progress_callback_user_data =*/ nullptr,
-        /*.kv_overrides                =*/ nullptr,
-        /*.vocab_only                  =*/ false,
-        /*.use_mmap                    =*/ true,
-        /*.use_mlock                   =*/ false,
-        /*.check_tensors               =*/ false,
-    };
-
-#ifdef GGML_USE_METAL
-    // note: we usually have plenty of VRAM, so by default offload all layers to the GPU
-    result.n_gpu_layers = 999;
-#endif
-
-    return result;
-}
-
-struct llama_context_params llama_context_default_params() {
-    struct llama_context_params result = {
-        /*.seed                        =*/ LLAMA_DEFAULT_SEED,
-        /*.n_ctx                       =*/ 512,
-        /*.n_batch                     =*/ 2048,
-        /*.n_ubatch                    =*/ 512,
-        /*.n_seq_max                   =*/ 1,
-        /*.n_threads                   =*/ GGML_DEFAULT_N_THREADS, // TODO: better default
-        /*.n_threads_batch             =*/ GGML_DEFAULT_N_THREADS,
-        /*.rope_scaling_type           =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
-        /*.pooling_type                =*/ LLAMA_POOLING_TYPE_UNSPECIFIED,
-        /*.rope_freq_base              =*/ 0.0f,
-        /*.rope_freq_scale             =*/ 0.0f,
-        /*.yarn_ext_factor             =*/ -1.0f,
-        /*.yarn_attn_factor            =*/ 1.0f,
-        /*.yarn_beta_fast              =*/ 32.0f,
-        /*.yarn_beta_slow              =*/ 1.0f,
-        /*.yarn_orig_ctx               =*/ 0,
-        /*.defrag_thold                =*/ -1.0f,
-        /*.cb_eval                     =*/ nullptr,
-        /*.cb_eval_user_data           =*/ nullptr,
-        /*.type_k                      =*/ GGML_TYPE_F16,
-        /*.type_v                      =*/ GGML_TYPE_F16,
-        /*.logits_all                  =*/ false,
-        /*.embeddings                  =*/ false,
-        /*.offload_kqv                 =*/ true,
-        /*.flash_attn                  =*/ false,
-        /*.abort_callback              =*/ nullptr,
-        /*.abort_callback_data         =*/ nullptr,
-    };
-
-    return result;
-}
-
-struct llama_model_quantize_params llama_model_quantize_default_params() {
-    struct llama_model_quantize_params result = {
-        /*.nthread                     =*/ 0,
-        /*.ftype                       =*/ LLAMA_FTYPE_MOSTLY_Q5_1,
-        /*.output_tensor_type          =*/ GGML_TYPE_COUNT,
-        /*.token_embedding_type        =*/ GGML_TYPE_COUNT,
-        /*.allow_requantize            =*/ false,
-        /*.quantize_output_tensor      =*/ true,
-        /*.only_copy                   =*/ false,
-        /*.pure                        =*/ false,
-        /*.keep_split                  =*/ false,
-        /*.imatrix                     =*/ nullptr,
-        /*.kv_overrides                =*/ nullptr,
-    };
-
-    return result;
-}
-
-size_t llama_max_devices(void) {
-#if defined(GGML_USE_RPC)
-    return GGML_RPC_MAX_SERVERS;
-#elif defined(GGML_USE_METAL)
-    return 1;
-#elif defined(GGML_USE_CUDA)
-    return GGML_CUDA_MAX_DEVICES;
-#elif defined(GGML_USE_SYCL)
-    return GGML_SYCL_MAX_DEVICES;
-#elif defined(GGML_USE_VULKAN)
-    return GGML_VK_MAX_DEVICES;
-#else
-    return 1;
-#endif
-}
-
-bool llama_supports_mmap(void) {
-    return llama_mmap::SUPPORTED;
-}
-
-bool llama_supports_mlock(void) {
-    return llama_mlock::SUPPORTED;
-}
-
-bool llama_supports_gpu_offload(void) {
-#if defined(GGML_USE_CUDA) || defined(GGML_USE_METAL)   || defined(GGML_USE_VULKAN) || \
-    defined(GGML_USE_SYCL) || defined(GGML_USE_KOMPUTE) || defined(GGML_USE_RPC)
-    // Defined when llama.cpp is compiled with support for offloading model layers to GPU.
-    return true;
-#else
-    return false;
-#endif
-}
-
-void llama_backend_init(void) {
-    ggml_time_init();
-
-    // needed to initialize f16 tables
-    {
-        struct ggml_init_params params = { 0, NULL, false };
-        struct ggml_context * ctx = ggml_init(params);
-        ggml_free(ctx);
-    }
-}
-
-void llama_numa_init(enum ggml_numa_strategy numa) {
-    if (numa != GGML_NUMA_STRATEGY_DISABLED) {
-        ggml_numa_init(numa);
-    }
-}
-
-void llama_backend_free(void) {
-    ggml_quantize_free();
-}
-
-int64_t llama_time_us(void) {
-    return ggml_time_us();
-}
-
-struct llama_model * llama_load_model_from_file(
-        const char * path_model,
-        struct llama_model_params   params) {
-    ggml_time_init();
-
-    llama_model * model = new llama_model;
-
-    unsigned cur_percentage = 0;
-    if (params.progress_callback == NULL) {
-        params.progress_callback_user_data = &cur_percentage;
-        params.progress_callback = [](float progress, void * ctx) {
-            unsigned * cur_percentage_p = (unsigned *) ctx;
-            unsigned percentage = (unsigned) (100 * progress);
-            while (percentage > *cur_percentage_p) {
-                *cur_percentage_p = percentage;
-                LLAMA_LOG_INFO(".");
-                if (percentage >= 100) {
-                    LLAMA_LOG_INFO("\n");
-                }
-            }
-            return true;
-        };
-    }
-    if (params.rpc_servers != nullptr && params.rpc_servers[0] != '\0') {
-        // split the servers set them into model->rpc_servers
-        std::string servers(params.rpc_servers);
-        size_t pos = 0;
-        while ((pos = servers.find(",")) != std::string::npos) {
-            std::string server = servers.substr(0, pos);
-            model->rpc_servers.push_back(server);
-            servers.erase(0, pos + 1);
-        }
-        model->rpc_servers.push_back(servers);
-    }
-    int status = llama_model_load(path_model, *model, params);
-    GGML_ASSERT(status <= 0);
-    if (status < 0) {
-        if (status == -1) {
-            LLAMA_LOG_ERROR("%s: failed to load model\n", __func__);
-        } else if (status == -2) {
-            LLAMA_LOG_INFO("%s: cancelled model load\n", __func__);
-        }
-        delete model;
-        return nullptr;
-    }
-
-    return model;
-}
-
-void llama_free_model(struct llama_model * model) {
-    delete model;
-}
-
-struct llama_context * llama_new_context_with_model(
-                 struct llama_model * model,
-        struct llama_context_params   params) {
-
-    if (!model) {
-        LLAMA_LOG_ERROR("%s: model cannot be NULL\n", __func__);
-        return nullptr;
-    }
-
-    if (params.n_batch == 0 && params.n_ubatch == 0) {
-        LLAMA_LOG_ERROR("%s: n_batch and n_ubatch cannot both be zero\n", __func__);
-        return nullptr;
-    }
-
-    if (params.n_ctx == 0 && model->hparams.n_ctx_train == 0) {
-        LLAMA_LOG_ERROR("%s: n_ctx and model->hparams.n_ctx_train cannot both be zero\n", __func__);
-        return nullptr;
-    }
-
-    if (params.flash_attn && model->arch == LLM_ARCH_GROK) {
-        LLAMA_LOG_WARN("%s: flash_attn is not compatible with Grok - forcing off\n", __func__);
-        params.flash_attn = false;
-    }
-
-    if (params.type_v != GGML_TYPE_F16 && !params.flash_attn) {
-        LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__);
-        return nullptr;
-    }
-
-    llama_context * ctx = new llama_context(*model);
-
-    const auto & hparams = model->hparams;
-    auto       & cparams = ctx->cparams;
-
-    cparams.n_seq_max        = std::max(1u, params.n_seq_max);
-    cparams.n_threads        = params.n_threads;
-    cparams.n_threads_batch  = params.n_threads_batch;
-    cparams.yarn_ext_factor  = params.yarn_ext_factor;
-    cparams.yarn_attn_factor = params.yarn_attn_factor;
-    cparams.yarn_beta_fast   = params.yarn_beta_fast;
-    cparams.yarn_beta_slow   = params.yarn_beta_slow;
-    cparams.defrag_thold     = params.defrag_thold;
-    cparams.embeddings       = params.embeddings;
-    cparams.offload_kqv      = params.offload_kqv;
-    cparams.flash_attn       = params.flash_attn;
-    cparams.pooling_type     = params.pooling_type;
-
-    cparams.n_ctx            = params.n_ctx           == 0    ? hparams.n_ctx_train           : params.n_ctx;
-    cparams.rope_freq_base   = params.rope_freq_base  == 0.0f ? hparams.rope_freq_base_train  : params.rope_freq_base;
-    cparams.rope_freq_scale  = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale;
-
-    // this is necessary due to kv_self.n being padded later during inference
-    cparams.n_ctx            = GGML_PAD(cparams.n_ctx, llama_kv_cache_get_padding(cparams));
-
-    // with causal attention, the batch size is limited by the context size
-    cparams.n_batch          = hparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;
-
-    // the batch has to be at least GGML_KQ_MASK_PAD because we will be padding the KQ_mask
-    // this is required by GPU kernels in order to avoid out-of-bounds accesses (e.g. ggml_flash_attn_ext)
-    // ref: https://github.com/ggerganov/llama.cpp/pull/5021
-    if (cparams.n_batch < GGML_KQ_MASK_PAD) {
-        LLAMA_LOG_WARN("%s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n", __func__, GGML_KQ_MASK_PAD);
-        cparams.n_batch = GGML_KQ_MASK_PAD;
-    }
-
-    cparams.n_ubatch         = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
-
-    cparams.n_ctx_orig_yarn  = params.yarn_orig_ctx    != 0 ? params.yarn_orig_ctx    :
-                               hparams.n_ctx_orig_yarn != 0 ? hparams.n_ctx_orig_yarn :
-                                                              hparams.n_ctx_train;
-
-    cparams.cb_eval           = params.cb_eval;
-    cparams.cb_eval_user_data = params.cb_eval_user_data;
-
-    auto rope_scaling_type = params.rope_scaling_type;
-    if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED) {
-        rope_scaling_type = hparams.rope_scaling_type_train;
-    }
-
-    if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_NONE) {
-        cparams.rope_freq_scale = 1.0f; // never scale if scaling type is none
-    }
-
-    if (cparams.yarn_ext_factor < 0.0f) { // negative indicates 'not set'
-        cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_YARN ? 1.0f : 0.0f;
-    }
-
-    cparams.yarn_attn_factor *= hparams.rope_attn_factor;
-    cparams.causal_attn = hparams.causal_attn;
-
-    if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
-        if (hparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
-            cparams.pooling_type = LLAMA_POOLING_TYPE_NONE;
-        } else {
-            cparams.pooling_type = hparams.pooling_type;
-        }
-    }
-
-    if (params.seed == LLAMA_DEFAULT_SEED) {
-        params.seed = time(NULL);
-    }
-
-    LLAMA_LOG_INFO("%s: n_ctx      = %u\n",     __func__, cparams.n_ctx);
-    LLAMA_LOG_INFO("%s: n_batch    = %u\n",     __func__, cparams.n_batch);
-    LLAMA_LOG_INFO("%s: n_ubatch   = %u\n",     __func__, cparams.n_ubatch);
-    LLAMA_LOG_INFO("%s: flash_attn = %d\n",     __func__, cparams.flash_attn);
-    LLAMA_LOG_INFO("%s: freq_base  = %.1f\n",   __func__, cparams.rope_freq_base);
-    LLAMA_LOG_INFO("%s: freq_scale = %g\n",     __func__, cparams.rope_freq_scale);
-
-    ctx->abort_callback      = params.abort_callback;
-    ctx->abort_callback_data = params.abort_callback_data;
-
-    ctx->rng                 = std::mt19937(params.seed);
-    ctx->logits_all          = params.logits_all;
-
-    uint32_t kv_size = cparams.n_ctx;
-    ggml_type type_k = params.type_k;
-    ggml_type type_v = params.type_v;
-
-    // Mamba only needs a constant number of KV cache cells per sequence
-    if (model->arch == LLM_ARCH_MAMBA) {
-        // Mamba needs at least as many KV cells as there are sequences kept at any time
-        kv_size = std::max((uint32_t) 1, params.n_seq_max);
-        // it's probably best to keep as much precision as possible for the states
-        type_k = GGML_TYPE_F32; // required by ggml_ssm_conv for Mamba's conv_states
-        type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_states
-    }
-
-    GGML_ASSERT(hparams.n_embd_head_k % ggml_blck_size(type_k) == 0);
-    GGML_ASSERT(hparams.n_embd_head_v % ggml_blck_size(type_v) == 0);
-
-    if (!hparams.vocab_only) {
-        // initialize backends
-#if defined(GGML_USE_METAL)
-        if (model->n_gpu_layers > 0) {
-            ctx->backend_metal = ggml_backend_metal_init();
-            if (ctx->backend_metal == nullptr) {
-                LLAMA_LOG_ERROR("%s: failed to initialize Metal backend\n", __func__);
-                llama_free(ctx);
-                return nullptr;
-            }
-            ctx->backends.push_back(ctx->backend_metal);
-        }
-#elif defined(GGML_USE_CUDA)
-        if (model->split_mode == LLAMA_SPLIT_MODE_NONE || model->split_mode == LLAMA_SPLIT_MODE_ROW) {
-            // with split_mode LLAMA_SPLIT_MODE_NONE or LLAMA_SPLIT_MODE_ROW, only the main GPU backend is used
-            ggml_backend_t backend = ggml_backend_cuda_init(model->main_gpu);
-            if (backend == nullptr) {
-                LLAMA_LOG_ERROR("%s: failed to initialize CUDA%d backend\n", __func__, model->main_gpu);
-                llama_free(ctx);
-                return nullptr;
-            }
-            ctx->backends.push_back(backend);
-        } else {
-            // LLAMA_SPLIT_MODE_LAYER requires a backend for each GPU
-            for (int device = 0; device < ggml_backend_cuda_get_device_count(); ++device) {
-                ggml_backend_t backend = ggml_backend_cuda_init(device);
-                if (backend == nullptr) {
-                    LLAMA_LOG_ERROR("%s: failed to initialize CUDA%d backend\n", __func__, device);
-                    llama_free(ctx);
-                    return nullptr;
-                }
-                ctx->backends.push_back(backend);
-            }
-        }
-#elif defined(GGML_USE_VULKAN)
-        if (model->split_mode == LLAMA_SPLIT_MODE_ROW) {
-            LLAMA_LOG_ERROR("%s: Row split not supported. Failed to initialize Vulkan backend\n", __func__);
-            llama_free(ctx);
-            return nullptr;
-        }
-        if (model->split_mode == LLAMA_SPLIT_MODE_NONE) {
-            ggml_backend_t backend = ggml_backend_vk_init(model->main_gpu);
-            if (backend == nullptr) {
-                LLAMA_LOG_ERROR("%s: failed to initialize Vulkan backend\n", __func__);
-                llama_free(ctx);
-                return nullptr;
-            }
-            ctx->backends.push_back(backend);
-        } else {
-            for (int device = 0; device < ggml_backend_vk_get_device_count(); ++device) {
-                ggml_backend_t backend = ggml_backend_vk_init(device);
-                if (backend == nullptr) {
-                    LLAMA_LOG_ERROR("%s: failed to initialize Vulkan%d backend\n", __func__, device);
-                    llama_free(ctx);
-                    return nullptr;
-                }
-                ctx->backends.push_back(backend);
-            }
-        }
-#elif defined(GGML_USE_SYCL)
-        // with split_mode LLAMA_SPLIT_MODE_NONE or LLAMA_SPLIT_MODE_ROW, only the main GPU backend is used
-        if (model->split_mode == LLAMA_SPLIT_MODE_NONE || model->split_mode == LLAMA_SPLIT_MODE_ROW) {
-            ggml_backend_t backend = ggml_backend_sycl_init(model->main_gpu);
-            if (backend == nullptr) {
-                int main_gpu_id = ggml_backend_sycl_get_device_id(model->main_gpu);
-                LLAMA_LOG_ERROR("%s: failed to initialize SYCL%d (index %d) backend\n", __func__, main_gpu_id, model->main_gpu);
-                llama_free(ctx);
-                return nullptr;
-            }
-            ctx->backends.push_back(backend);
-        } else {
-            // LLAMA_SPLIT_LAYER requires a backend for each GPU
-            for (int i = 0; i < ggml_backend_sycl_get_device_count(); ++i) {
-                ggml_backend_t backend = ggml_backend_sycl_init(i);
-                if (backend == nullptr) {
-                    int id_list[GGML_SYCL_MAX_DEVICES];
-                    ggml_sycl_get_gpu_list(id_list, GGML_SYCL_MAX_DEVICES);
-                    LLAMA_LOG_ERROR("%s: failed to initialize SYCL%d (index %d) backend\n", __func__, id_list[i], i);
-                    llama_free(ctx);
-                    return nullptr;
-                }
-                ctx->backends.push_back(backend);
-            }
-        }
-#elif defined(GGML_USE_KOMPUTE)
-        if (model->n_gpu_layers > 0) {
-            auto * backend = ggml_backend_kompute_init(model->main_gpu);
-            if (backend == nullptr) {
-                LLAMA_LOG_ERROR("%s: failed to initialize Kompute backend\n", __func__);
-                llama_free(ctx);
-                return nullptr;
-            }
-            ctx->backends.push_back(backend);
-        }
-#endif
-#if defined(GGML_USE_RPC)
-        if (model->n_gpu_layers > 0) {
-            for (const auto & endpoint : model->rpc_servers) {
-                ggml_backend_t backend = ggml_backend_rpc_init(endpoint.c_str());
-                if (backend == nullptr) {
-                    LLAMA_LOG_ERROR("%s: failed to initialize RPC to '%s'\n", __func__, endpoint.c_str());
-                    llama_free(ctx);
-                    return nullptr;
-                }
-                ctx->backends.push_back(backend);
-            }
-        }
-#endif
-        ctx->backend_cpu = ggml_backend_cpu_init();
-        if (ctx->backend_cpu == nullptr) {
-            LLAMA_LOG_ERROR("%s: failed to initialize CPU backend\n", __func__);
-            llama_free(ctx);
-            return nullptr;
-        }
-        ctx->backends.push_back(ctx->backend_cpu);
-
-        if (!llama_kv_cache_init(ctx->kv_self, ctx, type_k, type_v, kv_size, cparams.offload_kqv)) {
-            LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__);
-            llama_free(ctx);
-            return nullptr;
-        }
-
-        {
-            size_t memory_size_k = 0;
-            size_t memory_size_v = 0;
-
-            for (auto & k : ctx->kv_self.k_l) {
-                memory_size_k += ggml_nbytes(k);
-            }
-
-            for (auto & v : ctx->kv_self.v_l) {
-                memory_size_v += ggml_nbytes(v);
-            }
-
-            LLAMA_LOG_INFO("%s: KV self size  = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
-                (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
-                ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
-                ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
-        }
-
-        // graph outputs buffer
-        {
-            // resized during inference when a batch uses more outputs
-            if (llama_output_reserve(*ctx, params.n_seq_max) < params.n_seq_max) {
-                LLAMA_LOG_ERROR("%s: failed to reserve initial output buffer\n", __func__);
-                llama_free(ctx);
-                return nullptr;
-            }
-
-            LLAMA_LOG_INFO("%s: %10s  output buffer size = %8.2f MiB\n", __func__,
-                    ggml_backend_buffer_name(ctx->buf_output),
-                    ggml_backend_buffer_get_size(ctx->buf_output) / 1024.0 / 1024.0);
-        }
-
-        // scheduler and compute buffers
-        {
-            // buffer types used for the compute buffer of each backend
-            std::vector backend_buft;
-            for (auto * backend : ctx->backends) {
-                if (ggml_backend_is_cpu(backend)) {
-                    // use host buffers for the CPU backend compute buffer
-                    backend_buft.push_back(llama_default_buffer_type_cpu(true));
-                } else {
-                    backend_buft.push_back(ggml_backend_get_default_buffer_type(backend));
-                }
-            }
-
-            // buffer used to store the computation graph and the tensor meta data
-            ctx->buf_compute_meta.resize(ggml_tensor_overhead()*LLAMA_MAX_NODES + ggml_graph_overhead_custom(LLAMA_MAX_NODES, false));
-
-            // enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary
-            bool pipeline_parallel =
-                llama_get_device_count(*model) > 1 &&
-                model->n_gpu_layers > (int)model->hparams.n_layer &&
-                model->split_mode == LLAMA_SPLIT_MODE_LAYER &&
-                params.offload_kqv;
-#ifndef GGML_USE_CUDA
-            // pipeline parallelism requires support for async compute and events
-            // currently this is only implemented in the CUDA backend
-            pipeline_parallel = false;
-#endif
-            ctx->sched = ggml_backend_sched_new(ctx->backends.data(), backend_buft.data(), ctx->backends.size(), LLAMA_MAX_NODES, pipeline_parallel);
-
-            if (pipeline_parallel) {
-                LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, ggml_backend_sched_get_n_copies(ctx->sched));
-            }
-
-            // build worst-case graph
-            int n_tokens = (int)std::min(cparams.n_ctx, cparams.n_ubatch);
-            int n_past = cparams.n_ctx - n_tokens;
-            llama_token token = llama_token_bos(&ctx->model); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
-            ggml_cgraph * gf = llama_build_graph(*ctx, llama_batch_get_one(&token, n_tokens, n_past, 0), true);
-
-            // initialize scheduler with the worst-case graph
-            if (!ggml_backend_sched_reserve(ctx->sched, gf)) {
-                LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
-                llama_free(ctx);
-                return nullptr;
-            }
-
-            for (size_t i = 0; i < ctx->backends.size(); i++) {
-                ggml_backend_t backend = ctx->backends[i];
-                ggml_backend_buffer_type_t buft = backend_buft[i];
-                size_t size = ggml_backend_sched_get_buffer_size(ctx->sched, backend);
-                if (size > 1) {
-                    LLAMA_LOG_INFO("%s: %10s compute buffer size = %8.2f MiB\n", __func__,
-                            ggml_backend_buft_name(buft),
-                            size / 1024.0 / 1024.0);
-                }
-            }
-
-            // note: the number of splits during measure is higher than during inference due to the kv shift
-            int n_splits = ggml_backend_sched_get_n_splits(ctx->sched);
-            LLAMA_LOG_INFO("%s: graph nodes  = %d\n", __func__, gf->n_nodes);
-            LLAMA_LOG_INFO("%s: graph splits = %d\n", __func__, n_splits);
-        }
-    }
-
-    return ctx;
-}
-
-void llama_free(struct llama_context * ctx) {
-    delete ctx;
-}
-
-const llama_model * llama_get_model(const struct llama_context * ctx) {
-    return &ctx->model;
-}
-
-uint32_t llama_n_ctx(const struct llama_context * ctx) {
-    return ctx->cparams.n_ctx;
-}
-
-uint32_t llama_n_batch(const struct llama_context * ctx) {
-    return ctx->cparams.n_batch;
-}
-
-uint32_t llama_n_ubatch(const struct llama_context * ctx) {
-    return ctx->cparams.n_ubatch;
-}
-
-uint32_t llama_n_seq_max(const struct llama_context * ctx) {
-    return ctx->kv_self.size;
-}
-
-enum llama_vocab_type llama_vocab_type(const struct llama_model * model) {
-    return model->vocab.type;
-}
-
-enum llama_rope_type llama_rope_type(const struct llama_model * model) {
-    switch (model->arch) {
-        // these models do not use RoPE
-        case LLM_ARCH_GPT2:
-        case LLM_ARCH_GPTJ:
-        case LLM_ARCH_MPT:
-        case LLM_ARCH_REFACT:
-        case LLM_ARCH_BLOOM:
-        case LLM_ARCH_MAMBA:
-        case LLM_ARCH_JINA_BERT_V2:
-            return LLAMA_ROPE_TYPE_NONE;
-
-        // use what we call a normal RoPE, operating on pairs of consecutive head values
-        case LLM_ARCH_LLAMA:
-        case LLM_ARCH_BAICHUAN:
-        case LLM_ARCH_STARCODER:
-        case LLM_ARCH_PLAMO:
-        case LLM_ARCH_CODESHELL:
-        case LLM_ARCH_ORION:
-        case LLM_ARCH_INTERNLM2:
-        case LLM_ARCH_MINICPM:
-        case LLM_ARCH_XVERSE:
-        case LLM_ARCH_COMMAND_R:
-        case LLM_ARCH_OLMO:
-        case LLM_ARCH_ARCTIC:
-        case LLM_ARCH_DEEPSEEK2:
-            return LLAMA_ROPE_TYPE_NORM;
-
-        // the pairs of head values are offset by n_rot/2
-        case LLM_ARCH_FALCON:
-        case LLM_ARCH_GROK:
-        case LLM_ARCH_DBRX:
-        case LLM_ARCH_BERT:
-        case LLM_ARCH_NOMIC_BERT:
-        case LLM_ARCH_STABLELM:
-        case LLM_ARCH_QWEN:
-        case LLM_ARCH_QWEN2:
-        case LLM_ARCH_QWEN2MOE:
-        case LLM_ARCH_PHI2:
-        case LLM_ARCH_PHI3:
-        case LLM_ARCH_GEMMA:
-        case LLM_ARCH_STARCODER2:
-        case LLM_ARCH_GPTNEOX:
-            return LLAMA_ROPE_TYPE_NEOX;
-
-        // all model arches should be listed explicitly here
-        case LLM_ARCH_UNKNOWN:
-            GGML_ASSERT(false && "unknown architecture");
-            break;
-    }
-
-    return LLAMA_ROPE_TYPE_NONE;
-}
-
-enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx) {
-    return ctx->cparams.pooling_type;
-}
-
-int32_t llama_n_vocab(const struct llama_model * model) {
-    return model->hparams.n_vocab;
-}
-
-int32_t llama_n_ctx_train(const struct llama_model * model) {
-    return model->hparams.n_ctx_train;
-}
-
-int32_t llama_n_embd(const struct llama_model * model) {
-    return model->hparams.n_embd;
-}
-
-int32_t llama_n_layer(const struct llama_model * model) {
-    return model->hparams.n_layer;
-}
-
-float llama_rope_freq_scale_train(const struct llama_model * model) {
-    return model->hparams.rope_freq_scale_train;
-}
-
-int32_t llama_model_meta_val_str(const struct llama_model * model, const char * key, char * buf, size_t buf_size) {
-    const auto & it = model->gguf_kv.find(key);
-    if (it == model->gguf_kv.end()) {
-        if (buf_size > 0) {
-            buf[0] = '\0';
-        }
-        return -1;
-    }
-    return snprintf(buf, buf_size, "%s", it->second.c_str());
-}
-
-int32_t llama_model_meta_count(const struct llama_model * model) {
-    return (int)model->gguf_kv.size();
-}
-
-int32_t llama_model_meta_key_by_index(const struct llama_model * model, int i, char * buf, size_t buf_size) {
-    if (i < 0 || i >= (int)model->gguf_kv.size()) {
-        if (buf_size > 0) {
-            buf[0] = '\0';
-        }
-        return -1;
-    }
-    auto it = model->gguf_kv.begin();
-    std::advance(it, i);
-    return snprintf(buf, buf_size, "%s", it->first.c_str());
-}
-
-int32_t llama_model_meta_val_str_by_index(const struct llama_model * model, int32_t i, char * buf, size_t buf_size) {
-    if (i < 0 || i >= (int)model->gguf_kv.size()) {
-        if (buf_size > 0) {
-            buf[0] = '\0';
-        }
-        return -1;
-    }
-    auto it = model->gguf_kv.begin();
-    std::advance(it, i);
-    return snprintf(buf, buf_size, "%s", it->second.c_str());
-}
-
-int32_t llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size) {
-    return snprintf(buf, buf_size, "%s %s %s",
-            llama_model_arch_name(model->arch),
-            llama_model_type_name(model->type),
-            llama_model_ftype_name(model->ftype).c_str());
-}
-
-uint64_t llama_model_size(const struct llama_model * model) {
-    uint64_t size = 0;
-    for (const auto & it : model->tensors_by_name) {
-        size += ggml_nbytes(it.second);
-    }
-    return size;
-}
-
-uint64_t llama_model_n_params(const struct llama_model * model) {
-    uint64_t nparams = 0;
-    for (const auto & it : model->tensors_by_name) {
-        nparams += ggml_nelements(it.second);
-    }
-    return nparams;
-}
-
-struct ggml_tensor * llama_get_model_tensor(struct llama_model * model, const char * name) {
-    auto it = std::find_if(model->tensors_by_name.begin(), model->tensors_by_name.end(),
-            [name](const std::pair & it) {
-                return it.first == name;
-            });
-    if (it == model->tensors_by_name.end()) {
-        return nullptr;
-    }
-    return it->second;
-}
-
-uint32_t llama_model_quantize(
-        const char * fname_inp,
-        const char * fname_out,
-        const llama_model_quantize_params * params) {
-    try {
-        llama_model_quantize_internal(fname_inp, fname_out, params);
-        return 0;
-    } catch (const std::exception & err) {
-        LLAMA_LOG_ERROR("%s: failed to quantize: %s\n", __func__, err.what());
-        return 1;
-    }
-}
-
 int32_t llama_model_apply_lora_from_file(const struct llama_model * model, const char * path_lora, float scale, const char * path_base_model, int32_t n_threads) {
     try {
         return llama_apply_lora_from_file_internal(*model, path_lora, scale, path_base_model, n_threads);
@@ -16586,2086 +19470,4 @@ int32_t llama_model_apply_lora_from_file(const struct llama_model * model, const
         LLAMA_LOG_ERROR("%s: failed to apply lora adapter: %s\n", __func__, err.what());
         return 1;
     }
-}
-
-static bool llama_control_vector_init(struct llama_control_vector & cvec, const llama_model & model) {
-    GGML_ASSERT(cvec.tensors.empty());
-    GGML_ASSERT(cvec.ctxs.empty());
-    GGML_ASSERT(cvec.bufs.empty());
-
-    // count layer buffer types
-    std::map buft_layer_count;
-    for (int64_t i = 0; i < model.hparams.n_layer; i++) {
-        buft_layer_count[model.buft_layer[i].buft]++;
-    }
-
-    // allocate contexts
-    std::map ctx_map;
-    for (auto & it : buft_layer_count) {
-        int n_layers = it.second;
-        struct ggml_init_params params = {
-            /*.mem_size   =*/ n_layers * ggml_tensor_overhead(),
-            /*.mem_buffer =*/ NULL,
-            /*.no_alloc   =*/ true,
-        };
-        ggml_context * ctx = ggml_init(params);
-        if (!ctx) {
-            LLAMA_LOG_ERROR("%s: failed to allocate context for control vector\n", __func__);
-            return 1;
-        }
-        ctx_map[it.first] = ctx;
-    }
-
-    // make tensors
-    cvec.tensors.reserve(model.hparams.n_layer);
-    cvec.tensors.push_back(nullptr); // there's never a tensor for layer 0
-    for (size_t il = 1; il < model.hparams.n_layer; il++) {
-        struct ggml_context * ctx = ctx_map.at(model.buft_layer[il].buft);
-        ggml_tensor * tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, model.hparams.n_embd);
-        cvec.tensors.push_back(tensor);
-    }
-
-    // allocate tensors / buffers and zero
-    cvec.ctxs.reserve(ctx_map.size());
-    cvec.bufs.reserve(ctx_map.size());
-    for (auto it : ctx_map) {
-        ggml_backend_buffer_type_t buft = it.first;
-        ggml_context * ctx = it.second;
-        ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
-        if (!buf) {
-            LLAMA_LOG_ERROR("%s: failed to allocate buffer for control vector\n", __func__);
-            return false;
-        }
-        ggml_backend_buffer_clear(buf, 0);
-        cvec.ctxs.push_back(ctx);
-        cvec.bufs.push_back(buf);
-    }
-
-    return true;
-}
-
-int32_t llama_control_vector_apply(struct llama_context * lctx, const float * data, size_t len, int32_t n_embd, int32_t il_start, int32_t il_end) {
-    const llama_model & model = lctx->model;
-    llama_control_vector & cvec = lctx->cvec;
-
-    if (data == nullptr) {
-        // disable the current control vector (but leave allocated for later)
-        cvec.layer_start = -1;
-        cvec.layer_end   = -1;
-        return 0;
-    }
-
-    if (n_embd != (int) model.hparams.n_embd) {
-        LLAMA_LOG_ERROR("%s: control vector n_embd does not match model\n", __func__);
-        return 1;
-    }
-
-    if (cvec.tensors.empty()) {
-        if (!llama_control_vector_init(cvec, model)) {
-            return 1;
-        }
-    }
-
-    cvec.layer_start = il_start;
-    cvec.layer_end   = il_end;
-
-    for (size_t il = 1; il < model.hparams.n_layer; il++) {
-        assert(cvec.tensors[il] != nullptr);
-
-        const size_t off = n_embd * (il - 1); // buffer doesn't have data for layer 0, since it's never present
-        if (off + n_embd <= len) {
-            ggml_backend_tensor_set(cvec.tensors[il], data + off, 0, n_embd * ggml_element_size(cvec.tensors[il]));
-        }
-    }
-
-    return 0;
-}
-
-struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_seq_max) {
-    struct llama_kv_cache_view result = {
-        /*.n_cells            = */ 0,
-        /*.n_seq_max          = */ n_seq_max,
-        /*.token_count        = */ 0,
-        /*.used_cells         = */ llama_get_kv_cache_used_cells(ctx),
-        /*.max_contiguous     = */ 0,
-        /*.max_contiguous_idx = */ -1,
-        /*.cells              = */ nullptr,
-        /*.cells_sequences    = */ nullptr,
-    };
-    return result;
-}
-
-void llama_kv_cache_view_free(struct llama_kv_cache_view * view) {
-    if (view->cells != nullptr) {
-        free(view->cells);
-        view->cells = nullptr;
-    }
-    if (view->cells_sequences != nullptr) {
-        free(view->cells_sequences);
-        view->cells_sequences = nullptr;
-    }
-}
-
-void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view) {
-    if (uint32_t(view->n_cells) < ctx->kv_self.size || view->cells == nullptr) {
-        view->n_cells = int32_t(ctx->kv_self.size);
-        void * p = realloc(view->cells, sizeof(struct llama_kv_cache_view_cell) * view->n_cells);
-        GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells");
-        view->cells = (struct llama_kv_cache_view_cell *)p;
-        p = realloc(view->cells_sequences, sizeof(llama_seq_id) * view->n_seq_max * view->n_cells);
-        GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells sequences");
-        view->cells_sequences = (llama_seq_id *)p;
-    }
-
-    const std::vector & kv_cells = ctx->kv_self.cells;
-    llama_kv_cache_view_cell * c_curr = view->cells;
-    llama_seq_id * cs_curr = view->cells_sequences;
-    int32_t used_cells = 0;
-    int32_t token_count = 0;
-    int32_t curr_contig_idx = -1;
-    uint32_t max_contig = 0;
-    int32_t max_contig_idx = -1;
-
-    for (int32_t i = 0; i < int32_t(ctx->kv_self.size); i++, c_curr++, cs_curr += view->n_seq_max) {
-        const size_t curr_size = kv_cells[i].seq_id.size();
-        token_count += curr_size;
-        c_curr->pos = kv_cells[i].pos + kv_cells[i].delta;
-
-        if (curr_size > 0) {
-            if (curr_contig_idx >= 0 && uint32_t(i - curr_contig_idx) > max_contig) {
-                max_contig = i - curr_contig_idx;
-                max_contig_idx = curr_contig_idx;
-            }
-            curr_contig_idx = -1;
-        } else if (curr_contig_idx < 0) {
-            curr_contig_idx = i;
-        }
-
-        int seq_idx = 0;
-        for (const llama_seq_id it : kv_cells[i].seq_id) {
-            if (seq_idx >= view->n_seq_max) {
-                break;
-            }
-            cs_curr[seq_idx] = it;
-            seq_idx++;
-        }
-        if (seq_idx != 0) {
-            used_cells++;
-        }
-        for (; seq_idx < view->n_seq_max; seq_idx++) {
-            cs_curr[seq_idx] = -1;
-        }
-    }
-    if (curr_contig_idx >= 0 && kv_cells.size() - curr_contig_idx > max_contig) {
-        max_contig_idx = curr_contig_idx;
-        max_contig = kv_cells.size() - curr_contig_idx;
-    }
-    view->max_contiguous = max_contig;
-    view->max_contiguous_idx = max_contig_idx;
-    view->token_count = token_count;
-    view->used_cells = used_cells;
-    if (uint32_t(used_cells) != ctx->kv_self.used) {
-        LLAMA_LOG_ERROR("%s: used cells mismatch. kv_cache says %d but we calculated %d\n",
-            __func__, ctx->kv_self.used, used_cells);
-    }
-}
-
-int32_t llama_get_kv_cache_token_count(const struct llama_context * ctx) {
-    int result = 0;
-
-    for (uint32_t i = 0; i < ctx->kv_self.size; i++) {
-        result += ctx->kv_self.cells[i].seq_id.size();
-    }
-
-    return result;
-}
-
-int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx) {
-    return ctx->kv_self.used;
-}
-
-void llama_kv_cache_clear(struct llama_context * ctx) {
-    llama_kv_cache_clear(ctx->kv_self);
-}
-
-bool llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
-    return llama_kv_cache_seq_rm(ctx->kv_self, seq_id, p0, p1);
-}
-
-void llama_kv_cache_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
-    if (seq_id_src == seq_id_dst) {
-        return;
-    }
-    llama_kv_cache_seq_cp(ctx->kv_self, seq_id_src, seq_id_dst, p0, p1);
-}
-
-void llama_kv_cache_seq_keep(struct llama_context * ctx, llama_seq_id seq_id) {
-    llama_kv_cache_seq_keep(ctx->kv_self, seq_id);
-}
-
-void llama_kv_cache_seq_add(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
-    if (delta == 0) {
-        return;
-    }
-
-    llama_kv_cache_seq_add(ctx->kv_self, seq_id, p0, p1, delta);
-}
-
-void llama_kv_cache_seq_div(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
-    if (d == 1) {
-        return;
-    }
-
-    llama_kv_cache_seq_div(ctx->kv_self, seq_id, p0, p1, d);
-}
-
-llama_pos llama_kv_cache_seq_pos_max(struct llama_context * ctx, llama_seq_id seq_id) {
-    return llama_kv_cache_seq_pos_max(ctx->kv_self, seq_id);
-}
-
-void llama_kv_cache_defrag(struct llama_context * ctx) {
-    llama_kv_cache_defrag(ctx->kv_self);
-}
-
-void llama_kv_cache_update(struct llama_context * ctx) {
-    llama_kv_cache_update_internal(*ctx);
-}
-
-// deprecated
-size_t llama_get_state_size(const struct llama_context * ctx) {
-    return llama_state_get_size(ctx);
-}
-
-// deprecated
-size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) {
-    return llama_state_get_data(ctx, dst);
-}
-
-// deprecated
-size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
-    return llama_state_set_data(ctx, src);
-}
-
-// deprecated
-bool llama_load_session_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
-    return llama_state_load_file(ctx, path_session, tokens_out, n_token_capacity, n_token_count_out);
-}
-
-// deprecated
-bool llama_save_session_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) {
-    return llama_state_save_file(ctx, path_session, tokens, n_token_count);
-}
-
-// Returns the *maximum* size of the state
-size_t llama_state_get_size(const struct llama_context * ctx) {
-    const auto & cparams = ctx->cparams;
-    const auto & hparams = ctx->model.hparams;
-
-    // we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state.
-    // for reference, std::mt19937(1337) serializes to 6701 bytes.
-    const size_t s_rng_size        = sizeof(size_t);
-    const size_t s_rng             = LLAMA_MAX_RNG_STATE;
-    const size_t s_n_outputs       = sizeof(size_t);
-    // assume worst case for outputs although only currently set ones are serialized
-    const size_t s_output_pos      = ctx->cparams.n_batch * sizeof(int32_t);
-    const size_t s_logits_size     = sizeof(size_t);
-    const size_t s_logits          = ctx->logits_size ? cparams.n_batch * hparams.n_vocab * sizeof(float) : 0;
-    const size_t s_embedding_size  = sizeof(size_t);
-    const size_t s_embedding       = ctx->embd_size   ? cparams.n_batch * hparams.n_embd  * sizeof(float) : 0;
-    const size_t s_kv_buf_size     = sizeof(size_t);
-    const size_t s_kv_head         = sizeof(uint32_t);
-    const size_t s_kv_size         = sizeof(uint32_t);
-    const size_t s_kv_used         = sizeof(uint32_t);
-    const size_t s_v_trans         = sizeof(uint32_t);
-    const size_t s_kv              = ctx->kv_self.total_size();
-    const size_t s_kv_cell         = sizeof(llama_pos) + sizeof(size_t) + cparams.n_seq_max*sizeof(llama_seq_id);
-    const size_t s_kv_cells        = ctx->kv_self.size * s_kv_cell;
-
-    const size_t s_total = (
-        + s_rng_size
-        + s_rng
-        + s_n_outputs
-        + s_output_pos
-        + s_logits_size
-        + s_logits
-        + s_embedding_size
-        + s_embedding
-        + s_kv_buf_size
-        + s_kv_head
-        + s_kv_size
-        + s_kv_used
-        + s_v_trans
-        + s_kv
-        + s_kv_cells
-    );
-
-    // on session change it is very likely that the state size has changed - so we need to update this function
-    static_assert(LLAMA_SESSION_VERSION == 6, "So you just bumped the session version - good. But did you remember to update llama_state_get_size?");
-
-    return s_total;
-}
-
-// llama_context_data
-struct llama_data_context {
-    virtual void write(const void * src, size_t size) = 0;
-    virtual size_t get_size_written() = 0;
-    virtual ~llama_data_context() = default;
-};
-
-struct llama_data_buffer_context : llama_data_context {
-    uint8_t * ptr;
-    size_t size_written = 0;
-
-    llama_data_buffer_context(uint8_t * p) : ptr(p) {}
-
-    void write(const void * src, size_t size) override {
-        memcpy(ptr, src, size);
-        ptr += size;
-        size_written += size;
-    }
-
-    size_t get_size_written() override {
-        return size_written;
-    }
-};
-
-struct llama_data_file_context : llama_data_context {
-    llama_file * file;
-    size_t size_written = 0;
-
-    llama_data_file_context(llama_file * f) : file(f) {}
-
-    void write(const void * src, size_t size) override {
-        file->write_raw(src, size);
-        size_written += size;
-    }
-
-    size_t get_size_written() override {
-        return size_written;
-    }
-};
-
-/** copy state data into either a buffer or file depending on the passed in context
- *
- * file context:
- * llama_file file("/path", "wb");
- * llama_data_file_context data_ctx(&file);
- * llama_state_get_data(ctx, &data_ctx);
- *
- * buffer context:
- * std::vector buf(max_size, 0);
- * llama_data_buffer_context data_ctx(&buf.data());
- * llama_state_get_data(ctx, &data_ctx);
- *
-*/
-static void llama_state_get_data_internal(struct llama_context * ctx, llama_data_context * data_ctx) {
-    llama_synchronize(ctx);
-
-    // copy rng
-    {
-        std::ostringstream rng_ss;
-        rng_ss << ctx->rng;
-
-        const std::string & rng_str  = rng_ss.str();
-        const size_t        rng_size = rng_str.size();
-
-        GGML_ASSERT(rng_size <= LLAMA_MAX_RNG_STATE);
-
-        data_ctx->write(&rng_size,      sizeof(rng_size));
-        data_ctx->write(rng_str.data(), rng_size);
-    }
-
-    // copy outputs
-    {
-        // Can't use ctx->n_outputs because it's not for the
-        // entire last batch when n_ubatch is smaller than n_batch
-        size_t n_outputs = 0;
-
-        // copy output ids
-        {
-            std::vector output_pos;
-
-            const size_t    n_batch = ctx->cparams.n_batch;
-            const auto & output_ids = ctx->output_ids;
-
-            output_pos.resize(ctx->output_size);
-
-            // build a more compact representation of the output ids
-            for (size_t i = 0; i < n_batch; ++i) {
-                // map an output id to a position in the batch
-                int32_t pos = output_ids[i];
-                if (pos >= 0) {
-                    if ((size_t) pos >= n_outputs) {
-                        n_outputs = pos + 1;
-                    }
-                    GGML_ASSERT((size_t) pos < ctx->output_size);
-                    output_pos[pos] = i;
-                }
-            }
-
-            data_ctx->write(&n_outputs, sizeof(n_outputs));
-
-            if (n_outputs) {
-                data_ctx->write(output_pos.data(), n_outputs * sizeof(int32_t));
-            }
-        }
-
-        // copy logits
-        {
-            const size_t logits_size = std::min(ctx->logits_size, n_outputs * ctx->model.hparams.n_vocab);
-
-            data_ctx->write(&logits_size, sizeof(logits_size));
-
-            if (logits_size) {
-                data_ctx->write(ctx->logits, logits_size * sizeof(float));
-            }
-        }
-
-        // copy embeddings
-        {
-            const size_t embeddings_size = std::min(ctx->embd_size, n_outputs * ctx->model.hparams.n_embd);
-
-            data_ctx->write(&embeddings_size, sizeof(embeddings_size));
-
-            if (embeddings_size) {
-                data_ctx->write(ctx->embd, embeddings_size * sizeof(float));
-            }
-        }
-    }
-
-    // copy kv cache
-    {
-        const auto & kv_self = ctx->kv_self;
-        const auto & hparams = ctx->model.hparams;
-
-        const uint32_t n_layer      = hparams.n_layer;
-        const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s();
-        const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s();
-
-        // NOTE: kv_size and kv_buf_size are mostly used for sanity checks
-        const uint32_t kv_head     = llama_kv_cache_cell_max(kv_self);
-        const uint32_t kv_size     = kv_self.size;
-        const size_t   kv_buf_size = kv_self.total_size() / (kv_size ? kv_size : 1) * kv_head;
-        const uint32_t kv_used     = kv_self.used;
-        const uint32_t v_trans     = kv_self.v_trans ? 1 : 0;
-
-        data_ctx->write(&kv_buf_size, sizeof(kv_buf_size));
-        data_ctx->write(&kv_head,     sizeof(kv_head));
-        data_ctx->write(&kv_size,     sizeof(kv_size));
-        data_ctx->write(&kv_used,     sizeof(kv_used));
-        data_ctx->write(&v_trans,     sizeof(v_trans));
-
-        if (kv_buf_size) {
-            const size_t pre_kv_buf_size = data_ctx->get_size_written();
-
-            std::vector tmp_buf;
-            for (int il = 0; il < (int) n_layer; ++il) {
-                const size_t k_size = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*kv_head);
-
-                tmp_buf.resize(k_size);
-                ggml_backend_tensor_get(kv_self.k_l[il], tmp_buf.data(), 0, tmp_buf.size());
-                data_ctx->write(tmp_buf.data(), tmp_buf.size());
-
-                if (kv_self.recurrent || !kv_self.v_trans) {
-                    // v is contiguous for recurrent models
-                    // TODO: use other tensors for state models than k and v
-                    const size_t v_size = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*kv_head);
-
-                    tmp_buf.resize(v_size);
-                    ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), 0, tmp_buf.size());
-                    data_ctx->write(tmp_buf.data(), tmp_buf.size());
-                    continue;
-                }
-
-                // v is not contiguous, copy row by row
-                const size_t v_row_size   = ggml_row_size(kv_self.v_l[il]->type, kv_head);
-                const size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, kv_size);
-
-                tmp_buf.resize(v_row_size);
-                for (int ir = 0; ir < (int) n_embd_v_gqa; ++ir) {
-                    ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), ir*v_row_stride, tmp_buf.size());
-                    data_ctx->write(tmp_buf.data(), tmp_buf.size());
-                }
-            }
-            GGML_ASSERT(kv_buf_size == data_ctx->get_size_written() - pre_kv_buf_size);
-        }
-
-        for (uint32_t i = 0; i < kv_head; ++i) {
-            const auto & cell = kv_self.cells[i];
-
-            const llama_pos pos         = cell.pos;
-            const size_t    seq_id_size = cell.seq_id.size();
-
-            data_ctx->write(&pos,         sizeof(pos));
-            data_ctx->write(&seq_id_size, sizeof(seq_id_size));
-
-            for (auto seq_id : cell.seq_id) {
-                data_ctx->write(&seq_id, sizeof(seq_id));
-            }
-        }
-    }
-}
-
-size_t llama_state_get_data(struct llama_context * ctx, uint8_t * dst) {
-    llama_data_buffer_context data_ctx(dst);
-    llama_state_get_data_internal(ctx, &data_ctx);
-
-    return data_ctx.get_size_written();
-}
-
-// Sets the state reading from the specified source address
-size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) {
-    llama_synchronize(ctx);
-
-    const uint8_t * inp = src;
-
-    // set rng
-    {
-        size_t rng_size;
-        memcpy(&rng_size, inp, sizeof(rng_size)); inp += sizeof(rng_size);
-
-        GGML_ASSERT(rng_size <= LLAMA_MAX_RNG_STATE);
-
-        std::string rng_str((const char *)inp, rng_size); inp += rng_size;
-
-        std::istringstream rng_ss(rng_str);
-        rng_ss >> ctx->rng;
-
-        GGML_ASSERT(!rng_ss.fail());
-    }
-
-    // set output ids
-    {
-        size_t n_outputs;
-        std::vector output_pos;
-
-        memcpy(&n_outputs, inp, sizeof(n_outputs)); inp += sizeof(n_outputs);
-
-        GGML_ASSERT(n_outputs <= llama_output_reserve(*ctx, n_outputs));
-
-        if (n_outputs) {
-            output_pos.resize(n_outputs);
-            memcpy(output_pos.data(), inp, n_outputs * sizeof(int32_t));
-            inp += n_outputs * sizeof(int32_t);
-
-            for (int32_t i = 0; i < (int32_t) output_pos.size(); ++i) {
-                int32_t id = output_pos[i];
-                GGML_ASSERT((uint32_t) id < ctx->cparams.n_batch);
-                ctx->output_ids[id] = i;
-            }
-
-            ctx->n_outputs = n_outputs;
-        }
-    }
-
-    // set logits
-    {
-        size_t logits_size;
-
-        memcpy(&logits_size, inp, sizeof(logits_size)); inp += sizeof(logits_size);
-
-        GGML_ASSERT(ctx->logits_size >= logits_size);
-
-        if (logits_size) {
-            memcpy(ctx->logits, inp, logits_size * sizeof(float));
-            inp += logits_size * sizeof(float);
-        }
-    }
-
-    // set embeddings
-    {
-        size_t embeddings_size;
-
-        memcpy(&embeddings_size, inp, sizeof(embeddings_size)); inp += sizeof(embeddings_size);
-
-        GGML_ASSERT(ctx->embd_size >= embeddings_size);
-
-        if (embeddings_size) {
-            memcpy(ctx->embd, inp, embeddings_size * sizeof(float));
-            inp += embeddings_size * sizeof(float);
-        }
-    }
-
-    // set kv cache
-    {
-        const auto & kv_self = ctx->kv_self;
-        const auto & hparams = ctx->model.hparams;
-
-        const uint32_t n_layer      = hparams.n_layer;
-        const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s();
-        const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s();
-
-        size_t   kv_buf_size;
-        uint32_t kv_head;
-        uint32_t kv_size;
-        uint32_t kv_used;
-        uint32_t v_trans;
-
-        memcpy(&kv_buf_size, inp, sizeof(kv_buf_size)); inp += sizeof(kv_buf_size);
-        memcpy(&kv_head,     inp, sizeof(kv_head));     inp += sizeof(kv_head);
-        memcpy(&kv_size,     inp, sizeof(kv_size));     inp += sizeof(kv_size);
-        memcpy(&kv_used,     inp, sizeof(kv_used));     inp += sizeof(kv_used);
-        memcpy(&v_trans,     inp, sizeof(v_trans));     inp += sizeof(v_trans);
-
-        GGML_ASSERT(kv_self.v_trans == (bool) v_trans); // incompatible V transposition
-
-        if (kv_self.size != kv_size) {
-            // the KV cache needs to be big enough to load all the KV cells from the saved state
-            GGML_ASSERT(kv_self.size >= kv_head);
-
-            LLAMA_LOG_INFO("%s: state contains %d KV cells, was saved with kv_size=%d, but is loaded with kv_size=%d (fine, but different)\n",
-                __func__, kv_head, kv_size, kv_self.size);
-        }
-
-        llama_kv_cache_clear(ctx);
-
-        if (kv_buf_size) {
-            const size_t pre_kv_buf_size = inp - src;
-
-            GGML_ASSERT(kv_self.total_size() >= kv_buf_size);
-
-            for (int il = 0; il < (int) n_layer; ++il) {
-                const size_t k_size = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*kv_head);
-
-                ggml_backend_tensor_set(kv_self.k_l[il], inp, 0, k_size);
-                inp += k_size;
-
-                if (kv_self.recurrent || !kv_self.v_trans) {
-                    // v is contiguous for recurrent models
-                    // TODO: use other tensors for state models than k and v
-                    const size_t v_size = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*kv_head);
-
-                    ggml_backend_tensor_set(kv_self.v_l[il], inp, 0, v_size);
-                    inp += v_size;
-                    continue;
-                }
-
-                // v is not contiguous, copy row by row
-                const size_t v_row_size   = ggml_row_size(kv_self.v_l[il]->type, kv_head);
-                const size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, kv_self.size);
-
-                for (int ir = 0; ir < (int) n_embd_v_gqa; ++ir) {
-                    ggml_backend_tensor_set(kv_self.v_l[il], inp, ir*v_row_stride, v_row_size);
-                    inp += v_row_size;
-                }
-            }
-            GGML_ASSERT(kv_buf_size == inp - src - pre_kv_buf_size);
-        }
-
-        ctx->kv_self.head = kv_head;
-        ctx->kv_self.used = kv_used;
-
-        for (uint32_t i = 0; i < kv_head; ++i) {
-            llama_pos pos;
-            size_t    seq_id_size;
-
-            memcpy(&pos,         inp, sizeof(pos));         inp += sizeof(pos);
-            memcpy(&seq_id_size, inp, sizeof(seq_id_size)); inp += sizeof(seq_id_size);
-
-            ctx->kv_self.cells[i].pos = pos;
-
-            llama_seq_id seq_id;
-
-            for (size_t j = 0; j < seq_id_size; ++j) {
-                memcpy(&seq_id, inp, sizeof(seq_id)); inp += sizeof(seq_id);
-                ctx->kv_self.cells[i].seq_id.insert(seq_id);
-            }
-        }
-    }
-
-    const size_t nread    = inp - src;
-    const size_t max_size = llama_state_get_size(ctx);
-
-    GGML_ASSERT(nread <= max_size);
-
-    return nread;
-}
-
-static bool llama_state_load_file_internal(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
-    llama_file file(path_session, "rb");
-
-    // sanity checks
-    {
-        const uint32_t magic   = file.read_u32();
-        const uint32_t version = file.read_u32();
-
-        if (magic != LLAMA_SESSION_MAGIC || version != LLAMA_SESSION_VERSION) {
-            LLAMA_LOG_ERROR("%s : unknown (magic, version) for session file: %08x, %08x\n", __func__, magic, version);
-            return false;
-        }
-
-        llama_hparams session_hparams;
-        file.read_raw(&session_hparams, sizeof(llama_hparams));
-
-        if (session_hparams != ctx->model.hparams) {
-            LLAMA_LOG_INFO("%s : model hparams didn't match from session file!\n", __func__);
-            return false;
-        }
-    }
-
-    // load the prompt
-    {
-        const uint32_t n_token_count = file.read_u32();
-
-        if (n_token_count > n_token_capacity) {
-            LLAMA_LOG_ERROR("%s : token count in session file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity);
-            return false;
-        }
-
-        file.read_raw(tokens_out, sizeof(llama_token) * n_token_count);
-        *n_token_count_out = n_token_count;
-    }
-
-    // restore the context state
-    {
-        const size_t n_state_size_cur = file.size - file.tell();
-        const size_t n_state_size_max = llama_state_get_size(ctx);
-
-        if (n_state_size_cur > n_state_size_max) {
-            LLAMA_LOG_ERROR("%s : the state size in session file is too big! max %zu, got %zu\n", __func__, n_state_size_max, n_state_size_cur);
-            return false;
-        }
-
-        std::vector state_data(n_state_size_max);
-        file.read_raw(state_data.data(), n_state_size_cur);
-
-        llama_state_set_data(ctx, state_data.data());
-    }
-
-    return true;
-}
-
-bool llama_state_load_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
-    try {
-        return llama_state_load_file_internal(ctx, path_session, tokens_out, n_token_capacity, n_token_count_out);
-    } catch (const std::exception & err) {
-        LLAMA_LOG_ERROR("error loading session file: %s\n", err.what());
-        return false;
-    }
-}
-
-static bool llama_state_save_file_internal(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) {
-    llama_file file(path_session, "wb");
-
-    file.write_u32(LLAMA_SESSION_MAGIC);
-    file.write_u32(LLAMA_SESSION_VERSION);
-
-    file.write_raw(&ctx->model.hparams, sizeof(llama_hparams));
-
-    // save the prompt
-    file.write_u32((uint32_t) n_token_count);
-    file.write_raw(tokens, sizeof(llama_token) * n_token_count);
-
-    // save the context state using stream saving
-    llama_data_file_context data_ctx(&file);
-    llama_state_get_data_internal(ctx, &data_ctx);
-
-    return true;
-}
-
-bool llama_state_save_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) {
-    try {
-        return llama_state_save_file_internal(ctx, path_session, tokens, n_token_count);
-    } catch (const std::exception & err) {
-        LLAMA_LOG_ERROR("error saving session file: %s\n", err.what());
-        return false;
-    }
-}
-
-size_t llama_state_seq_get_size(struct llama_context* ctx, llama_seq_id seq_id) {
-    // save the size of size_t as a uint32_t for safety check
-    const size_t size_t_size_size = sizeof(uint32_t);
-
-    // other values
-    const size_t s_cell_count_size = sizeof(uint32_t);
-    const size_t s_layer_count_size = sizeof(uint32_t);
-    const size_t n_embd_v_gqa_size = sizeof(uint32_t);
-
-    size_t s_cell_count = 0;
-    size_t s_cell_data_size = 0;
-    const auto & kv_self = ctx->kv_self;
-    const auto & hparams = ctx->model.hparams;
-
-    const uint32_t n_layer = hparams.n_layer;
-    const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s();
-    const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s();
-
-    for (uint32_t i = 0; i < kv_self.size; ++i) {
-        const auto & cell = kv_self.cells[i];
-        if (cell.seq_id.count(seq_id) > 0) {
-            ++s_cell_count;
-            s_cell_data_size += sizeof(llama_pos);
-        }
-    }
-
-    for (int il = 0; il < (int)n_layer; ++il) {
-        // types of keys and values
-        s_cell_data_size += sizeof(int32_t) * 2;
-        // k_size_row and v_size_el values of layer
-        s_cell_data_size += sizeof(size_t) * 2;
-
-        // keys
-        const size_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa);
-        s_cell_data_size += k_size_row * s_cell_count;
-
-        // values (transposed)
-        const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type);
-        s_cell_data_size += v_size_el * s_cell_count * n_embd_v_gqa;
-    }
-
-    const size_t s_total = (
-        size_t_size_size +
-        s_cell_count_size +
-        s_layer_count_size +
-        n_embd_v_gqa_size +
-        s_cell_data_size
-        );
-
-    return s_total;
-}
-
-static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llama_data_context & data_ctx, llama_seq_id seq_id) {
-    llama_synchronize(ctx);
-
-    const auto & kv_self = ctx->kv_self;
-    GGML_ASSERT(!kv_self.recurrent); // not implemented
-
-    // Save the size of size_t as a uint32_t for safety check
-    const uint32_t size_t_size = sizeof(size_t);
-    data_ctx.write(&size_t_size, sizeof(size_t_size));
-
-    std::vector> cell_ranges; // ranges, from inclusive, to exclusive
-    uint32_t cell_count = 0;
-
-    // Count the number of cells with the specified seq_id
-    // Find all the ranges of cells with this seq id
-    {
-        uint32_t cell_range_begin = kv_self.size;
-        for (uint32_t i = 0; i < kv_self.size; ++i) {
-            const auto & cell = kv_self.cells[i];
-            if (cell.has_seq_id(seq_id)) {
-                ++cell_count;
-                if (cell_range_begin == kv_self.size) {
-                    cell_range_begin = i;
-                }
-            }
-            else {
-                if (cell_range_begin != kv_self.size) {
-                    cell_ranges.emplace_back(cell_range_begin, i);
-                    cell_range_begin = kv_self.size;
-                }
-            }
-        }
-        if (cell_range_begin != kv_self.size) {
-            cell_ranges.emplace_back(cell_range_begin, kv_self.size);
-        }
-
-        // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
-        uint32_t cell_count_check = 0;
-        for (const auto & range : cell_ranges) {
-            cell_count_check += range.second - range.first;
-        }
-        GGML_ASSERT(cell_count == cell_count_check);
-    }
-
-    // Write the cell count
-    data_ctx.write(&cell_count, sizeof(cell_count));
-
-    const auto & hparams = ctx->model.hparams;
-    const uint32_t n_layer = hparams.n_layer;
-    const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s();
-    const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s();
-
-    // Write the layer count
-    data_ctx.write(&n_layer, sizeof(n_layer));
-
-    // Write n_embd_v_gqa
-    data_ctx.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa));
-
-    // Iterate the ranges and write all the pos (this is the token position in the prompt)
-    for (const auto & range : cell_ranges) {
-        for (uint32_t i = range.first; i < range.second; ++i) {
-            const auto & cell = kv_self.cells[i];
-            data_ctx.write(&cell.pos, sizeof(cell.pos));
-        }
-    }
-
-    // Iterate and write all the keys first, each row is a cell
-    // Get whole range at a time
-    std::vector tmp_buf;
-    for (int il = 0; il < (int)n_layer; ++il) {
-        // Write key type
-        const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type;
-        data_ctx.write(&k_type_i, sizeof(k_type_i));
-
-        // Write row size of key
-        const size_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa);
-        data_ctx.write(&k_size_row, sizeof(k_size_row));
-
-        // Read each range of cells of k_size length each into tmp_buf and write out
-        for (const auto & range : cell_ranges) {
-            const size_t range_size = range.second - range.first;
-            tmp_buf.resize(range_size * k_size_row);
-            ggml_backend_tensor_get(kv_self.k_l[il], tmp_buf.data(), range.first * k_size_row, range_size * k_size_row);
-            data_ctx.write(tmp_buf.data(), tmp_buf.size());
-        }
-    }
-
-    // TODO: simplify, reduce copy-paste
-    if (!kv_self.v_trans) {
-        for (int il = 0; il < (int)n_layer; ++il) {
-            // Write value type
-            const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
-            data_ctx.write(&v_type_i, sizeof(v_type_i));
-
-            // Write row size of value
-            const size_t v_size_row = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
-            data_ctx.write(&v_size_row, sizeof(v_size_row));
-
-            // Read each range of cells of v_size length each into tmp_buf and write out
-            for (const auto & range : cell_ranges) {
-                const size_t range_size = range.second - range.first;
-                tmp_buf.resize(range_size * v_size_row);
-                ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), range.first * v_size_row, range_size * v_size_row);
-                data_ctx.write(tmp_buf.data(), tmp_buf.size());
-            }
-        }
-    } else {
-        // For the values, they are transposed, so we also need the element size and get the element ranges from each row
-        const uint32_t kv_size = kv_self.size;
-        for (int il = 0; il < (int)n_layer; ++il) {
-            // Write value type
-            const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
-            data_ctx.write(&v_type_i, sizeof(v_type_i));
-
-            // Write element size
-            const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type);
-            data_ctx.write(&v_size_el, sizeof(v_size_el));
-
-            // For each row, we get the element values of each cell
-            for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
-                // Read each range of cells of v_size_el length each into tmp_buf and write out
-                for (const auto & range : cell_ranges) {
-                    const size_t range_size = range.second - range.first;
-                    const size_t src_offset = (range.first + j * kv_size) * v_size_el;
-                    tmp_buf.resize(range_size * v_size_el);
-                    ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), src_offset, tmp_buf.size());
-                    data_ctx.write(tmp_buf.data(), tmp_buf.size());
-                }
-            }
-        }
-    }
-
-    return data_ctx.get_size_written();
-}
-
-size_t llama_state_seq_get_data(struct llama_context* ctx, uint8_t* dst, llama_seq_id seq_id) {
-    llama_data_buffer_context data_ctx(dst);
-    return llama_state_seq_get_data_internal(ctx, data_ctx, seq_id);
-}
-
-size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, llama_seq_id dest_seq_id) {
-    llama_synchronize(ctx);
-
-    auto & kv_self = ctx->kv_self;
-    GGML_ASSERT(!kv_self.recurrent); // not implemented
-
-    // Wipe the slot
-    llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
-
-    const uint8_t * inp = src;
-
-    // Read size of size_t
-    uint32_t size_t_size;
-    memcpy(&size_t_size, inp, sizeof(size_t_size));
-    inp += sizeof(size_t_size);
-    if (size_t_size != sizeof(size_t)) {
-        LLAMA_LOG_ERROR("%s: size_t size mismatch\n", __func__);
-        return 0;
-    }
-
-    // Read the cell count
-    uint32_t cell_count;
-    memcpy(&cell_count, inp, sizeof(cell_count));
-    inp += sizeof(cell_count);
-
-    // Read the layer count
-    uint32_t n_layer_ref;
-    memcpy(&n_layer_ref, inp, sizeof(n_layer_ref));
-    inp += sizeof(n_layer_ref);
-
-    // Read n_embd_v_gqa
-    uint32_t n_embd_v_gqa_ref;
-    memcpy(&n_embd_v_gqa_ref, inp, sizeof(n_embd_v_gqa_ref));
-    inp += sizeof(n_embd_v_gqa_ref);
-
-    // Sanity check model compatibility
-    const auto & hparams = ctx->model.hparams;
-    const uint32_t n_layer = hparams.n_layer;
-    const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s();
-    const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s();
-    if (n_layer != n_layer_ref) {
-        LLAMA_LOG_ERROR("%s: mismatched n_layer (%d != %d)\n", __func__, n_layer, n_layer_ref);
-        return 0;
-    }
-    if (n_embd_v_gqa != n_embd_v_gqa_ref) {
-        LLAMA_LOG_ERROR("%s: mismatched n_embd_v_gqa (%d != %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref);
-        return 0;
-    }
-
-    // Allocate the new cells for the slot
-    if (cell_count) {
-        llama_batch batch = llama_batch_init(cell_count, 0, 1);
-        batch.n_tokens = cell_count;
-        for (uint32_t i = 0; i < cell_count; ++i) {
-            llama_pos pos;
-            memcpy(&pos, inp, sizeof(pos));
-            inp += sizeof(pos);
-
-            batch.pos[i] = pos;
-            batch.n_seq_id[i] = 1;
-            batch.seq_id[i][0] = dest_seq_id;
-        }
-        if (!llama_kv_cache_find_slot(kv_self, batch)) {
-            llama_batch_free(batch);
-            LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
-            return 0;
-        }
-
-        // DEBUG CHECK: kv_self.head should be our first cell, kv_self.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
-        // Assume that this is one contiguous block of cells
-        GGML_ASSERT(kv_self.head + cell_count <= kv_self.size);
-        GGML_ASSERT(kv_self.cells[kv_self.head].pos == batch.pos[0]);
-        GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].pos == batch.pos[cell_count - 1]);
-        GGML_ASSERT(kv_self.cells[kv_self.head].has_seq_id(dest_seq_id));
-        GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].has_seq_id(dest_seq_id));
-
-        // Cleanup
-        llama_batch_free(batch);
-    }
-
-    const uint32_t kv_size = kv_self.size;
-    const uint32_t kv_head = kv_self.head;
-
-    // For each layer, read the keys for each cell, one row is one cell, read as one contiguous blo
-    for (int il = 0; il < (int)n_layer; ++il) {
-        // Read type of key
-        int32_t k_type_i_ref;
-        memcpy(&k_type_i_ref, inp, sizeof(k_type_i_ref));
-        inp += sizeof(k_type_i_ref);
-        const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type;
-        if (k_type_i != k_type_i_ref) {
-            llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
-            LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
-            return 0;
-        }
-
-        // Read row size of key
-        size_t k_size_row_ref;
-        memcpy(&k_size_row_ref, inp, sizeof(k_size_row_ref));
-        inp += sizeof(k_size_row_ref);
-        const size_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa);
-        if (k_size_row != k_size_row_ref) {
-            llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
-            LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, k_size_row_ref, il);
-            return 0;
-        }
-
-        if (cell_count) {
-            // Read and set the keys for the whole cell range
-            ggml_backend_tensor_set(kv_self.k_l[il], inp, kv_head * k_size_row, cell_count * k_size_row);
-            inp += cell_count * k_size_row;
-        }
-    }
-
-    // TODO: simplify, reduce copy-paste
-    if (!kv_self.v_trans) {
-        for (int il = 0; il < (int)n_layer; ++il) {
-            // Read type of value
-            int32_t v_type_i_ref;
-            memcpy(&v_type_i_ref, inp, sizeof(v_type_i_ref));
-            inp += sizeof(v_type_i_ref);
-            const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
-            if (v_type_i != v_type_i_ref) {
-                llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
-                LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
-                return 0;
-            }
-
-            // Read row size of value
-            size_t v_size_row_ref;
-            memcpy(&v_size_row_ref, inp, sizeof(v_size_row_ref));
-            inp += sizeof(v_size_row_ref);
-            const size_t v_size_row = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
-            if (v_size_row != v_size_row_ref) {
-                llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
-                LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, v_size_row_ref, il);
-                return 0;
-            }
-
-            if (cell_count) {
-                // Read and set the values for the whole cell range
-                ggml_backend_tensor_set(kv_self.v_l[il], inp, kv_head * v_size_row, cell_count * v_size_row);
-                inp += cell_count * v_size_row;
-            }
-        }
-    } else {
-        // For each layer, read the values for each cell (transposed)
-        for (int il = 0; il < (int)n_layer; ++il) {
-            // Read type of value
-            int32_t v_type_i_ref;
-            memcpy(&v_type_i_ref, inp, sizeof(v_type_i_ref));
-            inp += sizeof(v_type_i_ref);
-            const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
-            if (v_type_i != v_type_i_ref) {
-                llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
-                LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
-                return 0;
-            }
-
-            // Read element size of value
-            size_t v_size_el_ref;
-            memcpy(&v_size_el_ref, inp, sizeof(v_size_el_ref));
-            inp += sizeof(v_size_el_ref);
-            const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type);
-            if (v_size_el != v_size_el_ref) {
-                llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
-                LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, v_size_el_ref, il);
-                return 0;
-            }
-
-            if (cell_count) {
-                // For each row in the transposed matrix, read the values for the whole cell range
-                for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
-                    const size_t dst_offset = (kv_head + j * kv_size) * v_size_el;
-                    ggml_backend_tensor_set(kv_self.v_l[il], inp, dst_offset, cell_count * v_size_el);
-                    inp += cell_count * v_size_el;
-                }
-            }
-        }
-    }
-
-    const size_t nread = inp - src;
-
-    return nread;
-}
-
-static size_t llama_state_seq_save_file_internal(struct llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) {
-    llama_file file(filepath, "wb");
-
-    file.write_u32(LLAMA_STATE_SEQ_MAGIC);
-    file.write_u32(LLAMA_STATE_SEQ_VERSION);
-
-    // save the prompt
-    file.write_u32((uint32_t)n_token_count);
-    file.write_raw(tokens, sizeof(llama_token) * n_token_count);
-
-    // save the context state using stream saving
-    llama_data_file_context data_ctx(&file);
-    llama_state_seq_get_data_internal(ctx, data_ctx, seq_id);
-
-    const size_t res = file.tell();
-    GGML_ASSERT(res == sizeof(uint32_t) * 3 + sizeof(llama_token) * n_token_count + data_ctx.get_size_written());
-    return res;
-}
-
-static size_t llama_state_seq_load_file_internal(struct llama_context * ctx, const char * filepath, llama_seq_id dest_seq_id, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
-    llama_file file(filepath, "rb");
-
-    // version checks
-    {
-        const uint32_t magic   = file.read_u32();
-        const uint32_t version = file.read_u32();
-
-        if (magic != LLAMA_STATE_SEQ_MAGIC || version != LLAMA_STATE_SEQ_VERSION) {
-            LLAMA_LOG_ERROR("%s: unknown (magic, version) for sequence state file: %08x, %08x\n", __func__, magic, version);
-            return 0;
-        }
-    }
-
-    // load the prompt
-    {
-        const uint32_t n_token_count = file.read_u32();
-
-        if (n_token_count > n_token_capacity) {
-            LLAMA_LOG_ERROR("%s: token count in sequence state file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity);
-            return 0;
-        }
-
-        file.read_raw(tokens_out, sizeof(llama_token) * n_token_count);
-        *n_token_count_out = n_token_count;
-    }
-
-    // restore the context state
-    {
-        const size_t state_size = file.size - file.tell();
-        std::vector state_data(state_size);
-        file.read_raw(state_data.data(), state_size);
-        const size_t nread = llama_state_seq_set_data(ctx, state_data.data(), dest_seq_id);
-        if (!nread) {
-            LLAMA_LOG_ERROR("%s: failed to restore sequence state\n", __func__);
-            return 0;
-        }
-        GGML_ASSERT(nread <= state_size);
-        GGML_ASSERT(nread + sizeof(uint32_t) * 3 + sizeof(llama_token) * *n_token_count_out == file.tell());
-    }
-
-    return file.tell();
-}
-
-size_t llama_state_seq_save_file(struct llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) {
-    try {
-        return llama_state_seq_save_file_internal(ctx, filepath, seq_id, tokens, n_token_count);
-    } catch (const std::exception & err) {
-        LLAMA_LOG_ERROR("error saving sequence state file: %s\n", err.what());
-        return 0;
-    }
-}
-
-size_t llama_state_seq_load_file(struct llama_context * ctx, const char * filepath, llama_seq_id dest_seq_id, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
-    try {
-        return llama_state_seq_load_file_internal(ctx, filepath, dest_seq_id, tokens_out, n_token_capacity, n_token_count_out);
-    } catch (const std::exception & err) {
-        LLAMA_LOG_ERROR("error loading sequence state file: %s\n", err.what());
-        return 0;
-    }
-}
-
-void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch) {
-    ctx->cparams.n_threads       = n_threads;
-    ctx->cparams.n_threads_batch = n_threads_batch;
-}
-
-uint32_t llama_n_threads(struct llama_context * ctx) {
-    return ctx->cparams.n_threads;
-}
-
-uint32_t llama_n_threads_batch(struct llama_context * ctx) {
-    return ctx->cparams.n_threads_batch;
-}
-
-void llama_set_abort_callback(struct llama_context * ctx, bool (*abort_callback)(void * data), void * abort_callback_data) {
-    ctx->abort_callback      = abort_callback;
-    ctx->abort_callback_data = abort_callback_data;
-}
-
-void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn) {
-    ctx->cparams.causal_attn = causal_attn;
-}
-
-struct llama_batch llama_batch_get_one(
-             llama_token * tokens,
-                 int32_t   n_tokens,
-               llama_pos   pos_0,
-            llama_seq_id   seq_id) {
-    return {
-        /*n_tokens       =*/ n_tokens,
-        /*tokens         =*/ tokens,
-        /*embd           =*/ nullptr,
-        /*pos            =*/ nullptr,
-        /*n_seq_id       =*/ nullptr,
-        /*seq_id         =*/ nullptr,
-        /*logits         =*/ nullptr,
-        /*all_pos_0      =*/ pos_0,
-        /*all_pos_1      =*/ 1,
-        /*all_seq_id     =*/ seq_id,
-    };
-}
-
-struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) {
-    llama_batch batch = { 0, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, 0, 0, 0, };
-
-    if (embd) {
-        batch.embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd);
-    } else {
-        batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc);
-    }
-
-    batch.pos      = (llama_pos *)     malloc(sizeof(llama_pos)      * n_tokens_alloc);
-    batch.n_seq_id = (int32_t *)       malloc(sizeof(int32_t)        * n_tokens_alloc);
-    batch.seq_id   = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * (n_tokens_alloc + 1));
-    for (int i = 0; i < n_tokens_alloc; ++i) {
-        batch.seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
-    }
-    batch.seq_id[n_tokens_alloc] = nullptr;
-
-    batch.logits   = (int8_t *)        malloc(sizeof(int8_t)         * n_tokens_alloc);
-
-    return batch;
-}
-
-void llama_batch_free(struct llama_batch batch) {
-    if (batch.token)    free(batch.token);
-    if (batch.embd)     free(batch.embd);
-    if (batch.pos)      free(batch.pos);
-    if (batch.n_seq_id) free(batch.n_seq_id);
-    if (batch.seq_id) {
-        for (int i = 0; batch.seq_id[i] != nullptr; ++i) {
-            free(batch.seq_id[i]);
-        }
-        free(batch.seq_id);
-    }
-    if (batch.logits)   free(batch.logits);
-}
-
-int32_t llama_decode(
-        struct llama_context * ctx,
-          struct llama_batch   batch) {
-    const int ret = llama_decode_internal(*ctx, batch);
-    if (ret < 0) {
-        LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
-    }
-
-    return ret;
-}
-
-void llama_synchronize(struct llama_context * ctx) {
-    ggml_backend_sched_synchronize(ctx->sched);
-
-    // FIXME: if multiple single tokens are evaluated without a synchronization,
-    // the stats will be added to the prompt evaluation stats
-    // this should only happen when using batch size 1 to evaluate a batch
-
-    // add the evaluation to the stats
-    if (ctx->n_queued_tokens == 1) {
-        ctx->t_eval_us += ggml_time_us() - ctx->t_compute_start_us;
-        ctx->n_eval++;
-    } else if (ctx->n_queued_tokens > 1) {
-        ctx->t_p_eval_us += ggml_time_us() - ctx->t_compute_start_us;
-        ctx->n_p_eval += ctx->n_queued_tokens;
-    }
-
-    // get a more accurate load time, upon first eval
-    if (ctx->n_queued_tokens > 0 && !ctx->has_evaluated_once) {
-        ctx->t_load_us = ggml_time_us() - ctx->t_start_us;
-        ctx->has_evaluated_once = true;
-    }
-
-    ctx->n_queued_tokens = 0;
-    ctx->t_compute_start_us = 0;
-}
-
-float * llama_get_logits(struct llama_context * ctx) {
-    llama_synchronize(ctx);
-
-    return ctx->logits;
-}
-
-float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
-    int32_t j = -1;
-    llama_synchronize(ctx);
-
-    try {
-        if (ctx->logits == nullptr) {
-            throw std::runtime_error("no logits");
-        }
-
-        if (i < 0) {
-            j = ctx->n_outputs + i;
-            if (j < 0) {
-                throw std::runtime_error(format("negative index out of range [0, %d)", ctx->n_outputs));
-            }
-        } else if ((size_t) i >= ctx->output_ids.size()) {
-            throw std::runtime_error(format("out of range [0, %lu)", ctx->output_ids.size()));
-        } else {
-            j = ctx->output_ids[i];
-        }
-
-        if (j < 0) {
-            throw std::runtime_error(format("batch.logits[%d] != true", i));
-        }
-        if (j >= ctx->n_outputs) {
-            // This should not happen
-            throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs));
-        }
-
-        return ctx->logits + j*ctx->model.hparams.n_vocab;
-    } catch (const std::exception & err) {
-        LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what());
-#ifndef NDEBUG
-        GGML_ASSERT(false);
-#endif
-        return nullptr;
-    }
-}
-
-float * llama_get_embeddings(struct llama_context * ctx) {
-    llama_synchronize(ctx);
-
-    return ctx->embd;
-}
-
-float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) {
-    int32_t j = -1;
-
-    llama_synchronize(ctx);
-
-    try {
-        if (ctx->embd == nullptr) {
-            throw std::runtime_error("no embeddings");
-        }
-
-        if (i < 0) {
-            j = ctx->n_outputs + i;
-            if (j < 0) {
-                throw std::runtime_error(format("negative index out of range [0, %d)", ctx->n_outputs));
-            }
-        } else if ((size_t) i >= ctx->output_ids.size()) {
-            throw std::runtime_error(format("out of range [0, %lu)", ctx->output_ids.size()));
-        } else {
-            j = ctx->output_ids[i];
-        }
-
-        if (j < 0) {
-            throw std::runtime_error(format("batch.logits[%d] != true", i));
-        }
-        if (j >= ctx->n_outputs) {
-            // This should not happen
-            throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs));
-        }
-
-        return ctx->embd + j*ctx->model.hparams.n_embd;
-    } catch (const std::exception & err) {
-        LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what());
-#ifndef NDEBUG
-        GGML_ASSERT(false);
-#endif
-        return nullptr;
-    }
-}
-
-float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id) {
-    llama_synchronize(ctx);
-
-    auto it = ctx->embd_seq.find(seq_id);
-    if (it == ctx->embd_seq.end()) {
-        return nullptr;
-    }
-
-    return it->second.data();
-}
-
-const char * llama_token_get_text(const struct llama_model * model, llama_token token) {
-    GGML_ASSERT(model->vocab.type != LLAMA_VOCAB_TYPE_NONE);
-    return model->vocab.id_to_token[token].text.c_str();
-}
-
-float llama_token_get_score(const struct llama_model * model, llama_token token) {
-    GGML_ASSERT(model->vocab.type != LLAMA_VOCAB_TYPE_NONE);
-    return model->vocab.id_to_token[token].score;
-}
-
-llama_token_attr llama_token_get_attr(const struct llama_model * model, llama_token token) {
-    GGML_ASSERT(model->vocab.type != LLAMA_VOCAB_TYPE_NONE);
-    return model->vocab.id_to_token[token].attr;
-}
-
-bool llama_token_is_eog(const struct llama_model * model, llama_token token) {
-    return token != -1 && (
-        token == llama_token_eos(model) ||
-        token == llama_token_eot(model)
-    );
-}
-
-bool llama_token_is_control(const struct llama_model * model, llama_token token) {
-    return llama_is_control_token(model->vocab, token);
-}
-
-llama_token llama_token_bos(const struct llama_model * model) {
-    return model->vocab.special_bos_id;
-}
-
-llama_token llama_token_eos(const struct llama_model * model) {
-    return model->vocab.special_eos_id;
-}
-
-llama_token llama_token_cls(const struct llama_model * model) {
-    return model->vocab.special_cls_id;
-}
-
-llama_token llama_token_sep(const struct llama_model * model) {
-    return model->vocab.special_sep_id;
-}
-
-llama_token llama_token_nl(const struct llama_model * model) {
-    return model->vocab.linefeed_id;
-}
-
-int32_t llama_add_bos_token(const struct llama_model * model) {
-    return model->vocab.special_add_bos;
-}
-
-int32_t llama_add_eos_token(const struct llama_model * model) {
-    return model->vocab.special_add_eos;
-}
-
-llama_token llama_token_prefix(const struct llama_model * model) {
-    return model->vocab.special_prefix_id;
-}
-
-llama_token llama_token_middle(const struct llama_model * model) {
-    return model->vocab.special_middle_id;
-}
-
-llama_token llama_token_suffix(const struct llama_model * model) {
-    return model->vocab.special_suffix_id;
-}
-
-llama_token llama_token_eot(const struct llama_model * model) {
-    return model->vocab.special_eot_id;
-}
-
-int32_t llama_tokenize(
-    const struct llama_model * model,
-                  const char * text,
-                     int32_t   text_len,
-                 llama_token * tokens,
-                     int32_t   n_tokens_max,
-                        bool   add_special,
-                        bool   parse_special) {
-    auto res = llama_tokenize_internal(model->vocab, std::string(text, text_len), add_special, parse_special);
-
-    if (n_tokens_max < (int) res.size()) {
-        // LLAMA_LOG_ERROR("%s: too many tokens\n", __func__);
-        return -((int) res.size());
-    }
-
-    for (size_t i = 0; i < res.size(); i++) {
-        tokens[i] = res[i];
-    }
-
-    return res.size();
-}
-
-static std::string llama_decode_text(const std::string & text) {
-    std::string decoded_text;
-
-    const auto cpts = unicode_cpts_from_utf8(text);
-    for (const auto cpt : cpts) {
-        const auto utf8 = unicode_cpt_to_utf8(cpt);
-        try {
-            decoded_text += unicode_utf8_to_byte(utf8);
-        } catch (const std::out_of_range & e) {
-            decoded_text += "[UNK_BYTE_0x";
-            for (const auto c : utf8) {
-                decoded_text += format("%02x", (uint8_t) c);
-            }
-            decoded_text += text + "]";
-        }
-    }
-
-    return decoded_text;
-}
-
-// does not write null-terminator to buf
-int32_t llama_token_to_piece(const struct llama_model * model, llama_token token, char * buf, int32_t length, bool special) {
-    // ref: https://github.com/ggerganov/llama.cpp/pull/7587#discussion_r1620983843
-    if (!special && llama_is_control_token(model->vocab, token)) {
-        return 0;
-    }
-
-    // if we have a cache - use it
-    {
-        const auto & cache = model->vocab.cache_token_to_piece;
-
-        if (!cache.empty()) {
-            const auto & res = cache.at(token);
-            if (length < (int) res.size()) {
-                return -(int) res.size();
-            }
-            memcpy(buf, res.c_str(), res.size());
-            return res.size();
-        }
-    }
-
-    if (0 <= token && token < llama_n_vocab(model)) {
-        switch (llama_vocab_get_type(model->vocab)) {
-            case LLAMA_VOCAB_TYPE_WPM:
-            case LLAMA_VOCAB_TYPE_SPM: {
-                // NOTE: we accept all unsupported token types,
-                // suppressing them like CONTROL tokens.
-                if (llama_is_normal_token(model->vocab, token)) {
-                    std::string result = model->vocab.id_to_token[token].text;
-                    llama_unescape_whitespace(result);
-                    if (length < (int) result.length()) {
-                        return -(int) result.length();
-                    }
-                    memcpy(buf, result.c_str(), result.length());
-                    return result.length();
-                } else if (
-                        (llama_is_user_defined_token(model->vocab, token)) ||
-                        (llama_is_control_token     (model->vocab, token) && special)) {
-                    std::string result = model->vocab.id_to_token[token].text;
-                    if (length < (int) result.length()) {
-                        return -(int) result.length();
-                    }
-                    memcpy(buf, result.c_str(), result.length());
-                    return result.length();
-                } else if (llama_is_unknown_token(model->vocab, token)) { // NOLINT
-                    if (length < 3) {
-                        return -3;
-                    }
-                    memcpy(buf, "\xe2\x96\x85", 3);
-                    return 3;
-                } else if (llama_is_byte_token(model->vocab, token)) {
-                    if (length < 1) {
-                        return -1;
-                    }
-                    buf[0] = llama_token_to_byte(model->vocab, token);
-                    return 1;
-                }
-                break;
-            }
-            case LLAMA_VOCAB_TYPE_BPE: {
-                // NOTE: we accept all unsupported token types,
-                // suppressing them like CONTROL tokens.
-                if (llama_is_normal_token(model->vocab, token)) {
-                    std::string result = model->vocab.id_to_token[token].text;
-                    result = llama_decode_text(result);
-                    if (length < (int) result.length()) {
-                        return -(int) result.length();
-                    }
-                    memcpy(buf, result.c_str(), result.length());
-                    return result.length();
-                } else if (
-                        (llama_is_user_defined_token(model->vocab, token)) ||
-                        (llama_is_control_token     (model->vocab, token) && special)) {
-                    std::string result = model->vocab.id_to_token[token].text;
-                    if (length < (int) result.length()) {
-                        return -(int) result.length();
-                    }
-                    memcpy(buf, result.c_str(), result.length());
-                    return result.length();
-                }
-                break;
-            }
-            default:
-                GGML_ASSERT(false);
-        }
-    }
-    return 0;
-}
-
-// trim whitespace from the beginning and end of a string
-static std::string trim(const std::string & str) {
-    size_t start = 0;
-    size_t end = str.size();
-    while (start < end && isspace(str[start])) {
-        start += 1;
-    }
-    while (end > start && isspace(str[end - 1])) {
-        end -= 1;
-    }
-    return str.substr(start, end - start);
-}
-
-// Simple version of "llama_apply_chat_template" that only works with strings
-// This function uses heuristic checks to determine commonly used template. It is not a jinja parser.
-static int32_t llama_chat_apply_template_internal(
-    const std::string & tmpl,
-    const std::vector & chat,
-    std::string & dest, bool add_ass) {
-    // Taken from the research: https://github.com/ggerganov/llama.cpp/issues/5527
-    std::stringstream ss;
-    if (tmpl == "chatml" || tmpl.find("<|im_start|>") != std::string::npos) {
-        // chatml template
-        for (auto message : chat) {
-            ss << "<|im_start|>" << message->role << "\n" << message->content << "<|im_end|>\n";
-        }
-        if (add_ass) {
-            ss << "<|im_start|>assistant\n";
-        }
-    } else if (tmpl == "llama2" || tmpl.find("[INST]") != std::string::npos) {
-        // llama2 template and its variants
-        // [variant] support system message
-        bool support_system_message = tmpl.find("<>") != std::string::npos;
-        // [variant] space before + after response
-        bool space_around_response = tmpl.find("' ' + eos_token") != std::string::npos;
-        // [variant] add BOS inside history
-        bool add_bos_inside_history = tmpl.find("bos_token + '[INST]") != std::string::npos;
-        // [variant] trim spaces from the input message
-        bool strip_message = tmpl.find("content.strip()") != std::string::npos;
-        // construct the prompt
-        bool is_inside_turn = true; // skip BOS at the beginning
-        ss << "[INST] ";
-        for (auto message : chat) {
-            std::string content = strip_message ? trim(message->content) : message->content;
-            std::string role(message->role);
-            if (!is_inside_turn) {
-                is_inside_turn = true;
-                ss << (add_bos_inside_history ? "[INST] " : "[INST] ");
-            }
-            if (role == "system") {
-                if (support_system_message) {
-                    ss << "<>\n" << content << "\n<>\n\n";
-                } else {
-                    // if the model does not support system message, we still include it in the first message, but without <>
-                    ss << content << "\n";
-                }
-            } else if (role == "user") {
-                ss << content << " [/INST]";
-            } else {
-                ss << (space_around_response ? " " : "") << content << (space_around_response ? " " : "") << "";
-                is_inside_turn = false;
-            }
-        }
-        // llama2 templates seem to not care about "add_generation_prompt"
-    } else if (tmpl == "phi3" || (tmpl.find("<|assistant|>") != std::string::npos && tmpl.find("<|end|>") != std::string::npos)) {
-        // Phi 3
-        for (auto message : chat) {
-            std::string role(message->role);
-            ss << "<|" << role << "|>\n" << message->content << "<|end|>\n";
-        }
-        if (add_ass) {
-            ss << "<|assistant|>\n";
-        }
-    } else if (tmpl == "zephyr" || tmpl.find("<|user|>") != std::string::npos) {
-        // zephyr template
-        for (auto message : chat) {
-            ss << "<|" << message->role << "|>" << "\n" << message->content << "<|endoftext|>\n";
-        }
-        if (add_ass) {
-            ss << "<|assistant|>\n";
-        }
-    } else if (tmpl == "monarch" || tmpl.find("bos_token + message['role']") != std::string::npos) {
-        // mlabonne/AlphaMonarch-7B template (the  is included inside history)
-        for (auto message : chat) {
-            std::string bos = (message == chat.front()) ? "" : ""; // skip BOS for first message
-            ss << bos << message->role << "\n" << message->content << "\n";
-        }
-        if (add_ass) {
-            ss << "assistant\n";
-        }
-    } else if (tmpl == "gemma" || tmpl.find("") != std::string::npos) {
-        // google/gemma-7b-it
-        std::string system_prompt = "";
-        for (auto message : chat) {
-            std::string role(message->role);
-            if (role == "system") {
-                // there is no system message for gemma, but we will merge it with user prompt, so nothing is broken
-                system_prompt = trim(message->content);
-                continue;
-            }
-            // in gemma, "assistant" is "model"
-            role = role == "assistant" ? "model" : message->role;
-            ss << "" << role << "\n";
-            if (!system_prompt.empty() && role != "model") {
-                ss << system_prompt << "\n\n";
-                system_prompt = "";
-            }
-            ss << trim(message->content) << "\n";
-        }
-        if (add_ass) {
-            ss << "model\n";
-        }
-    } else if (tmpl == "orion" || tmpl.find("'\\n\\nAssistant: ' + eos_token") != std::string::npos) {
-        // OrionStarAI/Orion-14B-Chat
-        std::string system_prompt = "";
-        for (auto message : chat) {
-            std::string role(message->role);
-            if (role == "system") {
-                // there is no system message support, we will merge it with user prompt
-                system_prompt = message->content;
-                continue;
-            } else if (role == "user") {
-                ss << "Human: ";
-                if (!system_prompt.empty()) {
-                    ss << system_prompt << "\n\n";
-                    system_prompt = "";
-                }
-                ss << message->content << "\n\nAssistant: ";
-            } else {
-                ss << message->content << "";
-            }
-        }
-    } else if (tmpl == "openchat" || tmpl.find("GPT4 Correct ") != std::string::npos) {
-        // openchat/openchat-3.5-0106,
-        for (auto message : chat) {
-            std::string role(message->role);
-            if (role == "system") {
-                ss << message->content << "<|end_of_turn|>";
-            } else {
-                role[0] = toupper(role[0]);
-                ss << "GPT4 Correct " << role << ": " << message->content << "<|end_of_turn|>";
-            }
-        }
-        if (add_ass) {
-            ss << "GPT4 Correct Assistant:";
-        }
-    } else if (tmpl == "vicuna" || tmpl == "vicuna-orca" || (tmpl.find("USER: ") != std::string::npos && tmpl.find("ASSISTANT: ") != std::string::npos)) {
-        // eachadea/vicuna-13b-1.1 (and Orca variant)
-        for (auto message : chat) {
-            std::string role(message->role);
-            if (role == "system") {
-                // Orca-Vicuna variant uses a system prefix
-                if (tmpl == "vicuna-orca" || tmpl.find("SYSTEM: ") != std::string::npos) {
-                    ss << "SYSTEM: " << message->content << "\n";
-                } else {
-                    ss << message->content << "\n\n";
-                }
-            } else if (role == "user") {
-                ss << "USER: " << message->content << "\n";
-            } else if (role == "assistant") {
-                ss << "ASSISTANT: " << message->content << "\n";
-            }
-        }
-        if (add_ass) {
-            ss << "ASSISTANT:";
-        }
-    } else if (tmpl == "deepseek" || (tmpl.find("### Instruction:") != std::string::npos && tmpl.find("<|EOT|>") != std::string::npos)) {
-        // deepseek-ai/deepseek-coder-33b-instruct
-        for (auto message : chat) {
-            std::string role(message->role);
-            if (role == "system") {
-                ss << message->content;
-            } else if (role == "user") {
-                ss << "### Instruction:\n" << message->content << "\n";
-            } else if (role == "assistant") {
-                ss << "### Response:\n" << message->content << "\n<|EOT|>\n";
-            }
-        }
-        if (add_ass) {
-            ss << "### Response:\n";
-        }
-    } else if (tmpl == "command-r" || (tmpl.find("<|START_OF_TURN_TOKEN|>") != std::string::npos && tmpl.find("<|USER_TOKEN|>") != std::string::npos)) {
-        // CohereForAI/c4ai-command-r-plus
-        for (auto message : chat) {
-            std::string role(message->role);
-            if (role == "system") {
-                ss << "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>" << trim(message->content) << "<|END_OF_TURN_TOKEN|>";
-            } else if (role == "user") {
-                ss << "<|START_OF_TURN_TOKEN|><|USER_TOKEN|>" << trim(message->content) << "<|END_OF_TURN_TOKEN|>";
-            } else if (role == "assistant") {
-                ss << "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>" << trim(message->content) << "<|END_OF_TURN_TOKEN|>";
-            }
-        }
-        if (add_ass) {
-            ss << "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>";
-        }
-    } else if (tmpl == "llama3" || (tmpl.find("<|start_header_id|>") != std::string::npos && tmpl.find("<|end_header_id|>") != std::string::npos)) {
-        // Llama 3
-        for (auto message : chat) {
-            std::string role(message->role);
-            ss << "<|start_header_id|>" << role << "<|end_header_id|>\n\n" << trim(message->content) << "<|eot_id|>";
-        }
-        if (add_ass) {
-            ss << "<|start_header_id|>assistant<|end_header_id|>\n\n";
-        }
-    } else {
-        // template not supported
-        return -1;
-    }
-    dest = ss.str();
-    return dest.size();
-}
-
-LLAMA_API int32_t llama_chat_apply_template(
-                const struct llama_model * model,
-                              const char * tmpl,
-         const struct llama_chat_message * chat,
-                                  size_t   n_msg,
-                                    bool   add_ass,
-                                    char * buf,
-                                 int32_t   length) {
-    std::string curr_tmpl(tmpl == nullptr ? "" : tmpl);
-    if (tmpl == nullptr) {
-        GGML_ASSERT(model != nullptr);
-        // load template from model
-        std::vector model_template(2048, 0); // longest known template is about 1200 bytes
-        std::string template_key = "tokenizer.chat_template";
-        int32_t res = llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), model_template.size());
-        if (res < 0) {
-            // worst case: there is no information about template, we will use chatml by default
-            curr_tmpl = "chatml"; // see llama_chat_apply_template_internal
-        } else {
-            curr_tmpl = std::string(model_template.data(), model_template.size());
-        }
-    }
-
-    // format the chat to string
-    std::vector chat_vec;
-    chat_vec.resize(n_msg);
-    for (size_t i = 0; i < n_msg; i++) {
-        chat_vec[i] = &chat[i];
-    }
-
-    std::string formatted_chat;
-    int32_t res = llama_chat_apply_template_internal(curr_tmpl, chat_vec, formatted_chat, add_ass);
-    if (res < 0) {
-        return res;
-    }
-    if (buf && length > 0) {
-        strncpy(buf, formatted_chat.c_str(), length);
-    }
-    return res;
-}
-
-LLAMA_API int llama_split_path(char * split_path, size_t maxlen, const char * path_prefix, int split_no, int split_count) {
-    static const char * const SPLIT_PATH_FORMAT = "%s-%05d-of-%05d.gguf";
-    if (snprintf(split_path, maxlen, SPLIT_PATH_FORMAT, path_prefix, split_no + 1, split_count)) {
-        return strlen(split_path);
-    }
-    return 0;
-}
-
-int llama_split_prefix(char * dest, size_t maxlen, const char * split_path, int split_no, int split_count) {
-    std::string str_split_path(split_path);
-    char postfix[32];
-    snprintf(postfix, 32, "-%05d-of-%05d.gguf", split_no + 1, split_count);
-    std::string str_postfix(postfix);
-
-    // check if dest ends with postfix
-    int size_prefix = str_split_path.size() - str_postfix.size();
-    if (size_prefix > 0 && str_split_path.find(str_postfix, size_prefix) != std::string::npos) {
-        snprintf(dest, std::min((size_t) size_prefix + 1, maxlen), "%s", split_path);
-        return size_prefix;
-    }
-
-    return 0;
-}
-
-struct llama_timings llama_get_timings(struct llama_context * ctx) {
-    struct llama_timings result = {
-        /*.t_start_ms  =*/ 1e-3 * ctx->t_start_us,
-        /*.t_end_ms    =*/ 1.00 * ggml_time_ms(),
-        /*.t_load_ms   =*/ 1e-3 * ctx->t_load_us,
-        /*.t_sample_ms =*/ 1e-3 * ctx->t_sample_us,
-        /*.t_p_eval_ms =*/ 1e-3 * ctx->t_p_eval_us,
-        /*.t_eval_ms   =*/ 1e-3 * ctx->t_eval_us,
-
-        /*.n_sample =*/ std::max(1, ctx->n_sample),
-        /*.n_p_eval =*/ std::max(0, ctx->n_p_eval),
-        /*.n_eval   =*/ std::max(1, ctx->n_eval),
-    };
-
-    return result;
-}
-
-void llama_print_timings(struct llama_context * ctx) {
-    const llama_timings timings = llama_get_timings(ctx);
-
-    LLAMA_LOG_INFO("\n");
-    LLAMA_LOG_INFO("%s:        load time = %10.2f ms\n", __func__, timings.t_load_ms);
-    LLAMA_LOG_INFO("%s:      sample time = %10.2f ms / %5d runs   (%8.2f ms per token, %8.2f tokens per second)\n",
-            __func__, timings.t_sample_ms, timings.n_sample, timings.t_sample_ms / timings.n_sample, 1e3 / timings.t_sample_ms * timings.n_sample);
-    LLAMA_LOG_INFO("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n",
-            __func__, timings.t_p_eval_ms, timings.n_p_eval, timings.t_p_eval_ms / timings.n_p_eval, 1e3 / timings.t_p_eval_ms * timings.n_p_eval);
-    LLAMA_LOG_INFO("%s:        eval time = %10.2f ms / %5d runs   (%8.2f ms per token, %8.2f tokens per second)\n",
-            __func__, timings.t_eval_ms, timings.n_eval, timings.t_eval_ms / timings.n_eval, 1e3 / timings.t_eval_ms * timings.n_eval);
-    LLAMA_LOG_INFO("%s:       total time = %10.2f ms / %5d tokens\n", __func__, (timings.t_end_ms - timings.t_start_ms), (timings.n_p_eval + timings.n_eval));
-}
-
-void llama_reset_timings(struct llama_context * ctx) {
-    ctx->t_start_us = ggml_time_us();
-    ctx->t_sample_us = ctx->n_sample = 0;
-    ctx->t_eval_us   = ctx->n_eval   = 0;
-    ctx->t_p_eval_us = ctx->n_p_eval = 0;
-}
-
-const char * llama_print_system_info(void) {
-    static std::string s;
-
-    s  = "";
-    s += "AVX = "         + std::to_string(ggml_cpu_has_avx())         + " | ";
-    s += "AVX_VNNI = "    + std::to_string(ggml_cpu_has_avx_vnni())    + " | ";
-    s += "AVX2 = "        + std::to_string(ggml_cpu_has_avx2())        + " | ";
-    s += "AVX512 = "      + std::to_string(ggml_cpu_has_avx512())      + " | ";
-    s += "AVX512_VBMI = " + std::to_string(ggml_cpu_has_avx512_vbmi()) + " | ";
-    s += "AVX512_VNNI = " + std::to_string(ggml_cpu_has_avx512_vnni()) + " | ";
-    s += "AVX512_BF16 = " + std::to_string(ggml_cpu_has_avx512_bf16()) + " | ";
-    s += "FMA = "         + std::to_string(ggml_cpu_has_fma())         + " | ";
-    s += "NEON = "        + std::to_string(ggml_cpu_has_neon())        + " | ";
-    s += "SVE = "         + std::to_string(ggml_cpu_has_sve())         + " | ";
-    s += "ARM_FMA = "     + std::to_string(ggml_cpu_has_arm_fma())     + " | ";
-    s += "F16C = "        + std::to_string(ggml_cpu_has_f16c())        + " | ";
-    s += "FP16_VA = "     + std::to_string(ggml_cpu_has_fp16_va())     + " | ";
-    s += "WASM_SIMD = "   + std::to_string(ggml_cpu_has_wasm_simd())   + " | ";
-    s += "BLAS = "        + std::to_string(ggml_cpu_has_blas())        + " | ";
-    s += "SSE3 = "        + std::to_string(ggml_cpu_has_sse3())        + " | ";
-    s += "SSSE3 = "       + std::to_string(ggml_cpu_has_ssse3())       + " | ";
-    s += "VSX = "         + std::to_string(ggml_cpu_has_vsx())         + " | ";
-    s += "MATMUL_INT8 = " + std::to_string(ggml_cpu_has_matmul_int8()) + " | ";
-#ifdef GGML_USE_LLAMAFILE
-    s += "LLAMAFILE = 1 | ";
-#else
-    s += "LLAMAFILE = 0 | ";
-#endif
-
-    return s.c_str();
-}
-
-void llama_dump_timing_info_yaml(FILE * stream, const llama_context * ctx) {
-    fprintf(stream, "\n");
-    fprintf(stream, "###########\n");
-    fprintf(stream, "# Timings #\n");
-    fprintf(stream, "###########\n");
-    fprintf(stream, "\n");
-
-    fprintf(stream, "mst_eval: %.2f  # ms / token during generation\n",
-            1.0e-3 * ctx->t_eval_us / ctx->n_eval);
-    fprintf(stream, "mst_p_eval: %.2f  # ms / token during prompt processing\n",
-            1.0e-3 * ctx->t_p_eval_us / ctx->n_p_eval);
-    fprintf(stream, "mst_sample: %.2f  # ms / token during sampling\n",
-            1.0e-3 * ctx->t_sample_us / ctx->n_sample);
-    fprintf(stream, "n_eval: %d  # number of tokens generated (excluding the first one)\n", ctx->n_eval);
-    fprintf(stream, "n_p_eval: %d  # number of tokens processed in batches at the beginning\n", ctx->n_p_eval);
-    fprintf(stream, "n_sample: %d  # number of sampled tokens\n", ctx->n_sample);
-    fprintf(stream, "t_eval_us: %" PRId64 "  # total microseconds spent generating tokens\n", ctx->t_eval_us);
-    fprintf(stream, "t_load_us: %" PRId64 "  # total microseconds spent loading the model\n", ctx->t_load_us);
-    fprintf(stream, "t_p_eval_us: %" PRId64 "  # total microseconds spent prompt processing\n", ctx->t_p_eval_us);
-    fprintf(stream, "t_sample_us: %" PRId64 "  # total microseconds spent sampling\n", ctx->t_sample_us);
-    fprintf(stream, "ts_eval: %.2f  # tokens / second during generation\n",
-            1.0e6 * ctx->n_eval / ctx->t_eval_us);
-    fprintf(stream, "ts_p_eval: %.2f  # tokens / second during prompt processing\n",
-            1.0e6 * ctx->n_p_eval / ctx->t_p_eval_us);
-    fprintf(stream, "ts_sample: %.2f  # tokens / second during sampling\n",
-            1.0e6 * ctx->n_sample / ctx->t_sample_us);
-}
-
-// For internal test use
-const std::vector> & llama_internal_get_tensor_map(
-    struct llama_context * ctx
-) {
-    return ctx->model.tensors_by_name;
-}
-
-void llama_log_set(ggml_log_callback log_callback, void * user_data) {
-    g_state.log_callback = log_callback ? log_callback : llama_log_callback_default;
-    g_state.log_callback_user_data = user_data;
-#ifdef GGML_USE_METAL
-    ggml_backend_metal_log_set_callback(g_state.log_callback, g_state.log_callback_user_data);
-#elif defined(GGML_USE_CUDA)
-    ggml_backend_cuda_log_set_callback(g_state.log_callback, g_state.log_callback_user_data);
-#endif
-}
-
-static void llama_log_internal_v(ggml_log_level level, const char * format, va_list args) {
-    va_list args_copy;
-    va_copy(args_copy, args);
-    char buffer[128];
-    int len = vsnprintf(buffer, 128, format, args);
-    if (len < 128) {
-        g_state.log_callback(level, buffer, g_state.log_callback_user_data);
-    } else {
-        char* buffer2 = new char[len+1];
-        vsnprintf(buffer2, len+1, format, args_copy);
-        buffer2[len] = 0;
-        g_state.log_callback(level, buffer2, g_state.log_callback_user_data);
-        delete[] buffer2;
-    }
-    va_end(args_copy);
-}
-
-static void llama_log_internal(ggml_log_level level, const char * format, ...) {
-    va_list args;
-    va_start(args, format);
-    llama_log_internal_v(level, format, args);
-    va_end(args);
-}
-
-static void llama_log_callback_default(ggml_log_level level, const char * text, void * user_data) {
-    (void) level;
-    (void) user_data;
-    fputs(text, stderr);
-    fflush(stderr);
-}
+}
\ No newline at end of file
diff --git a/llama/llama.go b/llama/llama.go
index dbe60b6a..48469121 100644
--- a/llama/llama.go
+++ b/llama/llama.go
@@ -260,6 +260,7 @@ func (m *Model) TokenToPiece(token int) string {
 		C.int32_t(token),
 		(*C.char)(unsafe.Pointer(&buf[0])),
 		C.int32_t(12),
+		C.int32_t(0),
 		C.bool(true),
 	)
 	return strings.TrimRight(string(buf), "\x00")
diff --git a/llama/llama.h b/llama/llama.h
index 3e0355b2..469bf75e 100644
--- a/llama/llama.h
+++ b/llama/llama.h
@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *
@@ -59,17 +59,15 @@
 
 #define LLAMA_DEFAULT_SEED 0xFFFFFFFF
 
-#define LLAMA_MAX_RNG_STATE (64*1024)
-
 #define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla'
 #define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
 #define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq'
 
 #define LLAMA_SESSION_MAGIC   LLAMA_FILE_MAGIC_GGSN
-#define LLAMA_SESSION_VERSION 6
+#define LLAMA_SESSION_VERSION 8
 
 #define LLAMA_STATE_SEQ_MAGIC   LLAMA_FILE_MAGIC_GGSQ
-#define LLAMA_STATE_SEQ_VERSION 1
+#define LLAMA_STATE_SEQ_VERSION 2
 
 #ifdef __cplusplus
 extern "C" {
@@ -93,6 +91,7 @@ extern "C" {
         LLAMA_VOCAB_TYPE_SPM  = 1, // LLaMA tokenizer based on byte-level BPE with byte fallback
         LLAMA_VOCAB_TYPE_BPE  = 2, // GPT-2 tokenizer based on byte-level BPE
         LLAMA_VOCAB_TYPE_WPM  = 3, // BERT tokenizer based on WordPiece
+        LLAMA_VOCAB_TYPE_UGM  = 4, // T5 tokenizer based on Unigram
     };
 
     // pre-tokenization types
@@ -112,6 +111,14 @@ extern "C" {
         LLAMA_VOCAB_PRE_TYPE_OLMO           = 12,
         LLAMA_VOCAB_PRE_TYPE_DBRX           = 13,
         LLAMA_VOCAB_PRE_TYPE_SMAUG          = 14,
+        LLAMA_VOCAB_PRE_TYPE_PORO           = 15,
+        LLAMA_VOCAB_PRE_TYPE_CHATGLM3       = 16,
+        LLAMA_VOCAB_PRE_TYPE_CHATGLM4       = 17,
+        LLAMA_VOCAB_PRE_TYPE_VIKING         = 18,
+        LLAMA_VOCAB_PRE_TYPE_JAIS           = 19,
+        LLAMA_VOCAB_PRE_TYPE_TEKKEN         = 20,
+        LLAMA_VOCAB_PRE_TYPE_SMOLLM         = 21,
+        LLAMA_VOCAB_PRE_TYPE_CODESHELL      = 22,
     };
 
     // note: these values should be synchronized with ggml_rope
@@ -153,7 +160,7 @@ extern "C" {
         LLAMA_FTYPE_MOSTLY_F16           = 1,  // except 1d tensors
         LLAMA_FTYPE_MOSTLY_Q4_0          = 2,  // except 1d tensors
         LLAMA_FTYPE_MOSTLY_Q4_1          = 3,  // except 1d tensors
-        LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4,  // tok_embeddings.weight and output.weight are F16
+        // LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4,  // tok_embeddings.weight and output.weight are F16
         // LLAMA_FTYPE_MOSTLY_Q4_2       = 5,  // support has been removed
         // LLAMA_FTYPE_MOSTLY_Q4_3       = 6,  // support has been removed
         LLAMA_FTYPE_MOSTLY_Q8_0          = 7,  // except 1d tensors
@@ -182,6 +189,9 @@ extern "C" {
         LLAMA_FTYPE_MOSTLY_IQ4_XS        = 30, // except 1d tensors
         LLAMA_FTYPE_MOSTLY_IQ1_M         = 31, // except 1d tensors
         LLAMA_FTYPE_MOSTLY_BF16          = 32, // except 1d tensors
+        LLAMA_FTYPE_MOSTLY_Q4_0_4_4      = 33, // except 1d tensors
+        LLAMA_FTYPE_MOSTLY_Q4_0_4_8      = 34, // except 1d tensors
+        LLAMA_FTYPE_MOSTLY_Q4_0_8_8      = 35, // except 1d tensors
 
         LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
     };
@@ -199,6 +209,13 @@ extern "C" {
         LLAMA_POOLING_TYPE_NONE = 0,
         LLAMA_POOLING_TYPE_MEAN = 1,
         LLAMA_POOLING_TYPE_CLS  = 2,
+        LLAMA_POOLING_TYPE_LAST = 3,
+    };
+
+    enum llama_attention_type {
+        LLAMA_ATTENTION_TYPE_UNSPECIFIED = -1,
+        LLAMA_ATTENTION_TYPE_CAUSAL      = 0,
+        LLAMA_ATTENTION_TYPE_NON_CAUSAL  = 1,
     };
 
     enum llama_split_mode {
@@ -318,7 +335,7 @@ extern "C" {
 
         enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
         enum llama_pooling_type      pooling_type;      // whether to pool (sum) embedding results by sequence id
-                                                        // (ignored if no pooling layer)
+        enum llama_attention_type    attention_type;    // attention type to use for embeddings
 
         // ref: https://github.com/ggerganov/llama.cpp/pull/2054
         float    rope_freq_base;   // RoPE base frequency, 0 = from model
@@ -421,6 +438,9 @@ extern "C" {
         const char * content;
     } llama_chat_message;
 
+    // lora adapter
+    struct llama_lora_adapter;
+
     // Helpers for getting default parameters
     LLAMA_API struct llama_model_params llama_model_default_params(void);
     LLAMA_API struct llama_context_params llama_context_default_params(void);
@@ -507,24 +527,45 @@ extern "C" {
     // Get a llama model tensor
     LLAMA_API struct ggml_tensor * llama_get_model_tensor(struct llama_model * model, const char * name);
 
+    // Returns true if the model contains an encoder that requires llama_encode() call
+    LLAMA_API bool llama_model_has_encoder(const struct llama_model * model);
+
+    // For encoder-decoder models, this function returns id of the token that must be provided
+    // to the decoder to start generating output sequence. For other models, it returns -1.
+    LLAMA_API llama_token llama_model_decoder_start_token(const struct llama_model * model);
+
     // Returns 0 on success
     LLAMA_API uint32_t llama_model_quantize(
             const char * fname_inp,
             const char * fname_out,
             const llama_model_quantize_params * params);
 
-    // Apply a LoRA adapter to a loaded model
-    // path_base_model is the path to a higher quality model to use as a base for
-    // the layers modified by the adapter. Can be NULL to use the current loaded model.
-    // The model needs to be reloaded before applying a new adapter, otherwise the adapter
-    // will be applied on top of the previous one
-    // Returns 0 on success
-    LLAMA_API int32_t llama_model_apply_lora_from_file(
-            const struct llama_model * model,
-                          const char * path_lora,
-                               float   scale,
-                          const char * path_base_model,
-                             int32_t   n_threads);
+    // Load a LoRA adapter from file
+    // The loaded adapter will be associated to the given model, and will be free when the model is deleted
+    LLAMA_API struct llama_lora_adapter * llama_lora_adapter_init(
+            struct llama_model * model,
+            const char * path_lora);
+
+    // Add a loaded LoRA adapter to given context
+    // This will not modify model's weight
+    LLAMA_API int32_t llama_lora_adapter_set(
+            struct llama_context * ctx,
+            struct llama_lora_adapter * adapter,
+            float scale);
+
+    // Remove a specific LoRA adapter from given context
+    // Return -1 if the adapter is not present in the context
+    LLAMA_API int32_t llama_lora_adapter_remove(
+            struct llama_context * ctx,
+            struct llama_lora_adapter * adapter);
+
+    // Remove all LoRA adapters from given context
+    LLAMA_API void llama_lora_adapter_clear(
+            struct llama_context * ctx);
+
+    // Manually free a LoRA adapter
+    // Note: loaded adapters will be free when the associated model is deleted
+    LLAMA_API void llama_lora_adapter_free(struct llama_lora_adapter * adapter);
 
     // Apply a loaded control vector to a llama_context, or if data is NULL, clear
     // the currently loaded vector.
@@ -674,10 +715,11 @@ extern "C" {
     // State / sessions
     //
 
-    // Returns the maximum size in bytes of the state (rng, logits, embedding
-    // and kv_cache) - will often be smaller after compacting tokens
-    LLAMA_API size_t llama_state_get_size(const struct llama_context * ctx);
-    LLAMA_API DEPRECATED(size_t llama_get_state_size(const struct llama_context * ctx),
+    // Returns the *actual* size in bytes of the state
+    // (rng, logits, embedding and kv_cache)
+    // Only use when saving the state, not when restoring it, otherwise the size may be too small.
+    LLAMA_API size_t llama_state_get_size(struct llama_context * ctx);
+    LLAMA_API DEPRECATED(size_t llama_get_state_size(struct llama_context * ctx),
         "use llama_state_get_size instead");
 
     // Copies the state to the specified destination address.
@@ -685,7 +727,8 @@ extern "C" {
     // Returns the number of bytes copied
     LLAMA_API size_t llama_state_get_data(
             struct llama_context * ctx,
-                         uint8_t * dst);
+                         uint8_t * dst,
+                          size_t   size);
     LLAMA_API DEPRECATED(size_t llama_copy_state_data(
             struct llama_context * ctx,
                          uint8_t * dst),
@@ -695,7 +738,8 @@ extern "C" {
     // Returns the number of bytes read
     LLAMA_API size_t llama_state_set_data(
             struct llama_context * ctx,
-                   const uint8_t * src);
+                   const uint8_t * src,
+                          size_t   size);
     LLAMA_API DEPRECATED(size_t llama_set_state_data(
             struct llama_context * ctx,
                    const uint8_t * src),
@@ -737,6 +781,7 @@ extern "C" {
     LLAMA_API size_t llama_state_seq_get_data(
             struct llama_context * ctx,
                          uint8_t * dst,
+                          size_t   size,
                     llama_seq_id   seq_id);
 
     // Copy the sequence data (originally copied with `llama_state_seq_get_data`) into the specified sequence
@@ -746,6 +791,7 @@ extern "C" {
     LLAMA_API size_t llama_state_seq_set_data(
             struct llama_context * ctx,
                    const uint8_t * src,
+                          size_t   size,
                     llama_seq_id   dest_seq_id);
 
     LLAMA_API size_t llama_state_seq_save_file(
@@ -792,6 +838,14 @@ extern "C" {
     // Frees a batch of tokens allocated with llama_batch_init()
     LLAMA_API void llama_batch_free(struct llama_batch batch);
 
+    // Processes a batch of tokens with the ecoder part of the encoder-decoder model.
+    // Stores the encoder output internally for later use by the decoder cross-attention layers.
+    //   0 - success
+    // < 0 - error
+    LLAMA_API int32_t llama_encode(
+            struct llama_context * ctx,
+              struct llama_batch   batch);
+
     // Positive return values does not mean a fatal error, but rather a warning.
     //   0 - success
     //   1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
@@ -811,6 +865,10 @@ extern "C" {
     // Get the number of threads used for prompt and batch processing (multiple token).
     LLAMA_API uint32_t llama_n_threads_batch(struct llama_context * ctx);
 
+    // Set whether the model is in embeddings mode or not
+    // If true, embeddings will be returned but logits will not
+    LLAMA_API void llama_set_embeddings(struct llama_context * ctx, bool embeddings);
+
     // Set whether to use causal attention or not
     // If set to true, the model will only attend to the past tokens
     LLAMA_API void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn);
@@ -878,12 +936,13 @@ extern "C" {
     LLAMA_API llama_token llama_token_cls(const struct llama_model * model); // classification
     LLAMA_API llama_token llama_token_sep(const struct llama_model * model); // sentence separator
     LLAMA_API llama_token llama_token_nl (const struct llama_model * model); // next-line
+    LLAMA_API llama_token llama_token_pad(const struct llama_model * model); // padding
 
     // Returns -1 if unknown, 1 for true or 0 for false.
-    LLAMA_API int32_t         llama_add_bos_token(const struct llama_model * model);
+    LLAMA_API int32_t llama_add_bos_token(const struct llama_model * model);
 
     // Returns -1 if unknown, 1 for true or 0 for false.
-    LLAMA_API int32_t         llama_add_eos_token(const struct llama_model * model);
+    LLAMA_API int32_t llama_add_eos_token(const struct llama_model * model);
 
     // Codellama infill tokens
     LLAMA_API llama_token llama_token_prefix(const struct llama_model * model); // Beginning of infill prefix
@@ -899,6 +958,7 @@ extern "C" {
     /// @param tokens The tokens pointer must be large enough to hold the resulting tokens.
     /// @return Returns the number of tokens on success, no more than n_tokens_max
     /// @return Returns a negative number on failure - the number of tokens that would have been returned
+    /// @param add_special Allow to add BOS and EOS tokens if model is configured to do so.
     /// @param parse_special Allow tokenizing special and/or control tokens which otherwise are not exposed and treated
     ///                      as plaintext. Does not insert a leading space.
     LLAMA_API int32_t llama_tokenize(
@@ -913,15 +973,35 @@ extern "C" {
     // Token Id -> Piece.
     // Uses the vocabulary in the provided context.
     // Does not write null terminator to the buffer.
-    // User code is responsible to remove the leading whitespace of the first non-BOS token when decoding multiple tokens.
+    // User can skip up to 'lstrip' leading spaces before copying (useful when encoding/decoding multiple tokens with 'add_space_prefix')
     // @param special If true, special tokens are rendered in the output.
     LLAMA_API int32_t llama_token_to_piece(
               const struct llama_model * model,
                            llama_token   token,
                                   char * buf,
                                int32_t   length,
+                               int32_t   lstrip,
                                   bool   special);
 
+    /// @details Convert the provided tokens into text (inverse of llama_tokenize()).
+    /// @param text The char pointer must be large enough to hold the resulting text.
+    /// @return Returns the number of chars/bytes on success, no more than text_len_max.
+    /// @return Returns a negative number on failure - the number of chars/bytes that would have been returned.
+    /// @param remove_special Allow to remove BOS and EOS tokens if model is configured to do so.
+    /// @param unparse_special If true, special tokens are rendered in the output.
+    LLAMA_API int32_t llama_detokenize(
+        const struct llama_model * model,
+               const llama_token * tokens,
+                         int32_t   n_tokens,
+                            char * text,
+                         int32_t   text_len_max,
+                            bool   remove_special,
+                            bool   unparse_special);
+
+    //
+    // Chat templates
+    //
+
     /// Apply chat template. Inspired by hf apply_chat_template() on python.
     /// Both "model" and "custom_template" are optional, but at least one is required. "custom_template" has higher precedence than "model"
     /// NOTE: This function does not use a jinja parser. It only support a pre-defined list of template. See more: https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template
@@ -945,6 +1025,12 @@ extern "C" {
     // Grammar
     //
 
+    /// Initialize a llama_grammar.
+    ///
+    /// @param rules The rule elements of the grammar to initialize.
+    /// @param n_rules The number of rules.
+    /// @param start_rule_index The index of the root rule (the starting point of the grammar).
+    /// @return The initialized llama_grammar or nullptr if initialization failed.
     LLAMA_API struct llama_grammar * llama_grammar_init(
             const llama_grammar_element ** rules,
                                  size_t    n_rules,
@@ -954,6 +1040,23 @@ extern "C" {
 
     LLAMA_API struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar);
 
+    /// @details Apply constraints from grammar
+    LLAMA_API void llama_grammar_sample(
+            const struct llama_grammar * grammar,
+            const struct llama_context * ctx,
+                llama_token_data_array * candidates);
+    LLAMA_API DEPRECATED(void llama_sample_grammar(
+            struct llama_context * ctx,
+          llama_token_data_array * candidates,
+      const struct llama_grammar * grammar),
+        "use llama_grammar_sample instead");
+
+    /// @details Accepts the sampled token into the grammar
+    LLAMA_API void llama_grammar_accept_token(
+            struct llama_grammar * grammar,
+            struct llama_context * ctx,
+                     llama_token   token);
+
     //
     // Sampling functions
     //
@@ -1035,12 +1138,6 @@ extern "C" {
           llama_token_data_array * candidates,
                            float   temp);
 
-    /// @details Apply constraints from grammar
-    LLAMA_API void llama_sample_grammar(
-            struct llama_context * ctx,
-          llama_token_data_array * candidates,
-      const struct llama_grammar * grammar);
-
     /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
     /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
     /// @param tau  The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
@@ -1078,12 +1175,6 @@ extern "C" {
             struct llama_context * ctx,
           llama_token_data_array * candidates);
 
-    /// @details Accepts the sampled token into the grammar
-    LLAMA_API void llama_grammar_accept_token(
-            struct llama_context * ctx,
-            struct llama_grammar * grammar,
-                     llama_token   token);
-
     //
     // Model split
     //
@@ -1113,6 +1204,20 @@ extern "C" {
 
     LLAMA_API void llama_dump_timing_info_yaml(FILE * stream, const struct llama_context * ctx);
 
+    // Apply a LoRA adapter to a loaded model
+    // path_base_model is the path to a higher quality model to use as a base for
+    // the layers modified by the adapter. Can be NULL to use the current loaded model.
+    // The model needs to be reloaded before applying a new adapter, otherwise the adapter
+    // will be applied on top of the previous one
+    // Returns 0 on success
+    LLAMA_API int32_t llama_model_apply_lora_from_file(
+            const struct llama_model * model,
+                            const char * path_lora,
+                                float   scale,
+                            const char * path_base_model,
+                                int32_t   n_threads);
+
+
 #ifdef __cplusplus
 }
 #endif
@@ -1126,38 +1231,45 @@ extern "C" {
 
 struct ggml_tensor;
 
+const std::vector> & llama_internal_get_tensor_map(
+    struct llama_context * ctx
+);
+
 struct llama_partial_utf8 {
     uint32_t value;    // bit value so far (unshifted)
     int      n_remain; // num bytes remaining; -1 indicates invalid sequence
 };
 
-struct llama_grammar {
-    const std::vector>   rules;
-    std::vector> stacks;
-
-    // buffer for partially generated UTF-8 sequence from accepted tokens
-    llama_partial_utf8                                      partial_utf8;
-};
-
 struct llama_grammar_candidate {
     size_t               index;
     const uint32_t     * code_points;
     llama_partial_utf8   partial_utf8;
 };
 
-const std::vector> & llama_internal_get_tensor_map(
-    struct llama_context * ctx
-);
+using llama_grammar_rule  = std::vector<      llama_grammar_element>;
+using llama_grammar_stack = std::vector;
+
+using llama_grammar_rules      = std::vector;
+using llama_grammar_stacks     = std::vector;
+using llama_grammar_candidates = std::vector;
+
+const llama_grammar_rules  & llama_grammar_get_rules (const struct llama_grammar * grammar);
+      llama_grammar_stacks & llama_grammar_get_stacks(      struct llama_grammar * grammar);
 
 void llama_grammar_accept(
-        const std::vector>         & rules,
-        const std::vector> & stacks,
-        const uint32_t                                                  chr,
-        std::vector>       & new_stacks);
+        const llama_grammar_rules  & rules,
+        const llama_grammar_stacks & stacks,
+        const uint32_t chr,
+              llama_grammar_stacks & new_stacks);
+
+std::vector llama_grammar_reject_candidates_for_stack(
+        const llama_grammar_rules & rules,
+        const llama_grammar_stack & stack,
+        const llama_grammar_candidates & candidates);
 
 std::pair, llama_partial_utf8> decode_utf8(
         const std::string & src,
-        llama_partial_utf8   partial_start);
+        llama_partial_utf8 partial_start);
 
 // Randomly selects a token from the candidates based on their probabilities using given std::mt19937.
 // This is a temporary workaround in order to fix race conditions when sampling with multiple sequences.
diff --git a/llama/llama_darwin.c b/llama/llama_darwin.c
index 3d41a75d..7d2d98c6 100644
--- a/llama/llama_darwin.c
+++ b/llama/llama_darwin.c
@@ -1,2 +1,236 @@
+/**
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+/**
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+/**
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+/**
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+/**
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+/**
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+/**
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+/**
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+/**
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
+ *
+ * MIT License
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
 const char *ggml_metallib_start;
 const char *ggml_metallib_end;
diff --git a/llama/llava.cpp b/llama/llava.cpp
index fc55207b..d94196ec 100644
--- a/llama/llava.cpp
+++ b/llama/llava.cpp
@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *
diff --git a/llama/llava.h b/llama/llava.h
index 0d120283..61dde037 100644
--- a/llama/llava.h
+++ b/llama/llava.h
@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *
diff --git a/llama/log.h b/llama/log.h
index 03855df6..67e92545 100644
--- a/llama/log.h
+++ b/llama/log.h
@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *
@@ -656,7 +656,7 @@ inline std::string LOG_TOKENS_TOSTR_PRETTY(const C & ctx, const T & tokens)
     buf << "[ ";
 
     bool first = true;
-    for (const auto &token : tokens)
+    for (const auto & token : tokens)
     {
         if (!first) {
             buf << ", ";
diff --git a/llama/patches/01-cuda.diff b/llama/patches/01-cuda.diff
index 66dd9d76..0ae13db1 100644
--- a/llama/patches/01-cuda.diff
+++ b/llama/patches/01-cuda.diff
@@ -1,7 +1,7 @@
-diff --git a/llama/ggml-backend.c b/llama/ggml-backend.c
+diff --git a/ggml/src/ggml-backend.c b/ggml/src/ggml-backend.c
 index 9e35ce98..179be840 100644
---- a/llama/ggml-backend.c
-+++ b/llama/ggml-backend.c
+--- a/ggml/src/ggml-backend.c
++++ b/ggml/src/ggml-backend.c
 @@ -87,7 +87,12 @@ void ggml_backend_buffer_free(ggml_backend_buffer_t buffer) {
      if (buffer->iface.free_buffer != NULL) {
          buffer->iface.free_buffer(buffer);
@@ -15,10 +15,10 @@ index 9e35ce98..179be840 100644
  }
  
  size_t ggml_backend_buffer_get_size(ggml_backend_buffer_t buffer) {
-diff --git a/llama/ggml-cuda.cu b/llama/ggml-cuda.cu
+diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu
 index 04b6e528..43b12bdf 100644
---- a/llama/ggml-cuda.cu
-+++ b/llama/ggml-cuda.cu
+--- a/ggml/src/ggml-cuda.cu
++++ b/ggml/src/ggml-cuda.cu
 @@ -392,6 +392,10 @@ GGML_CALL static bool ggml_backend_buffer_is_cuda(ggml_backend_buffer_t buffer)
  GGML_CALL static void ggml_backend_cuda_buffer_free_buffer(ggml_backend_buffer_t buffer) {
      ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
@@ -39,10 +39,10 @@ index 04b6e528..43b12bdf 100644
  GGML_CALL int ggml_backend_cuda_reg_devices() {
      int device_count = ggml_backend_cuda_get_device_count();
      //int device_count = 1; // DEBUG: some tools require delaying CUDA initialization
-diff --git a/llama/ggml-cuda.h b/llama/ggml-cuda.h
+diff --git a/ggml/include/ggml-cuda.h b/ggml/include/ggml-cuda.h
 index 5eb4af40..50b91009 100644
---- a/llama/ggml-cuda.h
-+++ b/llama/ggml-cuda.h
+--- a/ggml/include/ggml-cuda.h
++++ b/ggml/include/ggml-cuda.h
 @@ -31,6 +31,8 @@ GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_split_buffer_typ
  // pinned host buffer for use with the CPU backend for faster copies between CPU and GPU
  GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_host_buffer_type(void);
diff --git a/llama/patches/02-llamacpp.diff b/llama/patches/02-llamacpp.diff
index 7b5f9b70..0d40fc3c 100644
--- a/llama/patches/02-llamacpp.diff
+++ b/llama/patches/02-llamacpp.diff
@@ -1,11 +1,11 @@
-diff --git a/llama/llama.cpp b/llama/llama.cpp
-index 8b675ea9..bcc6ae75 100644
---- a/llama/llama.cpp
-+++ b/llama/llama.cpp
-@@ -4645,16 +4645,7 @@ static void llm_load_vocab(
- 
-         // for now, only BPE models have pre-tokenizers
+diff --git a/src/llama.cpp b/src/llama.cpp
+index a207451f..2ddf431d 100644
+--- a/src/llama.cpp
++++ b/src/llama.cpp
+@@ -5347,16 +5347,7 @@ static void llm_load_vocab(
          if (vocab.type == LLAMA_VOCAB_TYPE_BPE) {
+             vocab.tokenizer_add_space_prefix = false;
+             vocab.tokenizer_clean_spaces = true;
 -            if (tokenizer_pre.empty()) {
 -                LLAMA_LOG_WARN("%s: missing pre-tokenizer type, using: 'default'\n", __func__);
 -                LLAMA_LOG_WARN("%s:                                             \n", __func__);
@@ -20,22 +20,13 @@ index 8b675ea9..bcc6ae75 100644
                  vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
              } else if (
                      tokenizer_pre == "llama3"   ||
-@@ -4706,7 +4697,8 @@ static void llm_load_vocab(
-                 tokenizer_pre == "smaug-bpe") {
-                 vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_SMAUG;
+@@ -5443,7 +5434,8 @@ static void llm_load_vocab(
+                 tokenizer_pre == "codeshell") {
+                 vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_CODESHELL;
              } else {
 -                throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
 +                LLAMA_LOG_WARN("%s: missing or unrecognized pre-tokenizer type, using: 'default'\n", __func__);
 +                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
              }
-         } else {
+         } else if (vocab.type == LLAMA_VOCAB_TYPE_SPM) {
              vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
-@@ -7009,7 +7001,7 @@ static struct ggml_tensor * llm_build_kqv(
-         struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
-         cb(kq, "kq", il);
- 
--        if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX) {
-+        if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2) {
-             // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs
-             // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847
-             ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
diff --git a/llama/patches/03-metal.diff b/llama/patches/03-metal.diff
index f3edde3e..e63732e7 100644
--- a/llama/patches/03-metal.diff
+++ b/llama/patches/03-metal.diff
@@ -1,7 +1,7 @@
-diff --git a/llama/ggml-metal-darwin_arm64.m b/llama/ggml-metal-darwin_arm64.m
+diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m
 index 0207b787..b5e9884b 100644
---- a/llama/ggml-metal-darwin_arm64.m
-+++ b/llama/ggml-metal-darwin_arm64.m
+--- a/ggml/src/ggml-metal.m
++++ b/ggml/src/ggml-metal.m
 @@ -1396,27 +1396,23 @@ static enum ggml_status ggml_metal_graph_compute(
                          // to the matrix-vector kernel
                          int ne11_mm_min = 1;
diff --git a/llama/patches/04-ggml-metal.diff b/llama/patches/04-ggml-metal.diff
index 7ee48cf5..b3b7f14c 100644
--- a/llama/patches/04-ggml-metal.diff
+++ b/llama/patches/04-ggml-metal.diff
@@ -1,7 +1,7 @@
-diff --git a/llama/ggml-metal-darwin_arm64.m b/llama/ggml-metal-darwin_arm64.m
+diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m
 index b56c3604..400d43f4 100644
---- a/llama/ggml-metal-darwin_arm64.m
-+++ b/llama/ggml-metal-darwin_arm64.m
+--- a/ggml/src/ggml-metal.m
++++ b/ggml/src/ggml-metal.m
 @@ -377,8 +377,8 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
  #if GGML_METAL_EMBED_LIBRARY
              GGML_METAL_LOG_INFO("%s: using embedded metal library\n", __func__);
diff --git a/llama/patches/06-embeddings.diff b/llama/patches/06-embeddings.diff
new file mode 100644
index 00000000..a84e3b06
--- /dev/null
+++ b/llama/patches/06-embeddings.diff
@@ -0,0 +1,45 @@
+diff --git a/src/llama.cpp b/src/llama.cpp
+index 1fe2b9f7..a43312a7 100644
+--- a/src/llama.cpp
++++ b/src/llama.cpp
+@@ -13689,7 +13689,7 @@ static size_t llama_output_reserve(llama_context & lctx, size_t n_outputs) {
+     const auto n_embd  = hparams.n_embd;
+ 
+     // TODO: use a per-batch flag for logits presence instead
+-    const bool has_logits = !cparams.embeddings;
++    const bool has_logits =  cparams.causal_attn;
+     const bool has_embd   =  lctx.is_encoding || (cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE));
+ 
+     const size_t logits_size = has_logits ? n_vocab*n_outputs_max : 0;
+@@ -13959,17 +13959,25 @@ static int llama_decode_internal(
+             // no output
+             res  = nullptr;
+             embd = nullptr;
+-        } else if (cparams.embeddings) {
+-            res = nullptr; // do not extract logits for embedding case
+-            embd = gf->nodes[gf->n_nodes - 1];
+-            if (strcmp(embd->name, "result_embd_pooled") != 0) {
+-                embd = gf->nodes[gf->n_nodes - 2];
++        }
++
++        if (cparams.embeddings) {
++            for (int i = gf->n_nodes - 1; i >= 0; --i) {
++                embd = gf->nodes[i];
++                if (strcmp(embd->name, "result_embd_pooled") == 0) {
++                    break;
++                }
+             }
+             GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0 && "missing embeddings tensor");
+-        } else {
++         } else {
+             embd = nullptr; // do not extract embeddings when not needed
+             GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor");
+         }
++
++        if (!cparams.causal_attn) {
++            res = nullptr; // do not extract logits when not needed
++        }
++
+         // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
+ 
+         ggml_backend_sched_alloc_graph(lctx.sched, gf);
diff --git a/llama/patches/09-lora.diff b/llama/patches/09-lora.diff
new file mode 100644
index 00000000..10c66d1d
--- /dev/null
+++ b/llama/patches/09-lora.diff
@@ -0,0 +1,358 @@
+diff --git a/common/common.cpp b/common/common.cpp
+index dbb724fb..c26fe6ee 100644
+--- a/common/common.cpp
++++ b/common/common.cpp
+@@ -2087,14 +2087,27 @@ std::tuple llama_init_from_gpt_par
+     for (unsigned int i = 0; i < params.lora_adapter.size(); ++i) {
+         const std::string & lora_adapter = std::get<0>(params.lora_adapter[i]);
+         float lora_scale = std::get<1>(params.lora_adapter[i]);
++
++        // try to load as gguf
+         auto adapter = llama_lora_adapter_init(model, lora_adapter.c_str());
+         if (adapter == nullptr) {
+-            fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
+-            llama_free(lctx);
+-            llama_free_model(model);
+-            return std::make_tuple(nullptr, nullptr);
++            fprintf(stderr, "%s: error: failed to apply lora adapter, trying ggla\n", __func__);
++
++            // if that fails, try loading as ggla for compatibility
++            int err = llama_model_apply_lora_from_file(model,
++                                                    lora_adapter.c_str(),
++                                                    lora_scale,
++                                                    nullptr,
++                                                    params.n_threads);
++            if (err != 0) {
++                fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
++                llama_free(lctx);
++                llama_free_model(model);
++                return std::make_tuple(nullptr, nullptr);
++            }
++        } else {
++            llama_lora_adapter_set(lctx, adapter, lora_scale);
+         }
+-        llama_lora_adapter_set(lctx, adapter, lora_scale);
+     }
+ 
+     if (params.ignore_eos) {
+diff --git a/include/llama.h b/include/llama.h
+index 93fd77ca..b0fb37a6 100644
+--- a/include/llama.h
++++ b/include/llama.h
+@@ -1160,6 +1160,20 @@ extern "C" {
+ 
+     LLAMA_API void llama_dump_timing_info_yaml(FILE * stream, const struct llama_context * ctx);
+ 
++    // Apply a LoRA adapter to a loaded model
++    // path_base_model is the path to a higher quality model to use as a base for
++    // the layers modified by the adapter. Can be NULL to use the current loaded model.
++    // The model needs to be reloaded before applying a new adapter, otherwise the adapter
++    // will be applied on top of the previous one
++    // Returns 0 on success
++    LLAMA_API int32_t llama_model_apply_lora_from_file(
++            const struct llama_model * model,
++                            const char * path_lora,
++                                float   scale,
++                            const char * path_base_model,
++                                int32_t   n_threads);
++
++
+ #ifdef __cplusplus
+ }
+ #endif
+diff --git a/src/llama.cpp b/src/llama.cpp
+index 80a0dd0f..9d7b0e17 100644
+--- a/src/llama.cpp
++++ b/src/llama.cpp
+@@ -21880,3 +21880,290 @@ static void llama_log_callback_default(ggml_log_level level, const char * text,
+     fputs(text, stderr);
+     fflush(stderr);
+ }
++
++static int llama_apply_lora_from_file_internal(
++    const struct llama_model & model, const char * path_lora, float scale, const char * path_base_model, int n_threads
++) {
++    LLAMA_LOG_INFO("%s: applying lora adapter from '%s' - please wait ...\n", __func__, path_lora);
++
++    const int64_t t_start_lora_us = ggml_time_us();
++
++    llama_file fin(path_lora, "rb");
++
++    // verify magic and version
++    {
++        uint32_t magic = fin.read_u32();
++        if (magic != LLAMA_FILE_MAGIC_GGLA) {
++            LLAMA_LOG_ERROR("%s: bad file magic\n", __func__);
++            return 1;
++        }
++
++        uint32_t format_version = fin.read_u32();
++        if (format_version != 1) {
++            LLAMA_LOG_ERROR("%s: unsupported file version\n", __func__ );
++            return 1;
++        }
++    }
++
++    int32_t lora_r = fin.read_u32();
++    int32_t lora_alpha = fin.read_u32();
++    float scaling = scale * (float)lora_alpha / (float)lora_r;
++
++    LLAMA_LOG_INFO("%s: r = %d, alpha = %d, scaling = %.2f\n", __func__, lora_r, lora_alpha, scaling);
++
++    // load base model
++    std::unique_ptr ml;
++    if (path_base_model) {
++        LLAMA_LOG_INFO("%s: loading base model from '%s'\n", __func__, path_base_model);
++        ml.reset(new llama_model_loader(path_base_model, /*use_mmap*/ true, /*check_tensors*/ false, /*kv_overrides*/ nullptr));
++        ml->init_mappings(/*prefetch*/ false); // no prefetching
++    }
++
++    struct tensor_meta {
++        std::string name;
++        ggml_type type;
++        int32_t ne[2];
++        size_t offset;
++    };
++    std::map tensor_meta_map;
++
++    // load all tensor meta
++    while (true) {
++        if (fin.tell() == fin.size) {
++            // eof
++            break;
++        }
++
++        int32_t n_dims;
++        int32_t name_len;
++        int32_t ftype;
++
++        fin.read_raw(&n_dims, sizeof(n_dims));
++        fin.read_raw(&name_len, sizeof(name_len));
++        fin.read_raw(&ftype, sizeof(ftype));
++
++        if (n_dims != 1 && n_dims != 2) {
++            LLAMA_LOG_ERROR("%s: unsupported tensor dimension %d\n", __func__, n_dims);
++            return 1;
++        }
++
++        int32_t ne[2] = { 1, 1 };
++        for (int i = 0; i < n_dims; ++i) {
++            fin.read_raw(&ne[i], sizeof(ne[i]));
++        }
++
++        std::string name;
++        {
++            GGML_ASSERT(name_len < GGML_MAX_NAME);
++            char buf[GGML_MAX_NAME];
++            fin.read_raw(buf, name_len);
++            name = std::string(buf, name_len);
++        }
++
++        // check for lora suffix
++        std::string lora_suffix;
++        if (name.length() > 6) {
++            lora_suffix = name.substr(name.length() - 6);
++        }
++        if (lora_suffix != ".loraA" && lora_suffix != ".loraB") {
++            LLAMA_LOG_ERROR("%s: error: '%s' is not a lora tensor\n", __func__, name.c_str());
++            return 1;
++        }
++
++        // tensor type
++        ggml_type wtype;
++        switch (ftype) {
++            case 0: wtype = GGML_TYPE_F32;  break;
++            case 1: wtype = GGML_TYPE_F16;  break;
++            default:
++                    {
++                        LLAMA_LOG_ERROR("%s: invalid tensor data type '%d'\n",
++                                __func__, ftype);
++                        return 1;
++                    }
++        }
++
++        // data offset
++        size_t offset = fin.tell();
++        offset = (offset + 31) & -32;
++
++        // skip tensor data
++        fin.seek(offset + ggml_row_size(wtype, ne[0]) * ne[1], SEEK_SET);
++
++        tensor_meta_map.emplace(name, tensor_meta{ name, wtype, { ne[0], ne[1] }, offset });
++    }
++
++    bool warned = false;
++    int n_tensors = 0;
++
++    // apply
++    ggml_backend_t backend_cpu = ggml_backend_cpu_init();
++    if (backend_cpu == nullptr) {
++        LLAMA_LOG_ERROR("%s: error: failed to initialize cpu backend\n", __func__);
++        return 1;
++    }
++    ggml_backend_cpu_set_n_threads(backend_cpu, n_threads);
++
++    std::vector> read_buf;
++    for (const auto & it : model.tensors_by_name) {
++        const std::string & base_name = it.first;
++        ggml_tensor * model_t = it.second;
++
++        if (tensor_meta_map.find(base_name + ".loraA") == tensor_meta_map.end() ||
++            tensor_meta_map.find(base_name + ".loraB") == tensor_meta_map.end()) {
++            continue;
++        }
++
++        tensor_meta & metaA = tensor_meta_map.at(base_name + ".loraA");
++        tensor_meta & metaB = tensor_meta_map.at(base_name + ".loraB");
++
++        ggml_init_params lora_init_params = {
++            /* .mem_size   */ ggml_tensor_overhead()*128 + ggml_graph_overhead(),
++            /* .mem_buffer */ nullptr,
++            /* .no_alloc   */ true,
++        };
++        ggml_context * lora_ctx = ggml_init(lora_init_params);
++        if (lora_ctx == nullptr) {
++            LLAMA_LOG_ERROR("%s: error: failed to initialize lora context\n", __func__);
++            ggml_backend_free(backend_cpu);
++            return 1;
++        }
++
++        // create tensors
++        ggml_tensor * loraA = ggml_new_tensor_2d(lora_ctx, metaA.type, metaA.ne[0], metaA.ne[1]);
++        ggml_tensor * loraB = ggml_new_tensor_2d(lora_ctx, metaB.type, metaB.ne[0], metaB.ne[1]);
++        ggml_set_name(loraA, metaA.name.c_str());
++        ggml_set_name(loraB, metaB.name.c_str());
++
++        ggml_tensor * base_t;
++        if (ml) {
++            if (!ml->get_tensor_meta(base_name.c_str())) {
++                LLAMA_LOG_ERROR("%s: error: tensor '%s' not found in base model\n", __func__, base_name.c_str());
++                return 1;
++            }
++            base_t = ggml_dup_tensor(lora_ctx, ml->get_tensor_meta(base_name.c_str()));
++        } else {
++            base_t = ggml_dup_tensor(lora_ctx, model_t);
++        }
++        ggml_set_name(base_t, base_name.c_str());
++
++        // allocate in backend buffer
++        ggml_backend_buffer_t lora_buf = ggml_backend_alloc_ctx_tensors_from_buft(lora_ctx, ggml_backend_cpu_buffer_type());
++        if (lora_buf == nullptr) {
++            LLAMA_LOG_ERROR("%s: error: failed to allocate lora tensors\n", __func__);
++            return 1;
++        }
++
++        // load tensor data
++        auto load_tensor = [&read_buf, &fin](const tensor_meta & tensor_meta, ggml_tensor * tensor) {
++            read_buf.resize(ggml_nbytes(tensor));
++            fin.seek(tensor_meta.offset, SEEK_SET);
++            fin.read_raw(read_buf.data(), ggml_nbytes(tensor));
++            ggml_backend_tensor_set(tensor, read_buf.data(), 0, read_buf.size());
++        };
++        load_tensor(metaA, loraA);
++        load_tensor(metaB, loraB);
++
++        // load base model tensor data
++        if (ml) {
++            ml->load_data_for(base_t);
++        } else {
++            ggml_backend_tensor_copy(model_t, base_t);
++        }
++
++        if (ggml_is_quantized(base_t->type) && !warned) {
++            LLAMA_LOG_WARN("%s: warning: using a lora adapter with a quantized model may result in poor quality, "
++                            "use a f16 or f32 base model with --lora-base\n", __func__);
++            warned = true;
++        }
++
++        if (base_t->ne[0] != loraA->ne[1] || base_t->ne[1] != loraB->ne[1]) {
++            LLAMA_LOG_ERROR("%s: incompatible tensor dimensions (%" PRId64 " and %" PRId64 ");"
++                            " are you sure that this adapter is for this model?\n", __func__, base_t->ne[0], loraA->ne[1]);
++            ggml_free(lora_ctx);
++            ggml_backend_buffer_free(lora_buf);
++            ggml_backend_free(backend_cpu);
++            return 1;
++        }
++
++        auto build_lora_graph = [&]() {
++            // w = w + BA*s
++            ggml_tensor * BA = ggml_mul_mat(lora_ctx, loraA, loraB);
++            ggml_set_name(BA, "BA");
++
++            if (scaling != 1.0f) {
++                BA = ggml_scale(lora_ctx, BA, scaling);
++                ggml_set_name(BA, "BA_scaled");
++            }
++
++            ggml_tensor * r;
++            r = ggml_add_inplace(lora_ctx, base_t, BA);
++            ggml_set_name(r, "r_add");
++
++            if (base_t->type != model_t->type) {
++                // convert the result to the model type
++                r = ggml_cast(lora_ctx, r, model_t->type);
++                ggml_set_name(r, "r_cast");
++            }
++
++            return r;
++        };
++
++        ggml_cgraph * gf = ggml_new_graph(lora_ctx);
++        ggml_tensor * r = build_lora_graph();
++        ggml_build_forward_expand(gf, r);
++
++        ggml_backend_buffer_t graph_buf = ggml_backend_alloc_ctx_tensors_from_buft(lora_ctx, ggml_backend_cpu_buffer_type());
++        if (graph_buf == nullptr) {
++            LLAMA_LOG_ERROR("%s: error: failed to allocate graph tensors\n", __func__);
++            ggml_free(lora_ctx);
++            ggml_backend_buffer_free(lora_buf);
++            ggml_backend_free(backend_cpu);
++            return 1;
++        }
++
++        ggml_backend_graph_compute(backend_cpu, gf);
++
++        ggml_backend_tensor_set(model_t, r->data, 0, ggml_nbytes(r));
++
++#if 0
++        // TODO: use scheduler with fallback to CPU for less copies between CPU and GPU
++        //ggml_backend_sched_t sched = ggml_backend_sched_new(backends.data(), backends.size(), GGML_DEFAULT_GRAPH_SIZE);
++
++        // sched compute
++        ggml_build_forward_expand(gf, build_graph());
++        ggml_backend_sched_init_measure(sched, gf);
++
++        // create the graph again, since the previous one was destroyed by the measure
++        ggml_graph_clear(gf);
++        ggml_build_forward_expand(gf, build_graph());
++        ggml_backend_sched_graph_compute(sched, gf);
++        ggml_backend_sched_free(sched);
++#endif
++
++        ggml_backend_buffer_free(lora_buf);
++        ggml_backend_buffer_free(graph_buf);
++        ggml_free(lora_ctx);
++
++        n_tensors++;
++        if (n_tensors % 4 == 0) {
++            LLAMA_LOG_INFO(".");
++        }
++    }
++
++    ggml_backend_free(backend_cpu);
++
++    const int64_t t_lora_us = ggml_time_us() - t_start_lora_us;
++    LLAMA_LOG_INFO(" done (%.2f ms)\n", t_lora_us / 1000.0);
++
++    return 0;
++}
++
++int32_t llama_model_apply_lora_from_file(const struct llama_model * model, const char * path_lora, float scale, const char * path_base_model, int32_t n_threads) {
++    try {
++        return llama_apply_lora_from_file_internal(*model, path_lora, scale, path_base_model, n_threads);
++    } catch (const std::exception & err) {
++        LLAMA_LOG_ERROR("%s: failed to apply lora adapter: %s\n", __func__, err.what());
++        return 1;
++    }
++}
+\ No newline at end of file
diff --git a/llama/sampling.cpp b/llama/sampling.cpp
index 07f6219d..1985ac2f 100644
--- a/llama/sampling.cpp
+++ b/llama/sampling.cpp
@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *
@@ -54,9 +54,13 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_
 
         std::vector grammar_rules(result->parsed_grammar.c_rules());
 
-        result->grammar = llama_grammar_init(
+        struct llama_grammar * grammar = llama_grammar_init(
                 grammar_rules.data(),
                 grammar_rules.size(), result->parsed_grammar.symbol_ids.at("root"));
+        if (grammar == nullptr) {
+            throw std::runtime_error("Failed to initialize llama_grammar");
+        }
+        result->grammar = grammar;
     }
 
     result->prev.resize(params.n_prev);
@@ -85,9 +89,13 @@ void llama_sampling_reset(llama_sampling_context * ctx) {
     if (!ctx->parsed_grammar.rules.empty()) {
         std::vector grammar_rules(ctx->parsed_grammar.c_rules());
 
-        ctx->grammar = llama_grammar_init(
+        struct llama_grammar * grammar = llama_grammar_init(
                 grammar_rules.data(),
                 grammar_rules.size(), ctx->parsed_grammar.symbol_ids.at("root"));
+        if (grammar == nullptr) {
+            throw std::runtime_error("Failed to initialize llama_grammar");
+        }
+        ctx->grammar = grammar;
     }
 
     std::fill(ctx->prev.begin(), ctx->prev.end(), 0);
@@ -300,8 +308,6 @@ static llama_token llama_sampling_sample_impl(
         GGML_ASSERT(!original_logits.empty());
     }
     llama_token id = 0;
-    // Get a pointer to the logits
-    float * logits = llama_get_logits_ith(ctx_main, idx);
 
     if (temp < 0.0) {
         // greedy sampling, with probs
@@ -342,12 +348,15 @@ static llama_token llama_sampling_sample_impl(
     }
 
     if (ctx_sampling->grammar != NULL && !is_resampling) {
+        // Get a pointer to the logits
+        float * logits = llama_get_logits_ith(ctx_main, idx);
+
         // Create an array with a single token data element for the sampled id
         llama_token_data single_token_data = {id, logits[id], 0.0f};
         llama_token_data_array single_token_data_array = { &single_token_data, 1, false };
 
         // Apply grammar constraints to the single token
-        llama_sample_grammar(ctx_main, &single_token_data_array, ctx_sampling->grammar);
+        llama_grammar_sample(ctx_sampling->grammar, ctx_main, &single_token_data_array);
 
         // Check if the token is valid according to the grammar by seeing if its logit has been set to -INFINITY
         bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
@@ -395,7 +404,7 @@ static llama_token_data_array llama_sampling_prepare_impl(
     if (ctx_sampling->grammar != NULL && !apply_grammar) {
         GGML_ASSERT(original_logits != NULL);
         // Only make a copy of the original logits if we are not applying grammar checks, not sure if I actually have to do this.
-        *original_logits = {logits, logits + llama_n_vocab(llama_get_model(ctx_main))};
+        *original_logits = {logits, logits + n_vocab};
     }
 
     // apply params.logit_bias map
@@ -408,10 +417,10 @@ static llama_token_data_array llama_sampling_prepare_impl(
         llama_sample_apply_guidance(ctx_main, logits, logits_guidance, params.cfg_scale);
     }
 
-    cur.clear();
+    cur.resize(n_vocab);
 
     for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
-        cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
+        cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
     }
 
     llama_token_data_array cur_p = { cur.data(), cur.size(), false };
@@ -438,7 +447,7 @@ static llama_token_data_array llama_sampling_prepare_impl(
 
     // apply grammar checks before sampling logic
     if (apply_grammar && ctx_sampling->grammar != NULL) {
-        llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar);
+        llama_grammar_sample(ctx_sampling->grammar, ctx_main, &cur_p);
     }
 
     return cur_p;
@@ -472,6 +481,6 @@ void llama_sampling_accept(
     ctx_sampling->prev.push_back(id);
 
     if (ctx_sampling->grammar != NULL && apply_grammar) {
-        llama_grammar_accept_token(ctx_main, ctx_sampling->grammar, id);
+        llama_grammar_accept_token(ctx_sampling->grammar, ctx_main, id);
     }
 }
diff --git a/llama/sampling.h b/llama/sampling.h
index cc6524b3..30b4134f 100644
--- a/llama/sampling.h
+++ b/llama/sampling.h
@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *
diff --git a/llama/sgemm.cpp b/llama/sgemm.cpp
index 40ba9d7e..6626ceb2 100644
--- a/llama/sgemm.cpp
+++ b/llama/sgemm.cpp
@@ -43,8 +43,10 @@
 // [1] J. Tunney, ‘LLaMA Now Goes Faster on CPUs’, Mar. 2024. [Online].
 //     Available: https://justine.lol/matmul/. [Accessed: 29-Mar-2024].
 
+#if defined(__GNUC__)
 #pragma GCC diagnostic ignored "-Wpedantic"
 #pragma GCC diagnostic ignored "-Wignored-attributes"
+#endif
 
 #include "sgemm.h"
 #include "ggml-impl.h"
@@ -247,9 +249,8 @@ class tinyBLAS {
         : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
     }
 
-    void matmul(int64_t m, int64_t n, int task) {
-        if (task == GGML_TASK_TYPE_COMPUTE)
-            mnpack(0, m, 0, n);
+    void matmul(int64_t m, int64_t n) {
+        mnpack(0, m, 0, n);
     }
 
   private:
@@ -456,9 +457,8 @@ class tinyBLAS_Q0_ARM {
         : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
     }
 
-    void matmul(int64_t m, int64_t n, int task) {
-        if (task == GGML_TASK_TYPE_COMPUTE)
-            mnpack(0, m, 0, n);
+    void matmul(int64_t m, int64_t n) {
+        mnpack(0, m, 0, n);
     }
 
   private:
@@ -594,9 +594,8 @@ class tinyBLAS_Q0_AVX {
         : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
     }
 
-    void matmul(int64_t m, int64_t n, int task) {
-        if (task == GGML_TASK_TYPE_COMPUTE)
-            mnpack(0, m, 0, n);
+    void matmul(int64_t m, int64_t n) {
+        mnpack(0, m, 0, n);
     }
 
   private:
@@ -827,7 +826,7 @@ class tinyBLAS_Q0_AVX {
  * For example, for single-threaded single-precision GEMM you can say
  *
  *     llamafile_sgemm(m, n, k, A, lda, B, ldb, C, ldc,
- *                     0, 1, GGML_TASK_TYPE_COMPUTE,
+ *                     0, 1,
  *                     GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32);
  *
  * @param m is rows in `A` and `C`
@@ -841,14 +840,13 @@ class tinyBLAS_Q0_AVX {
  * @param ldc is row stride of `C`
  * @param ith is thread id (must be less than `nth`)
  * @param nth is number of threads (must be greater than zero)
- * @param task is GGML task type
  * @param Atype is GGML data type of `A`
  * @param Btype is GGML data type of `B`
  * @param Ctype is GGML data type of `C`
  * @return true if this function was able to service the matmul request
  */
 bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda, const void *B, int64_t ldb, void *C,
-                     int64_t ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype) {
+                     int64_t ldc, int ith, int nth, int Atype, int Btype, int Ctype) {
 
     assert(m >= 0);
     assert(n >= 0);
@@ -875,7 +873,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
             (const float *)B, ldb,
             (float *)C, ldc,
             ith, nth};
-        tb.matmul(m, n, task);
+        tb.matmul(m, n);
         return true;
 #elif defined(__AVX__) || defined(__AVX2__)
         if (k % 8)
@@ -885,7 +883,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
             (const float *)B, ldb,
             (float *)C, ldc,
             ith, nth};
-        tb.matmul(m, n, task);
+        tb.matmul(m, n);
         return true;
 #elif defined(__ARM_NEON)
         if (n < 4)
@@ -897,7 +895,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
             (const float *)B, ldb,
             (float *)C, ldc,
             ith, nth};
-        tb.matmul(m, n, task);
+        tb.matmul(m, n);
         return true;
 #else
         return false;
@@ -915,7 +913,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
             (const float *)B, ldb,
             (float *)C, ldc,
             ith, nth};
-        tb.matmul(m, n, task);
+        tb.matmul(m, n);
         return true;
 #elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__)
         if (k % 8)
@@ -927,7 +925,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
             (const float *)B, ldb,
             (float *)C, ldc,
             ith, nth};
-        tb.matmul(m, n, task);
+        tb.matmul(m, n);
         return true;
 #elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
         if (n < 8)
@@ -941,7 +939,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
             (const ggml_fp16_t *)B, ldb,
             (float *)C, ldc,
             ith, nth};
-        tb.matmul(m, n, task);
+        tb.matmul(m, n);
         return true;
 #elif defined(__ARM_NEON) && !defined(_MSC_VER)
         if (k % 4)
@@ -953,7 +951,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
             (const float *)B, ldb,
             (float *)C, ldc,
             ith, nth};
-        tb.matmul(m, n, task);
+        tb.matmul(m, n);
         return true;
 #else
         return false;
@@ -969,7 +967,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
             (const block_q8_0 *)B, ldb,
             (float *)C, ldc,
             ith, nth};
-        tb.matmul(m, n, task);
+        tb.matmul(m, n);
         return true;
 #elif defined(__ARM_FEATURE_DOTPROD)
         tinyBLAS_Q0_ARM tb{
@@ -977,7 +975,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
             (const block_q8_0 *)B, ldb,
             (float *)C, ldc,
             ith, nth};
-        tb.matmul(m, n, task);
+        tb.matmul(m, n);
         return true;
 #else
         return false;
@@ -993,7 +991,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
             (const block_q8_0 *)B, ldb,
             (float *)C, ldc,
             ith, nth};
-        tb.matmul(m, n, task);
+        tb.matmul(m, n);
         return true;
 #elif defined(__ARM_FEATURE_DOTPROD)
         tinyBLAS_Q0_ARM tb{
@@ -1001,7 +999,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
             (const block_q8_0 *)B, ldb,
             (float *)C, ldc,
             ith, nth};
-        tb.matmul(m, n, task);
+        tb.matmul(m, n);
         return true;
 #else
         return false;
@@ -1023,7 +1021,6 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
     (void)ldc;
     (void)ith;
     (void)nth;
-    (void)task;
     (void)Atype;
     (void)Btype;
     (void)Ctype;
diff --git a/llama/sgemm.h b/llama/sgemm.h
index f29747d0..caf6dd55 100644
--- a/llama/sgemm.h
+++ b/llama/sgemm.h
@@ -7,7 +7,7 @@ extern "C" {
 
 bool llamafile_sgemm(int64_t, int64_t, int64_t, const void *, int64_t,
                      const void *, int64_t, void *, int64_t, int, int,
-                     int, int, int, int);
+                     int, int, int);
 
 #ifdef __cplusplus
 }
diff --git a/llama/stb_image.h b/llama/stb_image.h
index a2f5386a..ed9badad 100644
--- a/llama/stb_image.h
+++ b/llama/stb_image.h
@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *
diff --git a/llama/sync.sh b/llama/sync.sh
index 3a3d6b49..1dfb4d89 100755
--- a/llama/sync.sh
+++ b/llama/sync.sh
@@ -1,9 +1,12 @@
 #!/bin/bash
 
+# Run in the llama directory
+
 set -e
 
 # Set the source directory
-src_dir=$1
+# TODO in the future: src_dir=$1
+src_dir=../llm/llama.cpp
 
 if [ -z "$src_dir" ]; then
   echo "Usage: $0 LLAMA_CPP_DIR"
@@ -11,41 +14,63 @@ if [ -z "$src_dir" ]; then
 fi
 
 # Set the destination directory
-dst_dir=.
+dst_dir=$(pwd)
+
+# TODO remove once we no longer use the submodule
+if [ -z "${OLLAMA_SKIP_PATCHING}" ]; then
+  (cd ../ && git submodule init && git submodule update --force ./llm/llama.cpp)
+
+  # apply patches
+  for patch in $dst_dir/patches/*.diff; do
+    echo "Applying $patch"
+    git -C $src_dir apply "$patch"
+  done
+else 
+  echo "Skipping patching"
+fi
 
 # llama.cpp
-cp $src_dir/unicode.cpp $dst_dir/unicode.cpp
-cp $src_dir/unicode.h $dst_dir/unicode.h
-cp $src_dir/unicode-data.cpp $dst_dir/unicode-data.cpp
-cp $src_dir/unicode-data.h $dst_dir/unicode-data.h
-cp $src_dir/llama.cpp $dst_dir/llama.cpp
-cp $src_dir/llama.h $dst_dir/llama.h
-cp $src_dir/sgemm.cpp $dst_dir/sgemm.cpp
-cp $src_dir/sgemm.h $dst_dir/sgemm.h
+cp $src_dir/src/unicode.cpp $dst_dir/unicode.cpp
+cp $src_dir/src/unicode.h $dst_dir/unicode.h
+cp $src_dir/src/unicode-data.cpp $dst_dir/unicode-data.cpp
+cp $src_dir/src/unicode-data.h $dst_dir/unicode-data.h
+cp $src_dir/src/llama.cpp $dst_dir/llama.cpp
+cp $src_dir/src/llama-impl.h $dst_dir/llama-impl.h
+cp $src_dir/src/llama-vocab.cpp $dst_dir/llama-vocab.cpp
+cp $src_dir/src/llama-vocab.h $dst_dir/llama-vocab.h
+cp $src_dir/src/llama-grammar.cpp $dst_dir/llama-grammar.cpp
+cp $src_dir/src/llama-grammar.h $dst_dir/llama-grammar.h
+cp $src_dir/src/llama-sampling.cpp $dst_dir/llama-sampling.cpp
+cp $src_dir/src/llama-sampling.h $dst_dir/llama-sampling.h
+cp $src_dir/include/llama.h $dst_dir/llama.h
+cp $src_dir/ggml/src/llamafile/sgemm.cpp $dst_dir/sgemm.cpp
+cp $src_dir/ggml/src/llamafile/sgemm.h $dst_dir/sgemm.h
 
 # ggml
-cp $src_dir/ggml.c $dst_dir/ggml.c
-cp $src_dir/ggml.h $dst_dir/ggml.h
-cp $src_dir/ggml-quants.c $dst_dir/ggml-quants.c
-cp $src_dir/ggml-quants.h $dst_dir/ggml-quants.h
-cp $src_dir/ggml-metal.metal $dst_dir/ggml-metal.metal
-cp $src_dir/ggml-metal.h $dst_dir/ggml-metal.h
-cp $src_dir/ggml-metal.m $dst_dir/ggml-metal-darwin_arm64.m
-cp $src_dir/ggml-impl.h $dst_dir/ggml-impl.h
-cp $src_dir/ggml-cuda.h $dst_dir/ggml-cuda.h
-cp $src_dir/ggml-cuda.cu $dst_dir/ggml-cuda.cu
-cp $src_dir/ggml-common.h $dst_dir/ggml-common.h
-cp $src_dir/ggml-backend.h $dst_dir/ggml-backend.h
-cp $src_dir/ggml-backend.c $dst_dir/ggml-backend.c
-cp $src_dir/ggml-backend-impl.h $dst_dir/ggml-backend-impl.h
-cp $src_dir/ggml-alloc.h $dst_dir/ggml-alloc.h
-cp $src_dir/ggml-alloc.c $dst_dir/ggml-alloc.c
+cp $src_dir/ggml/src/ggml.c $dst_dir/ggml.c
+cp $src_dir/ggml/include/ggml.h $dst_dir/ggml.h
+cp $src_dir/ggml/src/ggml-quants.c $dst_dir/ggml-quants.c
+cp $src_dir/ggml/src/ggml-quants.h $dst_dir/ggml-quants.h
+cp $src_dir/ggml/src/ggml-metal.metal $dst_dir/ggml-metal.metal
+cp $src_dir/ggml/include/ggml-metal.h $dst_dir/ggml-metal.h
+cp $src_dir/ggml/src/ggml-metal.m $dst_dir/ggml-metal-darwin_arm64.m
+cp $src_dir/ggml/src/ggml-impl.h $dst_dir/ggml-impl.h
+cp $src_dir/ggml/include/ggml-cuda.h $dst_dir/ggml-cuda.h
+cp $src_dir/ggml/src/ggml-cuda.cu $dst_dir/ggml-cuda.cu
+cp $src_dir/ggml/src/ggml-common.h $dst_dir/ggml-common.h
+cp $src_dir/ggml/include/ggml-backend.h $dst_dir/ggml-backend.h
+cp $src_dir/ggml/src/ggml-backend.c $dst_dir/ggml-backend.c
+cp $src_dir/ggml/src/ggml-backend-impl.h $dst_dir/ggml-backend-impl.h
+cp $src_dir/ggml/include/ggml-alloc.h $dst_dir/ggml-alloc.h
+cp $src_dir/ggml/src/ggml-alloc.c $dst_dir/ggml-alloc.c
+cp $src_dir/ggml/src/ggml-aarch64.h $dst_dir/ggml-aarch64.h
+cp $src_dir/ggml/src/ggml-aarch64.c $dst_dir/ggml-aarch64.c
 
 # ggml-cuda
 mkdir -p $dst_dir/ggml-cuda/template-instances
-cp $src_dir/ggml-cuda/*.cu $dst_dir/ggml-cuda/
-cp $src_dir/ggml-cuda/*.cuh $dst_dir/ggml-cuda/
-cp $src_dir/ggml-cuda/template-instances/*.cu $dst_dir/ggml-cuda/template-instances/
+cp $src_dir/ggml/src/ggml-cuda/*.cu $dst_dir/ggml-cuda/
+cp $src_dir/ggml/src/ggml-cuda/*.cuh $dst_dir/ggml-cuda/
+cp $src_dir/ggml/src/ggml-cuda/template-instances/*.cu $dst_dir/ggml-cuda/template-instances/
 
 # llava
 cp $src_dir/examples/llava/clip.cpp $dst_dir/clip.cpp
@@ -74,11 +99,6 @@ char const *LLAMA_COMPILER = "";
 char const *LLAMA_BUILD_TARGET = "";
 EOF
 
-# apply patches
-for patch in $dst_dir/patches/*.diff; do
-  git apply "$patch"
-done
-
 # add licenses
 sha1=$(git -C $src_dir rev-parse @)
 
diff --git a/llama/unicode-data.cpp b/llama/unicode-data.cpp
index 33efff9f..ae01e5c4 100644
--- a/llama/unicode-data.cpp
+++ b/llama/unicode-data.cpp
@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *
@@ -94,36 +94,36 @@ const std::vector> unicode_ranges_flags = {  // st
 {0x000370, 0x0004},
 {0x000375, 0x0040},
 {0x000376, 0x0004},
-{0x000378, 0x0080},
+{0x000378, 0x0001},
 {0x00037A, 0x0004},
 {0x00037E, 0x0020},
 {0x00037F, 0x0004},
-{0x000380, 0x0080},
+{0x000380, 0x0001},
 {0x000384, 0x0040},
 {0x000386, 0x0004},
 {0x000387, 0x0020},
 {0x000388, 0x0004},
-{0x00038B, 0x0080},
+{0x00038B, 0x0001},
 {0x00038C, 0x0004},
-{0x00038D, 0x0080},
+{0x00038D, 0x0001},
 {0x00038E, 0x0004},
-{0x0003A2, 0x0080},
+{0x0003A2, 0x0001},
 {0x0003A3, 0x0004},
 {0x0003F6, 0x0040},
 {0x0003F7, 0x0004},
 {0x000482, 0x0040},
 {0x000483, 0x0010},
 {0x00048A, 0x0004},
-{0x000530, 0x0080},
+{0x000530, 0x0001},
 {0x000531, 0x0004},
-{0x000557, 0x0080},
+{0x000557, 0x0001},
 {0x000559, 0x0004},
 {0x00055A, 0x0020},
 {0x000560, 0x0004},
 {0x000589, 0x0020},
-{0x00058B, 0x0080},
+{0x00058B, 0x0001},
 {0x00058D, 0x0040},
-{0x000590, 0x0080},
+{0x000590, 0x0001},
 {0x000591, 0x0010},
 {0x0005BE, 0x0020},
 {0x0005BF, 0x0010},
@@ -133,12 +133,13 @@ const std::vector> unicode_ranges_flags = {  // st
 {0x0005C4, 0x0010},
 {0x0005C6, 0x0020},
 {0x0005C7, 0x0010},
-{0x0005C8, 0x0080},
+{0x0005C8, 0x0001},
 {0x0005D0, 0x0004},
-{0x0005EB, 0x0080},
+{0x0005EB, 0x0001},
 {0x0005EF, 0x0004},
 {0x0005F3, 0x0020},
-{0x0005F5, 0x0080},
+{0x0005F5, 0x0001},
+{0x000600, 0x0080},
 {0x000606, 0x0040},
 {0x000609, 0x0020},
 {0x00060B, 0x0040},
@@ -171,16 +172,17 @@ const std::vector> unicode_ranges_flags = {  // st
 {0x0006FD, 0x0040},
 {0x0006FF, 0x0004},
 {0x000700, 0x0020},
-{0x00070E, 0x0080},
+{0x00070E, 0x0001},
+{0x00070F, 0x0080},
 {0x000710, 0x0004},
 {0x000711, 0x0010},
 {0x000712, 0x0004},
 {0x000730, 0x0010},
-{0x00074B, 0x0080},
+{0x00074B, 0x0001},
 {0x00074D, 0x0004},
 {0x0007A6, 0x0010},
 {0x0007B1, 0x0004},
-{0x0007B2, 0x0080},
+{0x0007B2, 0x0001},
 {0x0007C0, 0x0002},
 {0x0007CA, 0x0004},
 {0x0007EB, 0x0010},
@@ -188,7 +190,7 @@ const std::vector> unicode_ranges_flags = {  // st
 {0x0007F6, 0x0040},
 {0x0007F7, 0x0020},
 {0x0007FA, 0x0004},
-{0x0007FB, 0x0080},
+{0x0007FB, 0x0001},
 {0x0007FD, 0x0010},
 {0x0007FE, 0x0040},
 {0x000800, 0x0004},
@@ -199,20 +201,22 @@ const std::vector> unicode_ranges_flags = {  // st
 {0x000825, 0x0010},
 {0x000828, 0x0004},
 {0x000829, 0x0010},
-{0x00082E, 0x0080},
+{0x00082E, 0x0001},
 {0x000830, 0x0020},
-{0x00083F, 0x0080},
+{0x00083F, 0x0001},
 {0x000840, 0x0004},
 {0x000859, 0x0010},
-{0x00085C, 0x0080},
+{0x00085C, 0x0001},
 {0x00085E, 0x0020},
-{0x00085F, 0x0080},
+{0x00085F, 0x0001},
 {0x000860, 0x0004},
-{0x00086B, 0x0080},
+{0x00086B, 0x0001},
 {0x000870, 0x0004},
 {0x000888, 0x0040},
 {0x000889, 0x0004},
-{0x00088F, 0x0080},
+{0x00088F, 0x0001},
+{0x000890, 0x0080},
+{0x000892, 0x0001},
 {0x000898, 0x0010},
 {0x0008A0, 0x0004},
 {0x0008CA, 0x0010},
@@ -231,35 +235,35 @@ const std::vector> unicode_ranges_flags = {  // st
 {0x000970, 0x0020},
 {0x000971, 0x0004},
 {0x000981, 0x0010},
-{0x000984, 0x0080},
+{0x000984, 0x0001},
 {0x000985, 0x0004},
-{0x00098D, 0x0080},
+{0x00098D, 0x0001},
 {0x00098F, 0x0004},
-{0x000991, 0x0080},
+{0x000991, 0x0001},
 {0x000993, 0x0004},
-{0x0009A9, 0x0080},
+{0x0009A9, 0x0001},
 {0x0009AA, 0x0004},
-{0x0009B1, 0x0080},
+{0x0009B1, 0x0001},
 {0x0009B2, 0x0004},
-{0x0009B3, 0x0080},
+{0x0009B3, 0x0001},
 {0x0009B6, 0x0004},
-{0x0009BA, 0x0080},
+{0x0009BA, 0x0001},
 {0x0009BC, 0x0010},
 {0x0009BD, 0x0004},
 {0x0009BE, 0x0010},
-{0x0009C5, 0x0080},
+{0x0009C5, 0x0001},
 {0x0009C7, 0x0010},
-{0x0009C9, 0x0080},
+{0x0009C9, 0x0001},
 {0x0009CB, 0x0010},
 {0x0009CE, 0x0004},
-{0x0009CF, 0x0080},
+{0x0009CF, 0x0001},
 {0x0009D7, 0x0010},
-{0x0009D8, 0x0080},
+{0x0009D8, 0x0001},
 {0x0009DC, 0x0004},
-{0x0009DE, 0x0080},
+{0x0009DE, 0x0001},
 {0x0009DF, 0x0004},
 {0x0009E2, 0x0010},
-{0x0009E4, 0x0080},
+{0x0009E4, 0x0001},
 {0x0009E6, 0x0002},
 {0x0009F0, 0x0004},
 {0x0009F2, 0x0040},
@@ -268,173 +272,173 @@ const std::vector> unicode_ranges_flags = {  // st
 {0x0009FC, 0x0004},
 {0x0009FD, 0x0020},
 {0x0009FE, 0x0010},
-{0x0009FF, 0x0080},
+{0x0009FF, 0x0001},
 {0x000A01, 0x0010},
-{0x000A04, 0x0080},
+{0x000A04, 0x0001},
 {0x000A05, 0x0004},
-{0x000A0B, 0x0080},
+{0x000A0B, 0x0001},
 {0x000A0F, 0x0004},
-{0x000A11, 0x0080},
+{0x000A11, 0x0001},
 {0x000A13, 0x0004},
-{0x000A29, 0x0080},
+{0x000A29, 0x0001},
 {0x000A2A, 0x0004},
-{0x000A31, 0x0080},
+{0x000A31, 0x0001},
 {0x000A32, 0x0004},
-{0x000A34, 0x0080},
+{0x000A34, 0x0001},
 {0x000A35, 0x0004},
-{0x000A37, 0x0080},
+{0x000A37, 0x0001},
 {0x000A38, 0x0004},
-{0x000A3A, 0x0080},
+{0x000A3A, 0x0001},
 {0x000A3C, 0x0010},
-{0x000A3D, 0x0080},
+{0x000A3D, 0x0001},
 {0x000A3E, 0x0010},
-{0x000A43, 0x0080},
+{0x000A43, 0x0001},
 {0x000A47, 0x0010},
-{0x000A49, 0x0080},
+{0x000A49, 0x0001},
 {0x000A4B, 0x0010},
-{0x000A4E, 0x0080},
+{0x000A4E, 0x0001},
 {0x000A51, 0x0010},
-{0x000A52, 0x0080},
+{0x000A52, 0x0001},
 {0x000A59, 0x0004},
-{0x000A5D, 0x0080},
+{0x000A5D, 0x0001},
 {0x000A5E, 0x0004},
-{0x000A5F, 0x0080},
+{0x000A5F, 0x0001},
 {0x000A66, 0x0002},
 {0x000A70, 0x0010},
 {0x000A72, 0x0004},
 {0x000A75, 0x0010},
 {0x000A76, 0x0020},
-{0x000A77, 0x0080},
+{0x000A77, 0x0001},
 {0x000A81, 0x0010},
-{0x000A84, 0x0080},
+{0x000A84, 0x0001},
 {0x000A85, 0x0004},
-{0x000A8E, 0x0080},
+{0x000A8E, 0x0001},
 {0x000A8F, 0x0004},
-{0x000A92, 0x0080},
+{0x000A92, 0x0001},
 {0x000A93, 0x0004},
-{0x000AA9, 0x0080},
+{0x000AA9, 0x0001},
 {0x000AAA, 0x0004},
-{0x000AB1, 0x0080},
+{0x000AB1, 0x0001},
 {0x000AB2, 0x0004},
-{0x000AB4, 0x0080},
+{0x000AB4, 0x0001},
 {0x000AB5, 0x0004},
-{0x000ABA, 0x0080},
+{0x000ABA, 0x0001},
 {0x000ABC, 0x0010},
 {0x000ABD, 0x0004},
 {0x000ABE, 0x0010},
-{0x000AC6, 0x0080},
+{0x000AC6, 0x0001},
 {0x000AC7, 0x0010},
-{0x000ACA, 0x0080},
+{0x000ACA, 0x0001},
 {0x000ACB, 0x0010},
-{0x000ACE, 0x0080},
+{0x000ACE, 0x0001},
 {0x000AD0, 0x0004},
-{0x000AD1, 0x0080},
+{0x000AD1, 0x0001},
 {0x000AE0, 0x0004},
 {0x000AE2, 0x0010},
-{0x000AE4, 0x0080},
+{0x000AE4, 0x0001},
 {0x000AE6, 0x0002},
 {0x000AF0, 0x0020},
 {0x000AF1, 0x0040},
-{0x000AF2, 0x0080},
+{0x000AF2, 0x0001},
 {0x000AF9, 0x0004},
 {0x000AFA, 0x0010},
-{0x000B00, 0x0080},
+{0x000B00, 0x0001},
 {0x000B01, 0x0010},
-{0x000B04, 0x0080},
+{0x000B04, 0x0001},
 {0x000B05, 0x0004},
-{0x000B0D, 0x0080},
+{0x000B0D, 0x0001},
 {0x000B0F, 0x0004},
-{0x000B11, 0x0080},
+{0x000B11, 0x0001},
 {0x000B13, 0x0004},
-{0x000B29, 0x0080},
+{0x000B29, 0x0001},
 {0x000B2A, 0x0004},
-{0x000B31, 0x0080},
+{0x000B31, 0x0001},
 {0x000B32, 0x0004},
-{0x000B34, 0x0080},
+{0x000B34, 0x0001},
 {0x000B35, 0x0004},
-{0x000B3A, 0x0080},
+{0x000B3A, 0x0001},
 {0x000B3C, 0x0010},
 {0x000B3D, 0x0004},
 {0x000B3E, 0x0010},
-{0x000B45, 0x0080},
+{0x000B45, 0x0001},
 {0x000B47, 0x0010},
-{0x000B49, 0x0080},
+{0x000B49, 0x0001},
 {0x000B4B, 0x0010},
-{0x000B4E, 0x0080},
+{0x000B4E, 0x0001},
 {0x000B55, 0x0010},
-{0x000B58, 0x0080},
+{0x000B58, 0x0001},
 {0x000B5C, 0x0004},
-{0x000B5E, 0x0080},
+{0x000B5E, 0x0001},
 {0x000B5F, 0x0004},
 {0x000B62, 0x0010},
-{0x000B64, 0x0080},
+{0x000B64, 0x0001},
 {0x000B66, 0x0002},
 {0x000B70, 0x0040},
 {0x000B71, 0x0004},
 {0x000B72, 0x0002},
-{0x000B78, 0x0080},
+{0x000B78, 0x0001},
 {0x000B82, 0x0010},
 {0x000B83, 0x0004},
-{0x000B84, 0x0080},
+{0x000B84, 0x0001},
 {0x000B85, 0x0004},
-{0x000B8B, 0x0080},
+{0x000B8B, 0x0001},
 {0x000B8E, 0x0004},
-{0x000B91, 0x0080},
+{0x000B91, 0x0001},
 {0x000B92, 0x0004},
-{0x000B96, 0x0080},
+{0x000B96, 0x0001},
 {0x000B99, 0x0004},
-{0x000B9B, 0x0080},
+{0x000B9B, 0x0001},
 {0x000B9C, 0x0004},
-{0x000B9D, 0x0080},
+{0x000B9D, 0x0001},
 {0x000B9E, 0x0004},
-{0x000BA0, 0x0080},
+{0x000BA0, 0x0001},
 {0x000BA3, 0x0004},
-{0x000BA5, 0x0080},
+{0x000BA5, 0x0001},
 {0x000BA8, 0x0004},
-{0x000BAB, 0x0080},
+{0x000BAB, 0x0001},
 {0x000BAE, 0x0004},
-{0x000BBA, 0x0080},
+{0x000BBA, 0x0001},
 {0x000BBE, 0x0010},
-{0x000BC3, 0x0080},
+{0x000BC3, 0x0001},
 {0x000BC6, 0x0010},
-{0x000BC9, 0x0080},
+{0x000BC9, 0x0001},
 {0x000BCA, 0x0010},
-{0x000BCE, 0x0080},
+{0x000BCE, 0x0001},
 {0x000BD0, 0x0004},
-{0x000BD1, 0x0080},
+{0x000BD1, 0x0001},
 {0x000BD7, 0x0010},
-{0x000BD8, 0x0080},
+{0x000BD8, 0x0001},
 {0x000BE6, 0x0002},
 {0x000BF3, 0x0040},
-{0x000BFB, 0x0080},
+{0x000BFB, 0x0001},
 {0x000C00, 0x0010},
 {0x000C05, 0x0004},
-{0x000C0D, 0x0080},
+{0x000C0D, 0x0001},
 {0x000C0E, 0x0004},
-{0x000C11, 0x0080},
+{0x000C11, 0x0001},
 {0x000C12, 0x0004},
-{0x000C29, 0x0080},
+{0x000C29, 0x0001},
 {0x000C2A, 0x0004},
-{0x000C3A, 0x0080},
+{0x000C3A, 0x0001},
 {0x000C3C, 0x0010},
 {0x000C3D, 0x0004},
 {0x000C3E, 0x0010},
-{0x000C45, 0x0080},
+{0x000C45, 0x0001},
 {0x000C46, 0x0010},
-{0x000C49, 0x0080},
+{0x000C49, 0x0001},
 {0x000C4A, 0x0010},
-{0x000C4E, 0x0080},
+{0x000C4E, 0x0001},
 {0x000C55, 0x0010},
-{0x000C57, 0x0080},
+{0x000C57, 0x0001},
 {0x000C58, 0x0004},
-{0x000C5B, 0x0080},
+{0x000C5B, 0x0001},
 {0x000C5D, 0x0004},
-{0x000C5E, 0x0080},
+{0x000C5E, 0x0001},
 {0x000C60, 0x0004},
 {0x000C62, 0x0010},
-{0x000C64, 0x0080},
+{0x000C64, 0x0001},
 {0x000C66, 0x0002},
-{0x000C70, 0x0080},
+{0x000C70, 0x0001},
 {0x000C77, 0x0020},
 {0x000C78, 0x0002},
 {0x000C7F, 0x0040},
@@ -442,124 +446,124 @@ const std::vector> unicode_ranges_flags = {  // st
 {0x000C81, 0x0010},
 {0x000C84, 0x0020},
 {0x000C85, 0x0004},
-{0x000C8D, 0x0080},
+{0x000C8D, 0x0001},
 {0x000C8E, 0x0004},
-{0x000C91, 0x0080},
+{0x000C91, 0x0001},
 {0x000C92, 0x0004},
-{0x000CA9, 0x0080},
+{0x000CA9, 0x0001},
 {0x000CAA, 0x0004},
-{0x000CB4, 0x0080},
+{0x000CB4, 0x0001},
 {0x000CB5, 0x0004},
-{0x000CBA, 0x0080},
+{0x000CBA, 0x0001},
 {0x000CBC, 0x0010},
 {0x000CBD, 0x0004},
 {0x000CBE, 0x0010},
-{0x000CC5, 0x0080},
+{0x000CC5, 0x0001},
 {0x000CC6, 0x0010},
-{0x000CC9, 0x0080},
+{0x000CC9, 0x0001},
 {0x000CCA, 0x0010},
-{0x000CCE, 0x0080},
+{0x000CCE, 0x0001},
 {0x000CD5, 0x0010},
-{0x000CD7, 0x0080},
+{0x000CD7, 0x0001},
 {0x000CDD, 0x0004},
-{0x000CDF, 0x0080},
+{0x000CDF, 0x0001},
 {0x000CE0, 0x0004},
 {0x000CE2, 0x0010},
-{0x000CE4, 0x0080},
+{0x000CE4, 0x0001},
 {0x000CE6, 0x0002},
-{0x000CF0, 0x0080},
+{0x000CF0, 0x0001},
 {0x000CF1, 0x0004},
 {0x000CF3, 0x0010},
-{0x000CF4, 0x0080},
+{0x000CF4, 0x0001},
 {0x000D00, 0x0010},
 {0x000D04, 0x0004},
-{0x000D0D, 0x0080},
+{0x000D0D, 0x0001},
 {0x000D0E, 0x0004},
-{0x000D11, 0x0080},
+{0x000D11, 0x0001},
 {0x000D12, 0x0004},
 {0x000D3B, 0x0010},
 {0x000D3D, 0x0004},
 {0x000D3E, 0x0010},
-{0x000D45, 0x0080},
+{0x000D45, 0x0001},
 {0x000D46, 0x0010},
-{0x000D49, 0x0080},
+{0x000D49, 0x0001},
 {0x000D4A, 0x0010},
 {0x000D4E, 0x0004},
 {0x000D4F, 0x0040},
-{0x000D50, 0x0080},
+{0x000D50, 0x0001},
 {0x000D54, 0x0004},
 {0x000D57, 0x0010},
 {0x000D58, 0x0002},
 {0x000D5F, 0x0004},
 {0x000D62, 0x0010},
-{0x000D64, 0x0080},
+{0x000D64, 0x0001},
 {0x000D66, 0x0002},
 {0x000D79, 0x0040},
 {0x000D7A, 0x0004},
-{0x000D80, 0x0080},
+{0x000D80, 0x0001},
 {0x000D81, 0x0010},
-{0x000D84, 0x0080},
+{0x000D84, 0x0001},
 {0x000D85, 0x0004},
-{0x000D97, 0x0080},
+{0x000D97, 0x0001},
 {0x000D9A, 0x0004},
-{0x000DB2, 0x0080},
+{0x000DB2, 0x0001},
 {0x000DB3, 0x0004},
-{0x000DBC, 0x0080},
+{0x000DBC, 0x0001},
 {0x000DBD, 0x0004},
-{0x000DBE, 0x0080},
+{0x000DBE, 0x0001},
 {0x000DC0, 0x0004},
-{0x000DC7, 0x0080},
+{0x000DC7, 0x0001},
 {0x000DCA, 0x0010},
-{0x000DCB, 0x0080},
+{0x000DCB, 0x0001},
 {0x000DCF, 0x0010},
-{0x000DD5, 0x0080},
+{0x000DD5, 0x0001},
 {0x000DD6, 0x0010},
-{0x000DD7, 0x0080},
+{0x000DD7, 0x0001},
 {0x000DD8, 0x0010},
-{0x000DE0, 0x0080},
+{0x000DE0, 0x0001},
 {0x000DE6, 0x0002},
-{0x000DF0, 0x0080},
+{0x000DF0, 0x0001},
 {0x000DF2, 0x0010},
 {0x000DF4, 0x0020},
-{0x000DF5, 0x0080},
+{0x000DF5, 0x0001},
 {0x000E01, 0x0004},
 {0x000E31, 0x0010},
 {0x000E32, 0x0004},
 {0x000E34, 0x0010},
-{0x000E3B, 0x0080},
+{0x000E3B, 0x0001},
 {0x000E3F, 0x0040},
 {0x000E40, 0x0004},
 {0x000E47, 0x0010},
 {0x000E4F, 0x0020},
 {0x000E50, 0x0002},
 {0x000E5A, 0x0020},
-{0x000E5C, 0x0080},
+{0x000E5C, 0x0001},
 {0x000E81, 0x0004},
-{0x000E83, 0x0080},
+{0x000E83, 0x0001},
 {0x000E84, 0x0004},
-{0x000E85, 0x0080},
+{0x000E85, 0x0001},
 {0x000E86, 0x0004},
-{0x000E8B, 0x0080},
+{0x000E8B, 0x0001},
 {0x000E8C, 0x0004},
-{0x000EA4, 0x0080},
+{0x000EA4, 0x0001},
 {0x000EA5, 0x0004},
-{0x000EA6, 0x0080},
+{0x000EA6, 0x0001},
 {0x000EA7, 0x0004},
 {0x000EB1, 0x0010},
 {0x000EB2, 0x0004},
 {0x000EB4, 0x0010},
 {0x000EBD, 0x0004},
-{0x000EBE, 0x0080},
+{0x000EBE, 0x0001},
 {0x000EC0, 0x0004},
-{0x000EC5, 0x0080},
+{0x000EC5, 0x0001},
 {0x000EC6, 0x0004},
-{0x000EC7, 0x0080},
+{0x000EC7, 0x0001},
 {0x000EC8, 0x0010},
-{0x000ECF, 0x0080},
+{0x000ECF, 0x0001},
 {0x000ED0, 0x0002},
-{0x000EDA, 0x0080},
+{0x000EDA, 0x0001},
 {0x000EDC, 0x0004},
-{0x000EE0, 0x0080},
+{0x000EE0, 0x0001},
 {0x000F00, 0x0004},
 {0x000F01, 0x0040},
 {0x000F04, 0x0020},
@@ -578,26 +582,26 @@ const std::vector> unicode_ranges_flags = {  // st
 {0x000F3A, 0x0020},
 {0x000F3E, 0x0010},
 {0x000F40, 0x0004},
-{0x000F48, 0x0080},
+{0x000F48, 0x0001},
 {0x000F49, 0x0004},
-{0x000F6D, 0x0080},
+{0x000F6D, 0x0001},
 {0x000F71, 0x0010},
 {0x000F85, 0x0020},
 {0x000F86, 0x0010},
 {0x000F88, 0x0004},
 {0x000F8D, 0x0010},
-{0x000F98, 0x0080},
+{0x000F98, 0x0001},
 {0x000F99, 0x0010},
-{0x000FBD, 0x0080},
+{0x000FBD, 0x0001},
 {0x000FBE, 0x0040},
 {0x000FC6, 0x0010},
 {0x000FC7, 0x0040},
-{0x000FCD, 0x0080},
+{0x000FCD, 0x0001},
 {0x000FCE, 0x0040},
 {0x000FD0, 0x0020},
 {0x000FD5, 0x0040},
 {0x000FD9, 0x0020},
-{0x000FDB, 0x0080},
+{0x000FDB, 0x0001},
 {0x001000, 0x0004},
 {0x00102B, 0x0010},
 {0x00103F, 0x0004},
@@ -621,56 +625,56 @@ const std::vector> unicode_ranges_flags = {  // st
 {0x00109A, 0x0010},
 {0x00109E, 0x0040},
 {0x0010A0, 0x0004},
-{0x0010C6, 0x0080},
+{0x0010C6, 0x0001},
 {0x0010C7, 0x0004},
-{0x0010C8, 0x0080},
+{0x0010C8, 0x0001},
 {0x0010CD, 0x0004},
-{0x0010CE, 0x0080},
+{0x0010CE, 0x0001},
 {0x0010D0, 0x0004},
 {0x0010FB, 0x0020},
 {0x0010FC, 0x0004},
-{0x001249, 0x0080},
+{0x001249, 0x0001},
 {0x00124A, 0x0004},
-{0x00124E, 0x0080},
+{0x00124E, 0x0001},
 {0x001250, 0x0004},
-{0x001257, 0x0080},
+{0x001257, 0x0001},
 {0x001258, 0x0004},
-{0x001259, 0x0080},
+{0x001259, 0x0001},
 {0x00125A, 0x0004},
-{0x00125E, 0x0080},
+{0x00125E, 0x0001},
 {0x001260, 0x0004},
-{0x001289, 0x0080},
+{0x001289, 0x0001},
 {0x00128A, 0x0004},
-{0x00128E, 0x0080},
+{0x00128E, 0x0001},
 {0x001290, 0x0004},
-{0x0012B1, 0x0080},
+{0x0012B1, 0x0001},
 {0x0012B2, 0x0004},
-{0x0012B6, 0x0080},
+{0x0012B6, 0x0001},
 {0x0012B8, 0x0004},
-{0x0012BF, 0x0080},
+{0x0012BF, 0x0001},
 {0x0012C0, 0x0004},
-{0x0012C1, 0x0080},
+{0x0012C1, 0x0001},
 {0x0012C2, 0x0004},
-{0x0012C6, 0x0080},
+{0x0012C6, 0x0001},
 {0x0012C8, 0x0004},
-{0x0012D7, 0x0080},
+{0x0012D7, 0x0001},
 {0x0012D8, 0x0004},
-{0x001311, 0x0080},
+{0x001311, 0x0001},
 {0x001312, 0x0004},
-{0x001316, 0x0080},
+{0x001316, 0x0001},
 {0x001318, 0x0004},
-{0x00135B, 0x0080},
+{0x00135B, 0x0001},
 {0x00135D, 0x0010},
 {0x001360, 0x0020},
 {0x001369, 0x0002},
-{0x00137D, 0x0080},
+{0x00137D, 0x0001},
 {0x001380, 0x0004},
 {0x001390, 0x0040},
-{0x00139A, 0x0080},
+{0x00139A, 0x0001},
 {0x0013A0, 0x0004},
-{0x0013F6, 0x0080},
+{0x0013F6, 0x0001},
 {0x0013F8, 0x0004},
-{0x0013FE, 0x0080},
+{0x0013FE, 0x0001},
 {0x001400, 0x0020},
 {0x001401, 0x0004},
 {0x00166D, 0x0040},
@@ -679,28 +683,28 @@ const std::vector> unicode_ranges_flags = {  // st
 {0x001680, 0x0008},
 {0x001681, 0x0004},
 {0x00169B, 0x0020},
-{0x00169D, 0x0080},
+{0x00169D, 0x0001},
 {0x0016A0, 0x0004},
 {0x0016EB, 0x0020},
 {0x0016EE, 0x0002},
 {0x0016F1, 0x0004},
-{0x0016F9, 0x0080},
+{0x0016F9, 0x0001},
 {0x001700, 0x0004},
 {0x001712, 0x0010},
-{0x001716, 0x0080},
+{0x001716, 0x0001},
 {0x00171F, 0x0004},
 {0x001732, 0x0010},
 {0x001735, 0x0020},
-{0x001737, 0x0080},
+{0x001737, 0x0001},
 {0x001740, 0x0004},
 {0x001752, 0x0010},
-{0x001754, 0x0080},
+{0x001754, 0x0001},
 {0x001760, 0x0004},
-{0x00176D, 0x0080},
+{0x00176D, 0x0001},
 {0x00176E, 0x0004},
-{0x001771, 0x0080},
+{0x001771, 0x0001},
 {0x001772, 0x0010},
-{0x001774, 0x0080},
+{0x001774, 0x0001},
 {0x001780, 0x0004},
 {0x0017B4, 0x0010},
 {0x0017D4, 0x0020},
@@ -709,80 +713,80 @@ const std::vector> unicode_ranges_flags = {  // st
 {0x0017DB, 0x0040},
 {0x0017DC, 0x0004},
 {0x0017DD, 0x0010},
-{0x0017DE, 0x0080},
+{0x0017DE, 0x0001},
 {0x0017E0, 0x0002},
-{0x0017EA, 0x0080},
+{0x0017EA, 0x0001},
 {0x0017F0, 0x0002},
-{0x0017FA, 0x0080},
+{0x0017FA, 0x0001},
 {0x001800, 0x0020},
 {0x00180B, 0x0010},
 {0x00180E, 0x0080},
 {0x00180F, 0x0010},
 {0x001810, 0x0002},
-{0x00181A, 0x0080},
+{0x00181A, 0x0001},
 {0x001820, 0x0004},
-{0x001879, 0x0080},
+{0x001879, 0x0001},
 {0x001880, 0x0004},
 {0x001885, 0x0010},
 {0x001887, 0x0004},
 {0x0018A9, 0x0010},
 {0x0018AA, 0x0004},
-{0x0018AB, 0x0080},
+{0x0018AB, 0x0001},
 {0x0018B0, 0x0004},
-{0x0018F6, 0x0080},
+{0x0018F6, 0x0001},
 {0x001900, 0x0004},
-{0x00191F, 0x0080},
+{0x00191F, 0x0001},
 {0x001920, 0x0010},
-{0x00192C, 0x0080},
+{0x00192C, 0x0001},
 {0x001930, 0x0010},
-{0x00193C, 0x0080},
+{0x00193C, 0x0001},
 {0x001940, 0x0040},
-{0x001941, 0x0080},
+{0x001941, 0x0001},
 {0x001944, 0x0020},
 {0x001946, 0x0002},
 {0x001950, 0x0004},
-{0x00196E, 0x0080},
+{0x00196E, 0x0001},
 {0x001970, 0x0004},
-{0x001975, 0x0080},
+{0x001975, 0x0001},
 {0x001980, 0x0004},
-{0x0019AC, 0x0080},
+{0x0019AC, 0x0001},
 {0x0019B0, 0x0004},
-{0x0019CA, 0x0080},
+{0x0019CA, 0x0001},
 {0x0019D0, 0x0002},
-{0x0019DB, 0x0080},
+{0x0019DB, 0x0001},
 {0x0019DE, 0x0040},
 {0x001A00, 0x0004},
 {0x001A17, 0x0010},
-{0x001A1C, 0x0080},
+{0x001A1C, 0x0001},
 {0x001A1E, 0x0020},
 {0x001A20, 0x0004},
 {0x001A55, 0x0010},
-{0x001A5F, 0x0080},
+{0x001A5F, 0x0001},
 {0x001A60, 0x0010},
-{0x001A7D, 0x0080},
+{0x001A7D, 0x0001},
 {0x001A7F, 0x0010},
 {0x001A80, 0x0002},
-{0x001A8A, 0x0080},
+{0x001A8A, 0x0001},
 {0x001A90, 0x0002},
-{0x001A9A, 0x0080},
+{0x001A9A, 0x0001},
 {0x001AA0, 0x0020},
 {0x001AA7, 0x0004},
 {0x001AA8, 0x0020},
-{0x001AAE, 0x0080},
+{0x001AAE, 0x0001},
 {0x001AB0, 0x0010},
-{0x001ACF, 0x0080},
+{0x001ACF, 0x0001},
 {0x001B00, 0x0010},
 {0x001B05, 0x0004},
 {0x001B34, 0x0010},
 {0x001B45, 0x0004},
-{0x001B4D, 0x0080},
+{0x001B4D, 0x0001},
 {0x001B50, 0x0002},
 {0x001B5A, 0x0020},
 {0x001B61, 0x0040},
 {0x001B6B, 0x0010},
 {0x001B74, 0x0040},
 {0x001B7D, 0x0020},
-{0x001B7F, 0x0080},
+{0x001B7F, 0x0001},
 {0x001B80, 0x0010},
 {0x001B83, 0x0004},
 {0x001BA1, 0x0010},
@@ -790,25 +794,25 @@ const std::vector> unicode_ranges_flags = {  // st
 {0x001BB0, 0x0002},
 {0x001BBA, 0x0004},
 {0x001BE6, 0x0010},
-{0x001BF4, 0x0080},
+{0x001BF4, 0x0001},
 {0x001BFC, 0x0020},
 {0x001C00, 0x0004},
 {0x001C24, 0x0010},
-{0x001C38, 0x0080},
+{0x001C38, 0x0001},
 {0x001C3B, 0x0020},
 {0x001C40, 0x0002},
-{0x001C4A, 0x0080},
+{0x001C4A, 0x0001},
 {0x001C4D, 0x0004},
 {0x001C50, 0x0002},
 {0x001C5A, 0x0004},
 {0x001C7E, 0x0020},
 {0x001C80, 0x0004},
-{0x001C89, 0x0080},
+{0x001C89, 0x0001},
 {0x001C90, 0x0004},
-{0x001CBB, 0x0080},
+{0x001CBB, 0x0001},
 {0x001CBD, 0x0004},
 {0x001CC0, 0x0020},
-{0x001CC8, 0x0080},
+{0x001CC8, 0x0001},
 {0x001CD0, 0x0010},
 {0x001CD3, 0x0020},
 {0x001CD4, 0x0010},
@@ -819,50 +823,50 @@ const std::vector> unicode_ranges_flags = {  // st
 {0x001CF5, 0x0004},
 {0x001CF7, 0x0010},
 {0x001CFA, 0x0004},
-{0x001CFB, 0x0080},
+{0x001CFB, 0x0001},
 {0x001D00, 0x0004},
 {0x001DC0, 0x0010},
 {0x001E00, 0x0004},
-{0x001F16, 0x0080},
+{0x001F16, 0x0001},
 {0x001F18, 0x0004},
-{0x001F1E, 0x0080},
+{0x001F1E, 0x0001},
 {0x001F20, 0x0004},
-{0x001F46, 0x0080},
+{0x001F46, 0x0001},
 {0x001F48, 0x0004},
-{0x001F4E, 0x0080},
+{0x001F4E, 0x0001},
 {0x001F50, 0x0004},
-{0x001F58, 0x0080},
+{0x001F58, 0x0001},
 {0x001F59, 0x0004},
-{0x001F5A, 0x0080},
+{0x001F5A, 0x0001},
 {0x001F5B, 0x0004},
-{0x001F5C, 0x0080},
+{0x001F5C, 0x0001},
 {0x001F5D, 0x0004},
-{0x001F5E, 0x0080},
+{0x001F5E, 0x0001},
 {0x001F5F, 0x0004},
-{0x001F7E, 0x0080},
+{0x001F7E, 0x0001},
 {0x001F80, 0x0004},
-{0x001FB5, 0x0080},
+{0x001FB5, 0x0001},
 {0x001FB6, 0x0004},
 {0x001FBD, 0x0040},
 {0x001FBE, 0x0004},
 {0x001FBF, 0x0040},
 {0x001FC2, 0x0004},
-{0x001FC5, 0x0080},
+{0x001FC5, 0x0001},
 {0x001FC6, 0x0004},
 {0x001FCD, 0x0040},
 {0x001FD0, 0x0004},
-{0x001FD4, 0x0080},
+{0x001FD4, 0x0001},
 {0x001FD6, 0x0004},
-{0x001FDC, 0x0080},
+{0x001FDC, 0x0001},
 {0x001FDD, 0x0040},
 {0x001FE0, 0x0004},
 {0x001FED, 0x0040},
-{0x001FF0, 0x0080},
+{0x001FF0, 0x0001},
 {0x001FF2, 0x0004},
-{0x001FF5, 0x0080},
+{0x001FF5, 0x0001},
 {0x001FF6, 0x0004},
 {0x001FFD, 0x0040},
-{0x001FFF, 0x0080},
+{0x001FFF, 0x0001},
 {0x002000, 0x0008},
 {0x00200B, 0x0080},
 {0x002010, 0x0020},
@@ -876,9 +880,11 @@ const std::vector> unicode_ranges_flags = {  // st
 {0x002053, 0x0020},
 {0x00205F, 0x0008},
 {0x002060, 0x0080},
+{0x002065, 0x0001},
+{0x002066, 0x0080},
 {0x002070, 0x0002},
 {0x002071, 0x0004},
-{0x002072, 0x0080},
+{0x002072, 0x0001},
 {0x002074, 0x0002},
 {0x00207A, 0x0040},
 {0x00207D, 0x0020},
@@ -886,13 +892,13 @@ const std::vector> unicode_ranges_flags = {  // st
 {0x002080, 0x0002},
 {0x00208A, 0x0040},
 {0x00208D, 0x0020},
-{0x00208F, 0x0080},
+{0x00208F, 0x0001},
 {0x002090, 0x0004},
-{0x00209D, 0x0080},
+{0x00209D, 0x0001},
 {0x0020A0, 0x0040},
-{0x0020C1, 0x0080},
+{0x0020C1, 0x0001},
 {0x0020D0, 0x0010},
-{0x0020F1, 0x0080},
+{0x0020F1, 0x0001},
 {0x002100, 0x0040},
 {0x002102, 0x0004},
 {0x002103, 0x0040},
@@ -924,15 +930,15 @@ const std::vector> unicode_ranges_flags = {  // st
 {0x002183, 0x0004},
 {0x002185, 0x0002},
 {0x00218A, 0x0040},
-{0x00218C, 0x0080},
+{0x00218C, 0x0001},
 {0x002190, 0x0040},
 {0x002308, 0x0020},
 {0x00230C, 0x0040},
 {0x002329, 0x0020},
 {0x00232B, 0x0040},
-{0x002427, 0x0080},
+{0x002427, 0x0001},
 {0x002440, 0x0040},
-{0x00244B, 0x0080},
+{0x00244B, 0x0001},
 {0x002460, 0x0002},
 {0x00249C, 0x0040},
 {0x0024EA, 0x0002},
@@ -950,62 +956,62 @@ const std::vector> unicode_ranges_flags = {  // st
 {0x0029DC, 0x0040},
 {0x0029FC, 0x0020},
 {0x0029FE, 0x0040},
-{0x002B74, 0x0080},
+{0x002B74, 0x0001},
 {0x002B76, 0x0040},
-{0x002B96, 0x0080},
+{0x002B96, 0x0001},
 {0x002B97, 0x0040},
 {0x002C00, 0x0004},
 {0x002CE5, 0x0040},
 {0x002CEB, 0x0004},
 {0x002CEF, 0x0010},
 {0x002CF2, 0x0004},
-{0x002CF4, 0x0080},
+{0x002CF4, 0x0001},
 {0x002CF9, 0x0020},
 {0x002CFD, 0x0002},
 {0x002CFE, 0x0020},
 {0x002D00, 0x0004},
-{0x002D26, 0x0080},
+{0x002D26, 0x0001},
 {0x002D27, 0x0004},
-{0x002D28, 0x0080},
+{0x002D28, 0x0001},
 {0x002D2D, 0x0004},
-{0x002D2E, 0x0080},
+{0x002D2E, 0x0001},
 {0x002D30, 0x0004},
-{0x002D68, 0x0080},
+{0x002D68, 0x0001},
 {0x002D6F, 0x0004},
 {0x002D70, 0x0020},
-{0x002D71, 0x0080},
+{0x002D71, 0x0001},
 {0x002D7F, 0x0010},
 {0x002D80, 0x0004},
-{0x002D97, 0x0080},
+{0x002D97, 0x0001},
 {0x002DA0, 0x0004},
-{0x002DA7, 0x0080},
+{0x002DA7, 0x0001},
 {0x002DA8, 0x0004},
-{0x002DAF, 0x0080},
+{0x002DAF, 0x0001},
 {0x002DB0, 0x0004},
-{0x002DB7, 0x0080},
+{0x002DB7, 0x0001},
 {0x002DB8, 0x0004},
-{0x002DBF, 0x0080},
+{0x002DBF, 0x0001},
 {0x002DC0, 0x0004},
-{0x002DC7, 0x0080},
+{0x002DC7, 0x0001},
 {0x002DC8, 0x0004},
-{0x002DCF, 0x0080},
+{0x002DCF, 0x0001},
 {0x002DD0, 0x0004},
-{0x002DD7, 0x0080},
+{0x002DD7, 0x0001},
 {0x002DD8, 0x0004},
-{0x002DDF, 0x0080},
+{0x002DDF, 0x0001},
 {0x002DE0, 0x0010},
 {0x002E00, 0x0020},
 {0x002E2F, 0x0004},
 {0x002E30, 0x0020},
 {0x002E50, 0x0040},
 {0x002E52, 0x0020},
-{0x002E5E, 0x0080},
+{0x002E5E, 0x0001},
 {0x002E80, 0x0040},
-{0x002E9A, 0x0080},
+{0x002E9A, 0x0001},
 {0x002E9B, 0x0040},
-{0x002EF4, 0x0080},
+{0x002EF4, 0x0001},
 {0x002F00, 0x0040},
-{0x002FD6, 0x0080},
+{0x002FD6, 0x0001},
 {0x002FF0, 0x0040},
 {0x003000, 0x0008},
 {0x003001, 0x0020},
@@ -1025,9 +1031,9 @@ const std::vector> unicode_ranges_flags = {  // st
 {0x00303B, 0x0004},
 {0x00303D, 0x0020},
 {0x00303E, 0x0040},
-{0x003040, 0x0080},
+{0x003040, 0x0001},
 {0x003041, 0x0004},
-{0x003097, 0x0080},
+{0x003097, 0x0001},
 {0x003099, 0x0010},
 {0x00309B, 0x0040},
 {0x00309D, 0x0004},
@@ -1035,21 +1041,21 @@ const std::vector> unicode_ranges_flags = {  // st
 {0x0030A1, 0x0004},
 {0x0030FB, 0x0020},
 {0x0030FC, 0x0004},
-{0x003100, 0x0080},
+{0x003100, 0x0001},
 {0x003105, 0x0004},
-{0x003130, 0x0080},
+{0x003130, 0x0001},
 {0x003131, 0x0004},
-{0x00318F, 0x0080},
+{0x00318F, 0x0001},
 {0x003190, 0x0040},
 {0x003192, 0x0002},
 {0x003196, 0x0040},
 {0x0031A0, 0x0004},
 {0x0031C0, 0x0040},
-{0x0031E4, 0x0080},
+{0x0031E4, 0x0001},
 {0x0031EF, 0x0040},
 {0x0031F0, 0x0004},
 {0x003200, 0x0040},
-{0x00321F, 0x0080},
+{0x00321F, 0x0001},
 {0x003220, 0x0002},
 {0x00322A, 0x0040},
 {0x003248, 0x0002},
@@ -1063,9 +1069,9 @@ const std::vector> unicode_ranges_flags = {  // st
 {0x003400, 0x0004},
 {0x004DC0, 0x0040},
 {0x004E00, 0x0004},
-{0x00A48D, 0x0080},
+{0x00A48D, 0x0001},
 {0x00A490, 0x0040},
-{0x00A4C7, 0x0080},
+{0x00A4C7, 0x0001},
 {0x00A4D0, 0x0004},
 {0x00A4FE, 0x0020},
 {0x00A500, 0x0004},
@@ -1073,7 +1079,7 @@ const std::vector> unicode_ranges_flags = {  // st
 {0x00A610, 0x0004},
 {0x00A620, 0x0002},
 {0x00A62A, 0x0004},
-{0x00A62C, 0x0080},
+{0x00A62C, 0x0001},
 {0x00A640, 0x0004},
 {0x00A66F, 0x0010},
 {0x00A673, 0x0020},
@@ -1085,20 +1091,20 @@ const std::vector> unicode_ranges_flags = {  // st
 {0x00A6E6, 0x0002},
 {0x00A6F0, 0x0010},
 {0x00A6F2, 0x0020},
-{0x00A6F8, 0x0080},
+{0x00A6F8, 0x0001},
 {0x00A700, 0x0040},
 {0x00A717, 0x0004},
 {0x00A720, 0x0040},
 {0x00A722, 0x0004},
 {0x00A789, 0x0040},
 {0x00A78B, 0x0004},
-{0x00A7CB, 0x0080},
+{0x00A7CB, 0x0001},
 {0x00A7D0, 0x0004},
-{0x00A7D2, 0x0080},
+{0x00A7D2, 0x0001},
 {0x00A7D3, 0x0004},
-{0x00A7D4, 0x0080},
+{0x00A7D4, 0x0001},
 {0x00A7D5, 0x0004},
-{0x00A7DA, 0x0080},
+{0x00A7DA, 0x0001},
 {0x00A7F2, 0x0004},
 {0x00A802, 0x0010},
 {0x00A803, 0x0004},
@@ -1109,20 +1115,20 @@ const std::vector> unicode_ranges_flags = {  // st
 {0x00A823, 0x0010},
 {0x00A828, 0x0040},
 {0x00A82C, 0x0010},
-{0x00A82D, 0x0080},
+{0x00A82D, 0x0001},
 {0x00A830, 0x0002},
 {0x00A836, 0x0040},
-{0x00A83A, 0x0080},
+{0x00A83A, 0x0001},
 {0x00A840, 0x0004},
 {0x00A874, 0x0020},
-{0x00A878, 0x0080},
+{0x00A878, 0x0001},
 {0x00A880, 0x0010},
 {0x00A882, 0x0004},
 {0x00A8B4, 0x0010},
-{0x00A8C6, 0x0080},
+{0x00A8C6, 0x0001},
 {0x00A8CE, 0x0020},
 {0x00A8D0, 0x0002},
-{0x00A8DA, 0x0080},
+{0x00A8DA, 0x0001},
 {0x00A8E0, 0x0010},
 {0x00A8F2, 0x0004},
 {0x00A8F8, 0x0020},
@@ -1136,35 +1142,35 @@ const std::vector> unicode_ranges_flags = {  // st
 {0x00A92E, 0x0020},
 {0x00A930, 0x0004},
 {0x00A947, 0x0010},
-{0x00A954, 0x0080},
+{0x00A954, 0x0001},
 {0x00A95F, 0x0020},
 {0x00A960, 0x0004},
-{0x00A97D, 0x0080},
+{0x00A97D, 0x0001},
 {0x00A980, 0x0010},
 {0x00A984, 0x0004},
 {0x00A9B3, 0x0010},
 {0x00A9C1, 0x0020},
-{0x00A9CE, 0x0080},
+{0x00A9CE, 0x0001},
 {0x00A9CF, 0x0004},
 {0x00A9D0, 0x0002},
-{0x00A9DA, 0x0080},
+{0x00A9DA, 0x0001},
 {0x00A9DE, 0x0020},
 {0x00A9E0, 0x0004},
 {0x00A9E5, 0x0010},
 {0x00A9E6, 0x0004},
 {0x00A9F0, 0x0002},
 {0x00A9FA, 0x0004},
-{0x00A9FF, 0x0080},
+{0x00A9FF, 0x0001},
 {0x00AA00, 0x0004},
 {0x00AA29, 0x0010},
-{0x00AA37, 0x0080},
+{0x00AA37, 0x0001},
 {0x00AA40, 0x0004},
 {0x00AA43, 0x0010},
 {0x00AA44, 0x0004},
 {0x00AA4C, 0x0010},
-{0x00AA4E, 0x0080},
+{0x00AA4E, 0x0001},
 {0x00AA50, 0x0002},
-{0x00AA5A, 0x0080},
+{0x00AA5A, 0x0001},
 {0x00AA5C, 0x0020},
 {0x00AA60, 0x0004},
 {0x00AA77, 0x0040},
@@ -1181,7 +1187,7 @@ const std::vector> unicode_ranges_flags = {  // st
 {0x00AAC0, 0x0004},
 {0x00AAC1, 0x0010},
 {0x00AAC2, 0x0004},
-{0x00AAC3, 0x0080},
+{0x00AAC3, 0x0001},
 {0x00AADB, 0x0004},
 {0x00AADE, 0x0020},
 {0x00AAE0, 0x0004},
@@ -1189,90 +1195,93 @@ const std::vector> unicode_ranges_flags = {  // st
 {0x00AAF0, 0x0020},
 {0x00AAF2, 0x0004},
 {0x00AAF5, 0x0010},
-{0x00AAF7, 0x0080},
+{0x00AAF7, 0x0001},
 {0x00AB01, 0x0004},
-{0x00AB07, 0x0080},
+{0x00AB07, 0x0001},
 {0x00AB09, 0x0004},
-{0x00AB0F, 0x0080},
+{0x00AB0F, 0x0001},
 {0x00AB11, 0x0004},
-{0x00AB17, 0x0080},
+{0x00AB17, 0x0001},
 {0x00AB20, 0x0004},
-{0x00AB27, 0x0080},
+{0x00AB27, 0x0001},
 {0x00AB28, 0x0004},
-{0x00AB2F, 0x0080},
+{0x00AB2F, 0x0001},
 {0x00AB30, 0x0004},
 {0x00AB5B, 0x0040},
 {0x00AB5C, 0x0004},
 {0x00AB6A, 0x0040},
-{0x00AB6C, 0x0080},
+{0x00AB6C, 0x0001},
 {0x00AB70, 0x0004},
 {0x00ABE3, 0x0010},
 {0x00ABEB, 0x0020},
 {0x00ABEC, 0x0010},
-{0x00ABEE, 0x0080},
+{0x00ABEE, 0x0001},
 {0x00ABF0, 0x0002},
-{0x00ABFA, 0x0080},
+{0x00ABFA, 0x0001},
 {0x00AC00, 0x0004},
-{0x00D7A4, 0x0080},
+{0x00D7A4, 0x0001},
 {0x00D7B0, 0x0004},
-{0x00D7C7, 0x0080},
+{0x00D7C7, 0x0001},
 {0x00D7CB, 0x0004},
-{0x00D7FC, 0x0080},
+{0x00D7FC, 0x0001},
+{0x00D800, 0x0080},
 {0x00F900, 0x0004},
-{0x00FA6E, 0x0080},
+{0x00FA6E, 0x0001},
 {0x00FA70, 0x0004},
-{0x00FADA, 0x0080},
+{0x00FADA, 0x0001},
 {0x00FB00, 0x0004},
-{0x00FB07, 0x0080},
+{0x00FB07, 0x0001},
 {0x00FB13, 0x0004},
-{0x00FB18, 0x0080},
+{0x00FB18, 0x0001},
 {0x00FB1D, 0x0004},
 {0x00FB1E, 0x0010},
 {0x00FB1F, 0x0004},
 {0x00FB29, 0x0040},
 {0x00FB2A, 0x0004},
-{0x00FB37, 0x0080},
+{0x00FB37, 0x0001},
 {0x00FB38, 0x0004},
-{0x00FB3D, 0x0080},
+{0x00FB3D, 0x0001},
 {0x00FB3E, 0x0004},
-{0x00FB3F, 0x0080},
+{0x00FB3F, 0x0001},
 {0x00FB40, 0x0004},
-{0x00FB42, 0x0080},
+{0x00FB42, 0x0001},
 {0x00FB43, 0x0004},
-{0x00FB45, 0x0080},
+{0x00FB45, 0x0001},
 {0x00FB46, 0x0004},
 {0x00FBB2, 0x0040},
-{0x00FBC3, 0x0080},
+{0x00FBC3, 0x0001},
 {0x00FBD3, 0x0004},
 {0x00FD3E, 0x0020},
 {0x00FD40, 0x0040},
 {0x00FD50, 0x0004},
-{0x00FD90, 0x0080},
+{0x00FD90, 0x0001},
 {0x00FD92, 0x0004},
-{0x00FDC8, 0x0080},
+{0x00FDC8, 0x0001},
 {0x00FDCF, 0x0040},
-{0x00FDD0, 0x0080},
+{0x00FDD0, 0x0001},
 {0x00FDF0, 0x0004},
 {0x00FDFC, 0x0040},
 {0x00FE00, 0x0010},
 {0x00FE10, 0x0020},
-{0x00FE1A, 0x0080},
+{0x00FE1A, 0x0001},
 {0x00FE20, 0x0010},
 {0x00FE30, 0x0020},
-{0x00FE53, 0x0080},
+{0x00FE53, 0x0001},
 {0x00FE54, 0x0020},
 {0x00FE62, 0x0040},
 {0x00FE63, 0x0020},
 {0x00FE64, 0x0040},
-{0x00FE67, 0x0080},
+{0x00FE67, 0x0001},
 {0x00FE68, 0x0020},
 {0x00FE69, 0x0040},
 {0x00FE6A, 0x0020},
-{0x00FE6C, 0x0080},
+{0x00FE6C, 0x0001},
 {0x00FE70, 0x0004},
-{0x00FE75, 0x0080},
+{0x00FE75, 0x0001},
 {0x00FE76, 0x0004},
-{0x00FEFD, 0x0080},
+{0x00FEFD, 0x0001},
+{0x00FEFF, 0x0080},
+{0x00FF00, 0x0001},
 {0x00FF01, 0x0020},
 {0x00FF04, 0x0040},
 {0x00FF05, 0x0020},
@@ -1294,260 +1303,261 @@ const std::vector> unicode_ranges_flags = {  // st
 {0x00FF5E, 0x0040},
 {0x00FF5F, 0x0020},
 {0x00FF66, 0x0004},
-{0x00FFBF, 0x0080},
+{0x00FFBF, 0x0001},
 {0x00FFC2, 0x0004},
-{0x00FFC8, 0x0080},
+{0x00FFC8, 0x0001},
 {0x00FFCA, 0x0004},
-{0x00FFD0, 0x0080},
+{0x00FFD0, 0x0001},
 {0x00FFD2, 0x0004},
-{0x00FFD8, 0x0080},
+{0x00FFD8, 0x0001},
 {0x00FFDA, 0x0004},
-{0x00FFDD, 0x0080},
+{0x00FFDD, 0x0001},
 {0x00FFE0, 0x0040},
-{0x00FFE7, 0x0080},
+{0x00FFE7, 0x0001},
 {0x00FFE8, 0x0040},
-{0x00FFEF, 0x0080},
+{0x00FFEF, 0x0001},
+{0x00FFF9, 0x0080},
 {0x00FFFC, 0x0040},
-{0x00FFFE, 0x0080},
+{0x00FFFE, 0x0001},
 {0x010000, 0x0004},
-{0x01000C, 0x0080},
+{0x01000C, 0x0001},
 {0x01000D, 0x0004},
-{0x010027, 0x0080},
+{0x010027, 0x0001},
 {0x010028, 0x0004},
-{0x01003B, 0x0080},
+{0x01003B, 0x0001},
 {0x01003C, 0x0004},
-{0x01003E, 0x0080},
+{0x01003E, 0x0001},
 {0x01003F, 0x0004},
-{0x01004E, 0x0080},
+{0x01004E, 0x0001},
 {0x010050, 0x0004},
-{0x01005E, 0x0080},
+{0x01005E, 0x0001},
 {0x010080, 0x0004},
-{0x0100FB, 0x0080},
+{0x0100FB, 0x0001},
 {0x010100, 0x0020},
-{0x010103, 0x0080},
+{0x010103, 0x0001},
 {0x010107, 0x0002},
-{0x010134, 0x0080},
+{0x010134, 0x0001},
 {0x010137, 0x0040},
 {0x010140, 0x0002},
 {0x010179, 0x0040},
 {0x01018A, 0x0002},
 {0x01018C, 0x0040},
-{0x01018F, 0x0080},
+{0x01018F, 0x0001},
 {0x010190, 0x0040},
-{0x01019D, 0x0080},
+{0x01019D, 0x0001},
 {0x0101A0, 0x0040},
-{0x0101A1, 0x0080},
+{0x0101A1, 0x0001},
 {0x0101D0, 0x0040},
 {0x0101FD, 0x0010},
-{0x0101FE, 0x0080},
+{0x0101FE, 0x0001},
 {0x010280, 0x0004},
-{0x01029D, 0x0080},
+{0x01029D, 0x0001},
 {0x0102A0, 0x0004},
-{0x0102D1, 0x0080},
+{0x0102D1, 0x0001},
 {0x0102E0, 0x0010},
 {0x0102E1, 0x0002},
-{0x0102FC, 0x0080},
+{0x0102FC, 0x0001},
 {0x010300, 0x0004},
 {0x010320, 0x0002},
-{0x010324, 0x0080},
+{0x010324, 0x0001},
 {0x01032D, 0x0004},
 {0x010341, 0x0002},
 {0x010342, 0x0004},
 {0x01034A, 0x0002},
-{0x01034B, 0x0080},
+{0x01034B, 0x0001},
 {0x010350, 0x0004},
 {0x010376, 0x0010},
-{0x01037B, 0x0080},
+{0x01037B, 0x0001},
 {0x010380, 0x0004},
-{0x01039E, 0x0080},
+{0x01039E, 0x0001},
 {0x01039F, 0x0020},
 {0x0103A0, 0x0004},
-{0x0103C4, 0x0080},
+{0x0103C4, 0x0001},
 {0x0103C8, 0x0004},
 {0x0103D0, 0x0020},
 {0x0103D1, 0x0002},
-{0x0103D6, 0x0080},
+{0x0103D6, 0x0001},
 {0x010400, 0x0004},
-{0x01049E, 0x0080},
+{0x01049E, 0x0001},
 {0x0104A0, 0x0002},
-{0x0104AA, 0x0080},
+{0x0104AA, 0x0001},
 {0x0104B0, 0x0004},
-{0x0104D4, 0x0080},
+{0x0104D4, 0x0001},
 {0x0104D8, 0x0004},
-{0x0104FC, 0x0080},
+{0x0104FC, 0x0001},
 {0x010500, 0x0004},
-{0x010528, 0x0080},
+{0x010528, 0x0001},
 {0x010530, 0x0004},
-{0x010564, 0x0080},
+{0x010564, 0x0001},
 {0x01056F, 0x0020},
 {0x010570, 0x0004},
-{0x01057B, 0x0080},
+{0x01057B, 0x0001},
 {0x01057C, 0x0004},
-{0x01058B, 0x0080},
+{0x01058B, 0x0001},
 {0x01058C, 0x0004},
-{0x010593, 0x0080},
+{0x010593, 0x0001},
 {0x010594, 0x0004},
-{0x010596, 0x0080},
+{0x010596, 0x0001},
 {0x010597, 0x0004},
-{0x0105A2, 0x0080},
+{0x0105A2, 0x0001},
 {0x0105A3, 0x0004},
-{0x0105B2, 0x0080},
+{0x0105B2, 0x0001},
 {0x0105B3, 0x0004},
-{0x0105BA, 0x0080},
+{0x0105BA, 0x0001},
 {0x0105BB, 0x0004},
-{0x0105BD, 0x0080},
+{0x0105BD, 0x0001},
 {0x010600, 0x0004},
-{0x010737, 0x0080},
+{0x010737, 0x0001},
 {0x010740, 0x0004},
-{0x010756, 0x0080},
+{0x010756, 0x0001},
 {0x010760, 0x0004},
-{0x010768, 0x0080},
+{0x010768, 0x0001},
 {0x010780, 0x0004},
-{0x010786, 0x0080},
+{0x010786, 0x0001},
 {0x010787, 0x0004},
-{0x0107B1, 0x0080},
+{0x0107B1, 0x0001},
 {0x0107B2, 0x0004},
-{0x0107BB, 0x0080},
+{0x0107BB, 0x0001},
 {0x010800, 0x0004},
-{0x010806, 0x0080},
+{0x010806, 0x0001},
 {0x010808, 0x0004},
-{0x010809, 0x0080},
+{0x010809, 0x0001},
 {0x01080A, 0x0004},
-{0x010836, 0x0080},
+{0x010836, 0x0001},
 {0x010837, 0x0004},
-{0x010839, 0x0080},
+{0x010839, 0x0001},
 {0x01083C, 0x0004},
-{0x01083D, 0x0080},
+{0x01083D, 0x0001},
 {0x01083F, 0x0004},
-{0x010856, 0x0080},
+{0x010856, 0x0001},
 {0x010857, 0x0020},
 {0x010858, 0x0002},
 {0x010860, 0x0004},
 {0x010877, 0x0040},
 {0x010879, 0x0002},
 {0x010880, 0x0004},
-{0x01089F, 0x0080},
+{0x01089F, 0x0001},
 {0x0108A7, 0x0002},
-{0x0108B0, 0x0080},
+{0x0108B0, 0x0001},
 {0x0108E0, 0x0004},
-{0x0108F3, 0x0080},
+{0x0108F3, 0x0001},
 {0x0108F4, 0x0004},
-{0x0108F6, 0x0080},
+{0x0108F6, 0x0001},
 {0x0108FB, 0x0002},
 {0x010900, 0x0004},
 {0x010916, 0x0002},
-{0x01091C, 0x0080},
+{0x01091C, 0x0001},
 {0x01091F, 0x0020},
 {0x010920, 0x0004},
-{0x01093A, 0x0080},
+{0x01093A, 0x0001},
 {0x01093F, 0x0020},
-{0x010940, 0x0080},
+{0x010940, 0x0001},
 {0x010980, 0x0004},
-{0x0109B8, 0x0080},
+{0x0109B8, 0x0001},
 {0x0109BC, 0x0002},
 {0x0109BE, 0x0004},
 {0x0109C0, 0x0002},
-{0x0109D0, 0x0080},
+{0x0109D0, 0x0001},
 {0x0109D2, 0x0002},
 {0x010A00, 0x0004},
 {0x010A01, 0x0010},
-{0x010A04, 0x0080},
+{0x010A04, 0x0001},
 {0x010A05, 0x0010},
-{0x010A07, 0x0080},
+{0x010A07, 0x0001},
 {0x010A0C, 0x0010},
 {0x010A10, 0x0004},
-{0x010A14, 0x0080},
+{0x010A14, 0x0001},
 {0x010A15, 0x0004},
-{0x010A18, 0x0080},
+{0x010A18, 0x0001},
 {0x010A19, 0x0004},
-{0x010A36, 0x0080},
+{0x010A36, 0x0001},
 {0x010A38, 0x0010},
-{0x010A3B, 0x0080},
+{0x010A3B, 0x0001},
 {0x010A3F, 0x0010},
 {0x010A40, 0x0002},
-{0x010A49, 0x0080},
+{0x010A49, 0x0001},
 {0x010A50, 0x0020},
-{0x010A59, 0x0080},
+{0x010A59, 0x0001},
 {0x010A60, 0x0004},
 {0x010A7D, 0x0002},
 {0x010A7F, 0x0020},
 {0x010A80, 0x0004},
 {0x010A9D, 0x0002},
-{0x010AA0, 0x0080},
+{0x010AA0, 0x0001},
 {0x010AC0, 0x0004},
 {0x010AC8, 0x0040},
 {0x010AC9, 0x0004},
 {0x010AE5, 0x0010},
-{0x010AE7, 0x0080},
+{0x010AE7, 0x0001},
 {0x010AEB, 0x0002},
 {0x010AF0, 0x0020},
-{0x010AF7, 0x0080},
+{0x010AF7, 0x0001},
 {0x010B00, 0x0004},
-{0x010B36, 0x0080},
+{0x010B36, 0x0001},
 {0x010B39, 0x0020},
 {0x010B40, 0x0004},
-{0x010B56, 0x0080},
+{0x010B56, 0x0001},
 {0x010B58, 0x0002},
 {0x010B60, 0x0004},
-{0x010B73, 0x0080},
+{0x010B73, 0x0001},
 {0x010B78, 0x0002},
 {0x010B80, 0x0004},
-{0x010B92, 0x0080},
+{0x010B92, 0x0001},
 {0x010B99, 0x0020},
-{0x010B9D, 0x0080},
+{0x010B9D, 0x0001},
 {0x010BA9, 0x0002},
-{0x010BB0, 0x0080},
+{0x010BB0, 0x0001},
 {0x010C00, 0x0004},
-{0x010C49, 0x0080},
+{0x010C49, 0x0001},
 {0x010C80, 0x0004},
-{0x010CB3, 0x0080},
+{0x010CB3, 0x0001},
 {0x010CC0, 0x0004},
-{0x010CF3, 0x0080},
+{0x010CF3, 0x0001},
 {0x010CFA, 0x0002},
 {0x010D00, 0x0004},
 {0x010D24, 0x0010},
-{0x010D28, 0x0080},
+{0x010D28, 0x0001},
 {0x010D30, 0x0002},
-{0x010D3A, 0x0080},
+{0x010D3A, 0x0001},
 {0x010E60, 0x0002},
-{0x010E7F, 0x0080},
+{0x010E7F, 0x0001},
 {0x010E80, 0x0004},
-{0x010EAA, 0x0080},
+{0x010EAA, 0x0001},
 {0x010EAB, 0x0010},
 {0x010EAD, 0x0020},
-{0x010EAE, 0x0080},
+{0x010EAE, 0x0001},
 {0x010EB0, 0x0004},
-{0x010EB2, 0x0080},
+{0x010EB2, 0x0001},
 {0x010EFD, 0x0010},
 {0x010F00, 0x0004},
 {0x010F1D, 0x0002},
 {0x010F27, 0x0004},
-{0x010F28, 0x0080},
+{0x010F28, 0x0001},
 {0x010F30, 0x0004},
 {0x010F46, 0x0010},
 {0x010F51, 0x0002},
 {0x010F55, 0x0020},
-{0x010F5A, 0x0080},
+{0x010F5A, 0x0001},
 {0x010F70, 0x0004},
 {0x010F82, 0x0010},
 {0x010F86, 0x0020},
-{0x010F8A, 0x0080},
+{0x010F8A, 0x0001},
 {0x010FB0, 0x0004},
 {0x010FC5, 0x0002},
-{0x010FCC, 0x0080},
+{0x010FCC, 0x0001},
 {0x010FE0, 0x0004},
-{0x010FF7, 0x0080},
+{0x010FF7, 0x0001},
 {0x011000, 0x0010},
 {0x011003, 0x0004},
 {0x011038, 0x0010},
 {0x011047, 0x0020},
-{0x01104E, 0x0080},
+{0x01104E, 0x0001},
 {0x011052, 0x0002},
 {0x011070, 0x0010},
 {0x011071, 0x0004},
 {0x011073, 0x0010},
 {0x011075, 0x0004},
-{0x011076, 0x0080},
+{0x011076, 0x0001},
 {0x01107F, 0x0010},
 {0x011083, 0x0004},
 {0x0110B0, 0x0010},
@@ -1555,26 +1565,28 @@ const std::vector> unicode_ranges_flags = {  // st
 {0x0110BD, 0x0080},
 {0x0110BE, 0x0020},
 {0x0110C2, 0x0010},
-{0x0110C3, 0x0080},
+{0x0110C3, 0x0001},
+{0x0110CD, 0x0080},
+{0x0110CE, 0x0001},
 {0x0110D0, 0x0004},
-{0x0110E9, 0x0080},
+{0x0110E9, 0x0001},
 {0x0110F0, 0x0002},
-{0x0110FA, 0x0080},
+{0x0110FA, 0x0001},
 {0x011100, 0x0010},
 {0x011103, 0x0004},
 {0x011127, 0x0010},
-{0x011135, 0x0080},
+{0x011135, 0x0001},
 {0x011136, 0x0002},
 {0x011140, 0x0020},
 {0x011144, 0x0004},
 {0x011145, 0x0010},
 {0x011147, 0x0004},
-{0x011148, 0x0080},
+{0x011148, 0x0001},
 {0x011150, 0x0004},
 {0x011173, 0x0010},
 {0x011174, 0x0020},
 {0x011176, 0x0004},
-{0x011177, 0x0080},
+{0x011177, 0x0001},
 {0x011180, 0x0010},
 {0x011183, 0x0004},
 {0x0111B3, 0x0010},
@@ -1588,159 +1600,159 @@ const std::vector> unicode_ranges_flags = {  // st
 {0x0111DB, 0x0020},
 {0x0111DC, 0x0004},
 {0x0111DD, 0x0020},
-{0x0111E0, 0x0080},
+{0x0111E0, 0x0001},
 {0x0111E1, 0x0002},
-{0x0111F5, 0x0080},
+{0x0111F5, 0x0001},
 {0x011200, 0x0004},
-{0x011212, 0x0080},
+{0x011212, 0x0001},
 {0x011213, 0x0004},
 {0x01122C, 0x0010},
 {0x011238, 0x0020},
 {0x01123E, 0x0010},
 {0x01123F, 0x0004},
 {0x011241, 0x0010},
-{0x011242, 0x0080},
+{0x011242, 0x0001},
 {0x011280, 0x0004},
-{0x011287, 0x0080},
+{0x011287, 0x0001},
 {0x011288, 0x0004},
-{0x011289, 0x0080},
+{0x011289, 0x0001},
 {0x01128A, 0x0004},
-{0x01128E, 0x0080},
+{0x01128E, 0x0001},
 {0x01128F, 0x0004},
-{0x01129E, 0x0080},
+{0x01129E, 0x0001},
 {0x01129F, 0x0004},
 {0x0112A9, 0x0020},
-{0x0112AA, 0x0080},
+{0x0112AA, 0x0001},
 {0x0112B0, 0x0004},
 {0x0112DF, 0x0010},
-{0x0112EB, 0x0080},
+{0x0112EB, 0x0001},
 {0x0112F0, 0x0002},
-{0x0112FA, 0x0080},
+{0x0112FA, 0x0001},
 {0x011300, 0x0010},
-{0x011304, 0x0080},
+{0x011304, 0x0001},
 {0x011305, 0x0004},
-{0x01130D, 0x0080},
+{0x01130D, 0x0001},
 {0x01130F, 0x0004},
-{0x011311, 0x0080},
+{0x011311, 0x0001},
 {0x011313, 0x0004},
-{0x011329, 0x0080},
+{0x011329, 0x0001},
 {0x01132A, 0x0004},
-{0x011331, 0x0080},
+{0x011331, 0x0001},
 {0x011332, 0x0004},
-{0x011334, 0x0080},
+{0x011334, 0x0001},
 {0x011335, 0x0004},
-{0x01133A, 0x0080},
+{0x01133A, 0x0001},
 {0x01133B, 0x0010},
 {0x01133D, 0x0004},
 {0x01133E, 0x0010},
-{0x011345, 0x0080},
+{0x011345, 0x0001},
 {0x011347, 0x0010},
-{0x011349, 0x0080},
+{0x011349, 0x0001},
 {0x01134B, 0x0010},
-{0x01134E, 0x0080},
+{0x01134E, 0x0001},
 {0x011350, 0x0004},
-{0x011351, 0x0080},
+{0x011351, 0x0001},
 {0x011357, 0x0010},
-{0x011358, 0x0080},
+{0x011358, 0x0001},
 {0x01135D, 0x0004},
 {0x011362, 0x0010},
-{0x011364, 0x0080},
+{0x011364, 0x0001},
 {0x011366, 0x0010},
-{0x01136D, 0x0080},
+{0x01136D, 0x0001},
 {0x011370, 0x0010},
-{0x011375, 0x0080},
+{0x011375, 0x0001},
 {0x011400, 0x0004},
 {0x011435, 0x0010},
 {0x011447, 0x0004},
 {0x01144B, 0x0020},
 {0x011450, 0x0002},
 {0x01145A, 0x0020},
-{0x01145C, 0x0080},
+{0x01145C, 0x0001},
 {0x01145D, 0x0020},
 {0x01145E, 0x0010},
 {0x01145F, 0x0004},
-{0x011462, 0x0080},
+{0x011462, 0x0001},
 {0x011480, 0x0004},
 {0x0114B0, 0x0010},
 {0x0114C4, 0x0004},
 {0x0114C6, 0x0020},
 {0x0114C7, 0x0004},
-{0x0114C8, 0x0080},
+{0x0114C8, 0x0001},
 {0x0114D0, 0x0002},
-{0x0114DA, 0x0080},
+{0x0114DA, 0x0001},
 {0x011580, 0x0004},
 {0x0115AF, 0x0010},
-{0x0115B6, 0x0080},
+{0x0115B6, 0x0001},
 {0x0115B8, 0x0010},
 {0x0115C1, 0x0020},
 {0x0115D8, 0x0004},
 {0x0115DC, 0x0010},
-{0x0115DE, 0x0080},
+{0x0115DE, 0x0001},
 {0x011600, 0x0004},
 {0x011630, 0x0010},
 {0x011641, 0x0020},
 {0x011644, 0x0004},
-{0x011645, 0x0080},
+{0x011645, 0x0001},
 {0x011650, 0x0002},
-{0x01165A, 0x0080},
+{0x01165A, 0x0001},
 {0x011660, 0x0020},
-{0x01166D, 0x0080},
+{0x01166D, 0x0001},
 {0x011680, 0x0004},
 {0x0116AB, 0x0010},
 {0x0116B8, 0x0004},
 {0x0116B9, 0x0020},
-{0x0116BA, 0x0080},
+{0x0116BA, 0x0001},
 {0x0116C0, 0x0002},
-{0x0116CA, 0x0080},
+{0x0116CA, 0x0001},
 {0x011700, 0x0004},
-{0x01171B, 0x0080},
+{0x01171B, 0x0001},
 {0x01171D, 0x0010},
-{0x01172C, 0x0080},
+{0x01172C, 0x0001},
 {0x011730, 0x0002},
 {0x01173C, 0x0020},
 {0x01173F, 0x0040},
 {0x011740, 0x0004},
-{0x011747, 0x0080},
+{0x011747, 0x0001},
 {0x011800, 0x0004},
 {0x01182C, 0x0010},
 {0x01183B, 0x0020},
-{0x01183C, 0x0080},
+{0x01183C, 0x0001},
 {0x0118A0, 0x0004},
 {0x0118E0, 0x0002},
-{0x0118F3, 0x0080},
+{0x0118F3, 0x0001},
 {0x0118FF, 0x0004},
-{0x011907, 0x0080},
+{0x011907, 0x0001},
 {0x011909, 0x0004},
-{0x01190A, 0x0080},
+{0x01190A, 0x0001},
 {0x01190C, 0x0004},
-{0x011914, 0x0080},
+{0x011914, 0x0001},
 {0x011915, 0x0004},
-{0x011917, 0x0080},
+{0x011917, 0x0001},
 {0x011918, 0x0004},
 {0x011930, 0x0010},
-{0x011936, 0x0080},
+{0x011936, 0x0001},
 {0x011937, 0x0010},
-{0x011939, 0x0080},
+{0x011939, 0x0001},
 {0x01193B, 0x0010},
 {0x01193F, 0x0004},
 {0x011940, 0x0010},
 {0x011941, 0x0004},
 {0x011942, 0x0010},
 {0x011944, 0x0020},
-{0x011947, 0x0080},
+{0x011947, 0x0001},
 {0x011950, 0x0002},
-{0x01195A, 0x0080},
+{0x01195A, 0x0001},
 {0x0119A0, 0x0004},
-{0x0119A8, 0x0080},
+{0x0119A8, 0x0001},
 {0x0119AA, 0x0004},
 {0x0119D1, 0x0010},
-{0x0119D8, 0x0080},
+{0x0119D8, 0x0001},
 {0x0119DA, 0x0010},
 {0x0119E1, 0x0004},
 {0x0119E2, 0x0020},
 {0x0119E3, 0x0004},
 {0x0119E4, 0x0010},
-{0x0119E5, 0x0080},
+{0x0119E5, 0x0001},
 {0x011A00, 0x0004},
 {0x011A01, 0x0010},
 {0x011A0B, 0x0004},
@@ -1749,7 +1761,7 @@ const std::vector> unicode_ranges_flags = {  // st
 {0x011A3B, 0x0010},
 {0x011A3F, 0x0020},
 {0x011A47, 0x0010},
-{0x011A48, 0x0080},
+{0x011A48, 0x0001},
 {0x011A50, 0x0004},
 {0x011A51, 0x0010},
 {0x011A5C, 0x0004},
@@ -1757,117 +1769,117 @@ const std::vector> unicode_ranges_flags = {  // st
 {0x011A9A, 0x0020},
 {0x011A9D, 0x0004},
 {0x011A9E, 0x0020},
-{0x011AA3, 0x0080},
+{0x011AA3, 0x0001},
 {0x011AB0, 0x0004},
-{0x011AF9, 0x0080},
+{0x011AF9, 0x0001},
 {0x011B00, 0x0020},
-{0x011B0A, 0x0080},
+{0x011B0A, 0x0001},
 {0x011C00, 0x0004},
-{0x011C09, 0x0080},
+{0x011C09, 0x0001},
 {0x011C0A, 0x0004},
 {0x011C2F, 0x0010},
-{0x011C37, 0x0080},
+{0x011C37, 0x0001},
 {0x011C38, 0x0010},
 {0x011C40, 0x0004},
 {0x011C41, 0x0020},
-{0x011C46, 0x0080},
+{0x011C46, 0x0001},
 {0x011C50, 0x0002},
-{0x011C6D, 0x0080},
+{0x011C6D, 0x0001},
 {0x011C70, 0x0020},
 {0x011C72, 0x0004},
-{0x011C90, 0x0080},
+{0x011C90, 0x0001},
 {0x011C92, 0x0010},
-{0x011CA8, 0x0080},
+{0x011CA8, 0x0001},
 {0x011CA9, 0x0010},
-{0x011CB7, 0x0080},
+{0x011CB7, 0x0001},
 {0x011D00, 0x0004},
-{0x011D07, 0x0080},
+{0x011D07, 0x0001},
 {0x011D08, 0x0004},
-{0x011D0A, 0x0080},
+{0x011D0A, 0x0001},
 {0x011D0B, 0x0004},
 {0x011D31, 0x0010},
-{0x011D37, 0x0080},
+{0x011D37, 0x0001},
 {0x011D3A, 0x0010},
-{0x011D3B, 0x0080},
+{0x011D3B, 0x0001},
 {0x011D3C, 0x0010},
-{0x011D3E, 0x0080},
+{0x011D3E, 0x0001},
 {0x011D3F, 0x0010},
 {0x011D46, 0x0004},
 {0x011D47, 0x0010},
-{0x011D48, 0x0080},
+{0x011D48, 0x0001},
 {0x011D50, 0x0002},
-{0x011D5A, 0x0080},
+{0x011D5A, 0x0001},
 {0x011D60, 0x0004},
-{0x011D66, 0x0080},
+{0x011D66, 0x0001},
 {0x011D67, 0x0004},
-{0x011D69, 0x0080},
+{0x011D69, 0x0001},
 {0x011D6A, 0x0004},
 {0x011D8A, 0x0010},
-{0x011D8F, 0x0080},
+{0x011D8F, 0x0001},
 {0x011D90, 0x0010},
-{0x011D92, 0x0080},
+{0x011D92, 0x0001},
 {0x011D93, 0x0010},
 {0x011D98, 0x0004},
-{0x011D99, 0x0080},
+{0x011D99, 0x0001},
 {0x011DA0, 0x0002},
-{0x011DAA, 0x0080},
+{0x011DAA, 0x0001},
 {0x011EE0, 0x0004},
 {0x011EF3, 0x0010},
 {0x011EF7, 0x0020},
-{0x011EF9, 0x0080},
+{0x011EF9, 0x0001},
 {0x011F00, 0x0010},
 {0x011F02, 0x0004},
 {0x011F03, 0x0010},
 {0x011F04, 0x0004},
-{0x011F11, 0x0080},
+{0x011F11, 0x0001},
 {0x011F12, 0x0004},
 {0x011F34, 0x0010},
-{0x011F3B, 0x0080},
+{0x011F3B, 0x0001},
 {0x011F3E, 0x0010},
 {0x011F43, 0x0020},
 {0x011F50, 0x0002},
-{0x011F5A, 0x0080},
+{0x011F5A, 0x0001},
 {0x011FB0, 0x0004},
-{0x011FB1, 0x0080},
+{0x011FB1, 0x0001},
 {0x011FC0, 0x0002},
 {0x011FD5, 0x0040},
-{0x011FF2, 0x0080},
+{0x011FF2, 0x0001},
 {0x011FFF, 0x0020},
 {0x012000, 0x0004},
-{0x01239A, 0x0080},
+{0x01239A, 0x0001},
 {0x012400, 0x0002},
-{0x01246F, 0x0080},
+{0x01246F, 0x0001},
 {0x012470, 0x0020},
-{0x012475, 0x0080},
+{0x012475, 0x0001},
 {0x012480, 0x0004},
-{0x012544, 0x0080},
+{0x012544, 0x0001},
 {0x012F90, 0x0004},
 {0x012FF1, 0x0020},
-{0x012FF3, 0x0080},
+{0x012FF3, 0x0001},
 {0x013000, 0x0004},
 {0x013430, 0x0080},
 {0x013440, 0x0010},
 {0x013441, 0x0004},
 {0x013447, 0x0010},
-{0x013456, 0x0080},
+{0x013456, 0x0001},
 {0x014400, 0x0004},
-{0x014647, 0x0080},
+{0x014647, 0x0001},
 {0x016800, 0x0004},
-{0x016A39, 0x0080},
+{0x016A39, 0x0001},
 {0x016A40, 0x0004},
-{0x016A5F, 0x0080},
+{0x016A5F, 0x0001},
 {0x016A60, 0x0002},
-{0x016A6A, 0x0080},
+{0x016A6A, 0x0001},
 {0x016A6E, 0x0020},
 {0x016A70, 0x0004},
-{0x016ABF, 0x0080},
+{0x016ABF, 0x0001},
 {0x016AC0, 0x0002},
-{0x016ACA, 0x0080},
+{0x016ACA, 0x0001},
 {0x016AD0, 0x0004},
-{0x016AEE, 0x0080},
+{0x016AEE, 0x0001},
 {0x016AF0, 0x0010},
 {0x016AF5, 0x0020},
-{0x016AF6, 0x0080},
+{0x016AF6, 0x0001},
 {0x016B00, 0x0004},
 {0x016B30, 0x0010},
 {0x016B37, 0x0020},
@@ -1875,81 +1887,82 @@ const std::vector> unicode_ranges_flags = {  // st
 {0x016B40, 0x0004},
 {0x016B44, 0x0020},
 {0x016B45, 0x0040},
-{0x016B46, 0x0080},
+{0x016B46, 0x0001},
 {0x016B50, 0x0002},
-{0x016B5A, 0x0080},
+{0x016B5A, 0x0001},
 {0x016B5B, 0x0002},
-{0x016B62, 0x0080},
+{0x016B62, 0x0001},
 {0x016B63, 0x0004},
-{0x016B78, 0x0080},
+{0x016B78, 0x0001},
 {0x016B7D, 0x0004},
-{0x016B90, 0x0080},
+{0x016B90, 0x0001},
 {0x016E40, 0x0004},
 {0x016E80, 0x0002},
 {0x016E97, 0x0020},
-{0x016E9B, 0x0080},
+{0x016E9B, 0x0001},
 {0x016F00, 0x0004},
-{0x016F4B, 0x0080},
+{0x016F4B, 0x0001},
 {0x016F4F, 0x0010},
 {0x016F50, 0x0004},
 {0x016F51, 0x0010},
-{0x016F88, 0x0080},
+{0x016F88, 0x0001},
 {0x016F8F, 0x0010},
 {0x016F93, 0x0004},
-{0x016FA0, 0x0080},
+{0x016FA0, 0x0001},
 {0x016FE0, 0x0004},
 {0x016FE2, 0x0020},
 {0x016FE3, 0x0004},
 {0x016FE4, 0x0010},
-{0x016FE5, 0x0080},
+{0x016FE5, 0x0001},
 {0x016FF0, 0x0010},
-{0x016FF2, 0x0080},
+{0x016FF2, 0x0001},
 {0x017000, 0x0004},
-{0x0187F8, 0x0080},
+{0x0187F8, 0x0001},
 {0x018800, 0x0004},
-{0x018CD6, 0x0080},
+{0x018CD6, 0x0001},
 {0x018D00, 0x0004},
-{0x018D09, 0x0080},
+{0x018D09, 0x0001},
 {0x01AFF0, 0x0004},
-{0x01AFF4, 0x0080},
+{0x01AFF4, 0x0001},
 {0x01AFF5, 0x0004},
-{0x01AFFC, 0x0080},
+{0x01AFFC, 0x0001},
 {0x01AFFD, 0x0004},
-{0x01AFFF, 0x0080},
+{0x01AFFF, 0x0001},
 {0x01B000, 0x0004},
-{0x01B123, 0x0080},
+{0x01B123, 0x0001},
 {0x01B132, 0x0004},
-{0x01B133, 0x0080},
+{0x01B133, 0x0001},
 {0x01B150, 0x0004},
-{0x01B153, 0x0080},
+{0x01B153, 0x0001},
 {0x01B155, 0x0004},
-{0x01B156, 0x0080},
+{0x01B156, 0x0001},
 {0x01B164, 0x0004},
-{0x01B168, 0x0080},
+{0x01B168, 0x0001},
 {0x01B170, 0x0004},
-{0x01B2FC, 0x0080},
+{0x01B2FC, 0x0001},
 {0x01BC00, 0x0004},
-{0x01BC6B, 0x0080},
+{0x01BC6B, 0x0001},
 {0x01BC70, 0x0004},
-{0x01BC7D, 0x0080},
+{0x01BC7D, 0x0001},
 {0x01BC80, 0x0004},
-{0x01BC89, 0x0080},
+{0x01BC89, 0x0001},
 {0x01BC90, 0x0004},
-{0x01BC9A, 0x0080},
+{0x01BC9A, 0x0001},
 {0x01BC9C, 0x0040},
 {0x01BC9D, 0x0010},
 {0x01BC9F, 0x0020},
 {0x01BCA0, 0x0080},
+{0x01BCA4, 0x0001},
 {0x01CF00, 0x0010},
-{0x01CF2E, 0x0080},
+{0x01CF2E, 0x0001},
 {0x01CF30, 0x0010},
-{0x01CF47, 0x0080},
+{0x01CF47, 0x0001},
 {0x01CF50, 0x0040},
-{0x01CFC4, 0x0080},
+{0x01CFC4, 0x0001},
 {0x01D000, 0x0040},
-{0x01D0F6, 0x0080},
+{0x01D0F6, 0x0001},
 {0x01D100, 0x0040},
-{0x01D127, 0x0080},
+{0x01D127, 0x0001},
 {0x01D129, 0x0040},
 {0x01D165, 0x0010},
 {0x01D16A, 0x0040},
@@ -1961,57 +1974,57 @@ const std::vector> unicode_ranges_flags = {  // st
 {0x01D18C, 0x0040},
 {0x01D1AA, 0x0010},
 {0x01D1AE, 0x0040},
-{0x01D1EB, 0x0080},
+{0x01D1EB, 0x0001},
 {0x01D200, 0x0040},
 {0x01D242, 0x0010},
 {0x01D245, 0x0040},
-{0x01D246, 0x0080},
+{0x01D246, 0x0001},
 {0x01D2C0, 0x0002},
-{0x01D2D4, 0x0080},
+{0x01D2D4, 0x0001},
 {0x01D2E0, 0x0002},
-{0x01D2F4, 0x0080},
+{0x01D2F4, 0x0001},
 {0x01D300, 0x0040},
-{0x01D357, 0x0080},
+{0x01D357, 0x0001},
 {0x01D360, 0x0002},
-{0x01D379, 0x0080},
+{0x01D379, 0x0001},
 {0x01D400, 0x0004},
-{0x01D455, 0x0080},
+{0x01D455, 0x0001},
 {0x01D456, 0x0004},
-{0x01D49D, 0x0080},
+{0x01D49D, 0x0001},
 {0x01D49E, 0x0004},
-{0x01D4A0, 0x0080},
+{0x01D4A0, 0x0001},
 {0x01D4A2, 0x0004},
-{0x01D4A3, 0x0080},
+{0x01D4A3, 0x0001},
 {0x01D4A5, 0x0004},
-{0x01D4A7, 0x0080},
+{0x01D4A7, 0x0001},
 {0x01D4A9, 0x0004},
-{0x01D4AD, 0x0080},
+{0x01D4AD, 0x0001},
 {0x01D4AE, 0x0004},
-{0x01D4BA, 0x0080},
+{0x01D4BA, 0x0001},
 {0x01D4BB, 0x0004},
-{0x01D4BC, 0x0080},
+{0x01D4BC, 0x0001},
 {0x01D4BD, 0x0004},
-{0x01D4C4, 0x0080},
+{0x01D4C4, 0x0001},
 {0x01D4C5, 0x0004},
-{0x01D506, 0x0080},
+{0x01D506, 0x0001},
 {0x01D507, 0x0004},
-{0x01D50B, 0x0080},
+{0x01D50B, 0x0001},
 {0x01D50D, 0x0004},
-{0x01D515, 0x0080},
+{0x01D515, 0x0001},
 {0x01D516, 0x0004},
-{0x01D51D, 0x0080},
+{0x01D51D, 0x0001},
 {0x01D51E, 0x0004},
-{0x01D53A, 0x0080},
+{0x01D53A, 0x0001},
 {0x01D53B, 0x0004},
-{0x01D53F, 0x0080},
+{0x01D53F, 0x0001},
 {0x01D540, 0x0004},
-{0x01D545, 0x0080},
+{0x01D545, 0x0001},
 {0x01D546, 0x0004},
-{0x01D547, 0x0080},
+{0x01D547, 0x0001},
 {0x01D54A, 0x0004},
-{0x01D551, 0x0080},
+{0x01D551, 0x0001},
 {0x01D552, 0x0004},
-{0x01D6A6, 0x0080},
+{0x01D6A6, 0x0001},
 {0x01D6A8, 0x0004},
 {0x01D6C1, 0x0040},
 {0x01D6C2, 0x0004},
@@ -2033,7 +2046,7 @@ const std::vector> unicode_ranges_flags = {  // st
 {0x01D7AA, 0x0004},
 {0x01D7C3, 0x0040},
 {0x01D7C4, 0x0004},
-{0x01D7CC, 0x0080},
+{0x01D7CC, 0x0001},
 {0x01D7CE, 0x0002},
 {0x01D800, 0x0040},
 {0x01DA00, 0x0010},
@@ -2045,251 +2058,283 @@ const std::vector> unicode_ranges_flags = {  // st
 {0x01DA84, 0x0010},
 {0x01DA85, 0x0040},
 {0x01DA87, 0x0020},
-{0x01DA8C, 0x0080},
+{0x01DA8C, 0x0001},
 {0x01DA9B, 0x0010},
-{0x01DAA0, 0x0080},
+{0x01DAA0, 0x0001},
 {0x01DAA1, 0x0010},
-{0x01DAB0, 0x0080},
+{0x01DAB0, 0x0001},
 {0x01DF00, 0x0004},
-{0x01DF1F, 0x0080},
+{0x01DF1F, 0x0001},
 {0x01DF25, 0x0004},
-{0x01DF2B, 0x0080},
+{0x01DF2B, 0x0001},
 {0x01E000, 0x0010},
-{0x01E007, 0x0080},
+{0x01E007, 0x0001},
 {0x01E008, 0x0010},
-{0x01E019, 0x0080},
+{0x01E019, 0x0001},
 {0x01E01B, 0x0010},
-{0x01E022, 0x0080},
+{0x01E022, 0x0001},
 {0x01E023, 0x0010},
-{0x01E025, 0x0080},
+{0x01E025, 0x0001},
 {0x01E026, 0x0010},
-{0x01E02B, 0x0080},
+{0x01E02B, 0x0001},
 {0x01E030, 0x0004},
-{0x01E06E, 0x0080},
+{0x01E06E, 0x0001},
 {0x01E08F, 0x0010},
-{0x01E090, 0x0080},
+{0x01E090, 0x0001},
 {0x01E100, 0x0004},
-{0x01E12D, 0x0080},
+{0x01E12D, 0x0001},
 {0x01E130, 0x0010},
 {0x01E137, 0x0004},
-{0x01E13E, 0x0080},
+{0x01E13E, 0x0001},
 {0x01E140, 0x0002},
-{0x01E14A, 0x0080},
+{0x01E14A, 0x0001},
 {0x01E14E, 0x0004},
 {0x01E14F, 0x0040},
-{0x01E150, 0x0080},
+{0x01E150, 0x0001},
 {0x01E290, 0x0004},
 {0x01E2AE, 0x0010},
-{0x01E2AF, 0x0080},
+{0x01E2AF, 0x0001},
 {0x01E2C0, 0x0004},
 {0x01E2EC, 0x0010},
 {0x01E2F0, 0x0002},
-{0x01E2FA, 0x0080},
+{0x01E2FA, 0x0001},
 {0x01E2FF, 0x0040},
-{0x01E300, 0x0080},
+{0x01E300, 0x0001},
 {0x01E4D0, 0x0004},
 {0x01E4EC, 0x0010},
 {0x01E4F0, 0x0002},
-{0x01E4FA, 0x0080},
+{0x01E4FA, 0x0001},
 {0x01E7E0, 0x0004},
-{0x01E7E7, 0x0080},
+{0x01E7E7, 0x0001},
 {0x01E7E8, 0x0004},
-{0x01E7EC, 0x0080},
+{0x01E7EC, 0x0001},
 {0x01E7ED, 0x0004},
-{0x01E7EF, 0x0080},
+{0x01E7EF, 0x0001},
 {0x01E7F0, 0x0004},
-{0x01E7FF, 0x0080},
+{0x01E7FF, 0x0001},
 {0x01E800, 0x0004},
-{0x01E8C5, 0x0080},
+{0x01E8C5, 0x0001},
 {0x01E8C7, 0x0002},
 {0x01E8D0, 0x0010},
-{0x01E8D7, 0x0080},
+{0x01E8D7, 0x0001},
 {0x01E900, 0x0004},
 {0x01E944, 0x0010},
 {0x01E94B, 0x0004},
-{0x01E94C, 0x0080},
+{0x01E94C, 0x0001},
 {0x01E950, 0x0002},
-{0x01E95A, 0x0080},
+{0x01E95A, 0x0001},
 {0x01E95E, 0x0020},
-{0x01E960, 0x0080},
+{0x01E960, 0x0001},
 {0x01EC71, 0x0002},
 {0x01ECAC, 0x0040},
 {0x01ECAD, 0x0002},
 {0x01ECB0, 0x0040},
 {0x01ECB1, 0x0002},
-{0x01ECB5, 0x0080},
+{0x01ECB5, 0x0001},
 {0x01ED01, 0x0002},
 {0x01ED2E, 0x0040},
 {0x01ED2F, 0x0002},
-{0x01ED3E, 0x0080},
+{0x01ED3E, 0x0001},
 {0x01EE00, 0x0004},
-{0x01EE04, 0x0080},
+{0x01EE04, 0x0001},
 {0x01EE05, 0x0004},
-{0x01EE20, 0x0080},
+{0x01EE20, 0x0001},
 {0x01EE21, 0x0004},
-{0x01EE23, 0x0080},
+{0x01EE23, 0x0001},
 {0x01EE24, 0x0004},
-{0x01EE25, 0x0080},
+{0x01EE25, 0x0001},
 {0x01EE27, 0x0004},
-{0x01EE28, 0x0080},
+{0x01EE28, 0x0001},
 {0x01EE29, 0x0004},
-{0x01EE33, 0x0080},
+{0x01EE33, 0x0001},
 {0x01EE34, 0x0004},
-{0x01EE38, 0x0080},
+{0x01EE38, 0x0001},
 {0x01EE39, 0x0004},
-{0x01EE3A, 0x0080},
+{0x01EE3A, 0x0001},
 {0x01EE3B, 0x0004},
-{0x01EE3C, 0x0080},
+{0x01EE3C, 0x0001},
 {0x01EE42, 0x0004},
-{0x01EE43, 0x0080},
+{0x01EE43, 0x0001},
 {0x01EE47, 0x0004},
-{0x01EE48, 0x0080},
+{0x01EE48, 0x0001},
 {0x01EE49, 0x0004},
-{0x01EE4A, 0x0080},
+{0x01EE4A, 0x0001},
 {0x01EE4B, 0x0004},
-{0x01EE4C, 0x0080},
+{0x01EE4C, 0x0001},
 {0x01EE4D, 0x0004},
-{0x01EE50, 0x0080},
+{0x01EE50, 0x0001},
 {0x01EE51, 0x0004},
-{0x01EE53, 0x0080},
+{0x01EE53, 0x0001},
 {0x01EE54, 0x0004},
-{0x01EE55, 0x0080},
+{0x01EE55, 0x0001},
 {0x01EE57, 0x0004},
-{0x01EE58, 0x0080},
+{0x01EE58, 0x0001},
 {0x01EE59, 0x0004},
-{0x01EE5A, 0x0080},
+{0x01EE5A, 0x0001},
 {0x01EE5B, 0x0004},
-{0x01EE5C, 0x0080},
+{0x01EE5C, 0x0001},
 {0x01EE5D, 0x0004},
-{0x01EE5E, 0x0080},
+{0x01EE5E, 0x0001},
 {0x01EE5F, 0x0004},
-{0x01EE60, 0x0080},
+{0x01EE60, 0x0001},
 {0x01EE61, 0x0004},
-{0x01EE63, 0x0080},
+{0x01EE63, 0x0001},
 {0x01EE64, 0x0004},
-{0x01EE65, 0x0080},
+{0x01EE65, 0x0001},
 {0x01EE67, 0x0004},
-{0x01EE6B, 0x0080},
+{0x01EE6B, 0x0001},
 {0x01EE6C, 0x0004},
-{0x01EE73, 0x0080},
+{0x01EE73, 0x0001},
 {0x01EE74, 0x0004},
-{0x01EE78, 0x0080},
+{0x01EE78, 0x0001},
 {0x01EE79, 0x0004},
-{0x01EE7D, 0x0080},
+{0x01EE7D, 0x0001},
 {0x01EE7E, 0x0004},
-{0x01EE7F, 0x0080},
+{0x01EE7F, 0x0001},
 {0x01EE80, 0x0004},
-{0x01EE8A, 0x0080},
+{0x01EE8A, 0x0001},
 {0x01EE8B, 0x0004},
-{0x01EE9C, 0x0080},
+{0x01EE9C, 0x0001},
 {0x01EEA1, 0x0004},
-{0x01EEA4, 0x0080},
+{0x01EEA4, 0x0001},
 {0x01EEA5, 0x0004},
-{0x01EEAA, 0x0080},
+{0x01EEAA, 0x0001},
 {0x01EEAB, 0x0004},
-{0x01EEBC, 0x0080},
+{0x01EEBC, 0x0001},
 {0x01EEF0, 0x0040},
-{0x01EEF2, 0x0080},
+{0x01EEF2, 0x0001},
 {0x01F000, 0x0040},
-{0x01F02C, 0x0080},
+{0x01F02C, 0x0001},
 {0x01F030, 0x0040},
-{0x01F094, 0x0080},
+{0x01F094, 0x0001},
 {0x01F0A0, 0x0040},
-{0x01F0AF, 0x0080},
+{0x01F0AF, 0x0001},
 {0x01F0B1, 0x0040},
-{0x01F0C0, 0x0080},
+{0x01F0C0, 0x0001},
 {0x01F0C1, 0x0040},
-{0x01F0D0, 0x0080},
+{0x01F0D0, 0x0001},
 {0x01F0D1, 0x0040},
-{0x01F0F6, 0x0080},
+{0x01F0F6, 0x0001},
 {0x01F100, 0x0002},
 {0x01F10D, 0x0040},
-{0x01F1AE, 0x0080},
+{0x01F1AE, 0x0001},
 {0x01F1E6, 0x0040},
-{0x01F203, 0x0080},
+{0x01F203, 0x0001},
 {0x01F210, 0x0040},
-{0x01F23C, 0x0080},
+{0x01F23C, 0x0001},
 {0x01F240, 0x0040},
-{0x01F249, 0x0080},
+{0x01F249, 0x0001},
 {0x01F250, 0x0040},
-{0x01F252, 0x0080},
+{0x01F252, 0x0001},
 {0x01F260, 0x0040},
-{0x01F266, 0x0080},
+{0x01F266, 0x0001},
 {0x01F300, 0x0040},
-{0x01F6D8, 0x0080},
+{0x01F6D8, 0x0001},
 {0x01F6DC, 0x0040},
-{0x01F6ED, 0x0080},
+{0x01F6ED, 0x0001},
 {0x01F6F0, 0x0040},
-{0x01F6FD, 0x0080},
+{0x01F6FD, 0x0001},
 {0x01F700, 0x0040},
-{0x01F777, 0x0080},
+{0x01F777, 0x0001},
 {0x01F77B, 0x0040},
-{0x01F7DA, 0x0080},
+{0x01F7DA, 0x0001},
 {0x01F7E0, 0x0040},
-{0x01F7EC, 0x0080},
+{0x01F7EC, 0x0001},
 {0x01F7F0, 0x0040},
-{0x01F7F1, 0x0080},
+{0x01F7F1, 0x0001},
 {0x01F800, 0x0040},
-{0x01F80C, 0x0080},
+{0x01F80C, 0x0001},
 {0x01F810, 0x0040},
-{0x01F848, 0x0080},
+{0x01F848, 0x0001},
 {0x01F850, 0x0040},
-{0x01F85A, 0x0080},
+{0x01F85A, 0x0001},
 {0x01F860, 0x0040},
-{0x01F888, 0x0080},
+{0x01F888, 0x0001},
 {0x01F890, 0x0040},
-{0x01F8AE, 0x0080},
+{0x01F8AE, 0x0001},
 {0x01F8B0, 0x0040},
-{0x01F8B2, 0x0080},
+{0x01F8B2, 0x0001},
 {0x01F900, 0x0040},
-{0x01FA54, 0x0080},
+{0x01FA54, 0x0001},
 {0x01FA60, 0x0040},
-{0x01FA6E, 0x0080},
+{0x01FA6E, 0x0001},
 {0x01FA70, 0x0040},
-{0x01FA7D, 0x0080},
+{0x01FA7D, 0x0001},
 {0x01FA80, 0x0040},
-{0x01FA89, 0x0080},
+{0x01FA89, 0x0001},
 {0x01FA90, 0x0040},
-{0x01FABE, 0x0080},
+{0x01FABE, 0x0001},
 {0x01FABF, 0x0040},
-{0x01FAC6, 0x0080},
+{0x01FAC6, 0x0001},
 {0x01FACE, 0x0040},
-{0x01FADC, 0x0080},
+{0x01FADC, 0x0001},
 {0x01FAE0, 0x0040},
-{0x01FAE9, 0x0080},
+{0x01FAE9, 0x0001},
 {0x01FAF0, 0x0040},
-{0x01FAF9, 0x0080},
+{0x01FAF9, 0x0001},
 {0x01FB00, 0x0040},
-{0x01FB93, 0x0080},
+{0x01FB93, 0x0001},
 {0x01FB94, 0x0040},
-{0x01FBCB, 0x0080},
+{0x01FBCB, 0x0001},
 {0x01FBF0, 0x0002},
-{0x01FBFA, 0x0080},
+{0x01FBFA, 0x0001},
 {0x020000, 0x0004},
-{0x02A6E0, 0x0080},
+{0x02A6E0, 0x0001},
 {0x02A700, 0x0004},
-{0x02B73A, 0x0080},
+{0x02B73A, 0x0001},
 {0x02B740, 0x0004},
-{0x02B81E, 0x0080},
+{0x02B81E, 0x0001},
 {0x02B820, 0x0004},
-{0x02CEA2, 0x0080},
+{0x02CEA2, 0x0001},
 {0x02CEB0, 0x0004},
-{0x02EBE1, 0x0080},
+{0x02EBE1, 0x0001},
 {0x02EBF0, 0x0004},
-{0x02EE5E, 0x0080},
+{0x02EE5E, 0x0001},
 {0x02F800, 0x0004},
-{0x02FA1E, 0x0080},
+{0x02FA1E, 0x0001},
 {0x030000, 0x0004},
-{0x03134B, 0x0080},
+{0x03134B, 0x0001},
 {0x031350, 0x0004},
-{0x0323B0, 0x0080},
+{0x0323B0, 0x0001},
+{0x0E0001, 0x0080},
+{0x0E0002, 0x0001},
+{0x0E0020, 0x0080},
+{0x0E0080, 0x0001},
 {0x0E0100, 0x0010},
-{0x0E01F0, 0x0080},
+{0x0E01F0, 0x0001},
+{0x0F0000, 0x0080},
+{0x0FFFFE, 0x0001},
+{0x100000, 0x0080},
+{0x10FFFE, 0x0001},
 {0x110000, 0x0000},
 };
 
 const std::unordered_set unicode_set_whitespace = {
-0x000009, 0x00000A, 0x00000B, 0x00000C, 0x00000D, 0x000020, 0x000085, 0x0000A0, 0x001680, 0x002000, 0x002001, 0x002002, 0x002003, 0x002004, 0x002005, 0x002006, 0x002007, 0x002008, 0x002009, 0x00200A, 0x002028, 0x002029, 0x00202F, 0x00205F, 0x003000
+0x000009,
+0x00000A,
+0x00000B,
+0x00000C,
+0x00000D,
+0x000020,
+0x000085,
+0x0000A0,
+0x001680,
+0x002000,
+0x002001,
+0x002002,
+0x002003,
+0x002004,
+0x002005,
+0x002006,
+0x002007,
+0x002008,
+0x002009,
+0x00200A,
+0x002028,
+0x002029,
+0x00202F,
+0x00205F,
+0x003000,
 };
 
 const std::unordered_map unicode_map_lowercase = {
@@ -3248,6 +3293,7 @@ const std::unordered_map unicode_map_lowercase = {
 {0x002C2C, 0x002C5C},
 {0x002C2D, 0x002C5D},
 {0x002C2E, 0x002C5E},
+{0x002C2F, 0x002C5F},
 {0x002C60, 0x002C61},
 {0x002C62, 0x00026B},
 {0x002C63, 0x001D7D},
@@ -3428,12 +3474,16 @@ const std::unordered_map unicode_map_lowercase = {
 {0x00A7BA, 0x00A7BB},
 {0x00A7BC, 0x00A7BD},
 {0x00A7BE, 0x00A7BF},
+{0x00A7C0, 0x00A7C1},
 {0x00A7C2, 0x00A7C3},
 {0x00A7C4, 0x00A794},
 {0x00A7C5, 0x000282},
 {0x00A7C6, 0x001D8E},
 {0x00A7C7, 0x00A7C8},
 {0x00A7C9, 0x00A7CA},
+{0x00A7D0, 0x00A7D1},
+{0x00A7D6, 0x00A7D7},
+{0x00A7D8, 0x00A7D9},
 {0x00A7F5, 0x00A7F6},
 {0x00FF21, 0x00FF41},
 {0x00FF22, 0x00FF42},
@@ -3537,6 +3587,41 @@ const std::unordered_map unicode_map_lowercase = {
 {0x0104D1, 0x0104F9},
 {0x0104D2, 0x0104FA},
 {0x0104D3, 0x0104FB},
+{0x010570, 0x010597},
+{0x010571, 0x010598},
+{0x010572, 0x010599},
+{0x010573, 0x01059A},
+{0x010574, 0x01059B},
+{0x010575, 0x01059C},
+{0x010576, 0x01059D},
+{0x010577, 0x01059E},
+{0x010578, 0x01059F},
+{0x010579, 0x0105A0},
+{0x01057A, 0x0105A1},
+{0x01057C, 0x0105A3},
+{0x01057D, 0x0105A4},
+{0x01057E, 0x0105A5},
+{0x01057F, 0x0105A6},
+{0x010580, 0x0105A7},
+{0x010581, 0x0105A8},
+{0x010582, 0x0105A9},
+{0x010583, 0x0105AA},
+{0x010584, 0x0105AB},
+{0x010585, 0x0105AC},
+{0x010586, 0x0105AD},
+{0x010587, 0x0105AE},
+{0x010588, 0x0105AF},
+{0x010589, 0x0105B0},
+{0x01058A, 0x0105B1},
+{0x01058C, 0x0105B3},
+{0x01058D, 0x0105B4},
+{0x01058E, 0x0105B5},
+{0x01058F, 0x0105B6},
+{0x010590, 0x0105B7},
+{0x010591, 0x0105B8},
+{0x010592, 0x0105B9},
+{0x010594, 0x0105BB},
+{0x010595, 0x0105BC},
 {0x010C80, 0x010CC0},
 {0x010C81, 0x010CC1},
 {0x010C82, 0x010CC2},
@@ -3716,7 +3801,6 @@ const std::unordered_map unicode_map_uppercase = {
 {0x000079, 0x000059},
 {0x00007A, 0x00005A},
 {0x0000B5, 0x00039C},
-{0x0000DF, 0x000053},
 {0x0000E0, 0x0000C0},
 {0x0000E1, 0x0000C1},
 {0x0000E2, 0x0000C2},
@@ -3784,7 +3868,6 @@ const std::unordered_map unicode_map_uppercase = {
 {0x000144, 0x000143},
 {0x000146, 0x000145},
 {0x000148, 0x000147},
-{0x000149, 0x0002BC},
 {0x00014B, 0x00014A},
 {0x00014D, 0x00014C},
 {0x00014F, 0x00014E},
@@ -3857,7 +3940,6 @@ const std::unordered_map unicode_map_uppercase = {
 {0x0001EB, 0x0001EA},
 {0x0001ED, 0x0001EC},
 {0x0001EF, 0x0001EE},
-{0x0001F0, 0x00004A},
 {0x0001F2, 0x0001F1},
 {0x0001F3, 0x0001F1},
 {0x0001F5, 0x0001F4},
@@ -3943,12 +4025,10 @@ const std::unordered_map unicode_map_uppercase = {
 {0x00037B, 0x0003FD},
 {0x00037C, 0x0003FE},
 {0x00037D, 0x0003FF},
-{0x000390, 0x000399},
 {0x0003AC, 0x000386},
 {0x0003AD, 0x000388},
 {0x0003AE, 0x000389},
 {0x0003AF, 0x00038A},
-{0x0003B0, 0x0003A5},
 {0x0003B1, 0x000391},
 {0x0003B2, 0x000392},
 {0x0003B3, 0x000393},
@@ -4189,7 +4269,6 @@ const std::unordered_map unicode_map_uppercase = {
 {0x000584, 0x000554},
 {0x000585, 0x000555},
 {0x000586, 0x000556},
-{0x000587, 0x000535},
 {0x0010D0, 0x001C90},
 {0x0010D1, 0x001C91},
 {0x0010D2, 0x001C92},
@@ -4329,11 +4408,6 @@ const std::unordered_map unicode_map_uppercase = {
 {0x001E91, 0x001E90},
 {0x001E93, 0x001E92},
 {0x001E95, 0x001E94},
-{0x001E96, 0x000048},
-{0x001E97, 0x000054},
-{0x001E98, 0x000057},
-{0x001E99, 0x000059},
-{0x001E9A, 0x000041},
 {0x001E9B, 0x001E60},
 {0x001EA1, 0x001EA0},
 {0x001EA3, 0x001EA2},
@@ -4419,13 +4493,9 @@ const std::unordered_map unicode_map_uppercase = {
 {0x001F43, 0x001F4B},
 {0x001F44, 0x001F4C},
 {0x001F45, 0x001F4D},
-{0x001F50, 0x0003A5},
 {0x001F51, 0x001F59},
-{0x001F52, 0x0003A5},
 {0x001F53, 0x001F5B},
-{0x001F54, 0x0003A5},
 {0x001F55, 0x001F5D},
-{0x001F56, 0x0003A5},
 {0x001F57, 0x001F5F},
 {0x001F60, 0x001F68},
 {0x001F61, 0x001F69},
@@ -4449,89 +4519,41 @@ const std::unordered_map unicode_map_uppercase = {
 {0x001F7B, 0x001FEB},
 {0x001F7C, 0x001FFA},
 {0x001F7D, 0x001FFB},
-{0x001F80, 0x001F08},
-{0x001F81, 0x001F09},
-{0x001F82, 0x001F0A},
-{0x001F83, 0x001F0B},
-{0x001F84, 0x001F0C},
-{0x001F85, 0x001F0D},
-{0x001F86, 0x001F0E},
-{0x001F87, 0x001F0F},
-{0x001F88, 0x001F08},
-{0x001F89, 0x001F09},
-{0x001F8A, 0x001F0A},
-{0x001F8B, 0x001F0B},
-{0x001F8C, 0x001F0C},
-{0x001F8D, 0x001F0D},
-{0x001F8E, 0x001F0E},
-{0x001F8F, 0x001F0F},
-{0x001F90, 0x001F28},
-{0x001F91, 0x001F29},
-{0x001F92, 0x001F2A},
-{0x001F93, 0x001F2B},
-{0x001F94, 0x001F2C},
-{0x001F95, 0x001F2D},
-{0x001F96, 0x001F2E},
-{0x001F97, 0x001F2F},
-{0x001F98, 0x001F28},
-{0x001F99, 0x001F29},
-{0x001F9A, 0x001F2A},
-{0x001F9B, 0x001F2B},
-{0x001F9C, 0x001F2C},
-{0x001F9D, 0x001F2D},
-{0x001F9E, 0x001F2E},
-{0x001F9F, 0x001F2F},
-{0x001FA0, 0x001F68},
-{0x001FA1, 0x001F69},
-{0x001FA2, 0x001F6A},
-{0x001FA3, 0x001F6B},
-{0x001FA4, 0x001F6C},
-{0x001FA5, 0x001F6D},
-{0x001FA6, 0x001F6E},
-{0x001FA7, 0x001F6F},
-{0x001FA8, 0x001F68},
-{0x001FA9, 0x001F69},
-{0x001FAA, 0x001F6A},
-{0x001FAB, 0x001F6B},
-{0x001FAC, 0x001F6C},
-{0x001FAD, 0x001F6D},
-{0x001FAE, 0x001F6E},
-{0x001FAF, 0x001F6F},
+{0x001F80, 0x001F88},
+{0x001F81, 0x001F89},
+{0x001F82, 0x001F8A},
+{0x001F83, 0x001F8B},
+{0x001F84, 0x001F8C},
+{0x001F85, 0x001F8D},
+{0x001F86, 0x001F8E},
+{0x001F87, 0x001F8F},
+{0x001F90, 0x001F98},
+{0x001F91, 0x001F99},
+{0x001F92, 0x001F9A},
+{0x001F93, 0x001F9B},
+{0x001F94, 0x001F9C},
+{0x001F95, 0x001F9D},
+{0x001F96, 0x001F9E},
+{0x001F97, 0x001F9F},
+{0x001FA0, 0x001FA8},
+{0x001FA1, 0x001FA9},
+{0x001FA2, 0x001FAA},
+{0x001FA3, 0x001FAB},
+{0x001FA4, 0x001FAC},
+{0x001FA5, 0x001FAD},
+{0x001FA6, 0x001FAE},
+{0x001FA7, 0x001FAF},
 {0x001FB0, 0x001FB8},
 {0x001FB1, 0x001FB9},
-{0x001FB2, 0x001FBA},
-{0x001FB3, 0x000391},
-{0x001FB4, 0x000386},
-{0x001FB6, 0x000391},
-{0x001FB7, 0x000391},
-{0x001FBC, 0x000391},
+{0x001FB3, 0x001FBC},
 {0x001FBE, 0x000399},
-{0x001FC2, 0x001FCA},
-{0x001FC3, 0x000397},
-{0x001FC4, 0x000389},
-{0x001FC6, 0x000397},
-{0x001FC7, 0x000397},
-{0x001FCC, 0x000397},
+{0x001FC3, 0x001FCC},
 {0x001FD0, 0x001FD8},
 {0x001FD1, 0x001FD9},
-{0x001FD2, 0x000399},
-{0x001FD3, 0x000399},
-{0x001FD6, 0x000399},
-{0x001FD7, 0x000399},
 {0x001FE0, 0x001FE8},
 {0x001FE1, 0x001FE9},
-{0x001FE2, 0x0003A5},
-{0x001FE3, 0x0003A5},
-{0x001FE4, 0x0003A1},
 {0x001FE5, 0x001FEC},
-{0x001FE6, 0x0003A5},
-{0x001FE7, 0x0003A5},
-{0x001FF2, 0x001FFA},
-{0x001FF3, 0x0003A9},
-{0x001FF4, 0x00038F},
-{0x001FF6, 0x0003A9},
-{0x001FF7, 0x0003A9},
-{0x001FFC, 0x0003A9},
+{0x001FF3, 0x001FFC},
 {0x00214E, 0x002132},
 {0x002170, 0x002160},
 {0x002171, 0x002161},
@@ -4623,6 +4645,7 @@ const std::unordered_map unicode_map_uppercase = {
 {0x002C5C, 0x002C2C},
 {0x002C5D, 0x002C2D},
 {0x002C5E, 0x002C2E},
+{0x002C5F, 0x002C2F},
 {0x002C61, 0x002C60},
 {0x002C65, 0x00023A},
 {0x002C66, 0x00023E},
@@ -4826,9 +4849,13 @@ const std::unordered_map unicode_map_uppercase = {
 {0x00A7BB, 0x00A7BA},
 {0x00A7BD, 0x00A7BC},
 {0x00A7BF, 0x00A7BE},
+{0x00A7C1, 0x00A7C0},
 {0x00A7C3, 0x00A7C2},
 {0x00A7C8, 0x00A7C7},
 {0x00A7CA, 0x00A7C9},
+{0x00A7D1, 0x00A7D0},
+{0x00A7D7, 0x00A7D6},
+{0x00A7D9, 0x00A7D8},
 {0x00A7F6, 0x00A7F5},
 {0x00AB53, 0x00A7B3},
 {0x00AB70, 0x0013A0},
@@ -4911,18 +4938,6 @@ const std::unordered_map unicode_map_uppercase = {
 {0x00ABBD, 0x0013ED},
 {0x00ABBE, 0x0013EE},
 {0x00ABBF, 0x0013EF},
-{0x00FB00, 0x000046},
-{0x00FB01, 0x000046},
-{0x00FB02, 0x000046},
-{0x00FB03, 0x000046},
-{0x00FB04, 0x000046},
-{0x00FB05, 0x000053},
-{0x00FB06, 0x000053},
-{0x00FB13, 0x000544},
-{0x00FB14, 0x000544},
-{0x00FB15, 0x000544},
-{0x00FB16, 0x00054E},
-{0x00FB17, 0x000544},
 {0x00FF41, 0x00FF21},
 {0x00FF42, 0x00FF22},
 {0x00FF43, 0x00FF23},
@@ -5025,6 +5040,41 @@ const std::unordered_map unicode_map_uppercase = {
 {0x0104F9, 0x0104D1},
 {0x0104FA, 0x0104D2},
 {0x0104FB, 0x0104D3},
+{0x010597, 0x010570},
+{0x010598, 0x010571},
+{0x010599, 0x010572},
+{0x01059A, 0x010573},
+{0x01059B, 0x010574},
+{0x01059C, 0x010575},
+{0x01059D, 0x010576},
+{0x01059E, 0x010577},
+{0x01059F, 0x010578},
+{0x0105A0, 0x010579},
+{0x0105A1, 0x01057A},
+{0x0105A3, 0x01057C},
+{0x0105A4, 0x01057D},
+{0x0105A5, 0x01057E},
+{0x0105A6, 0x01057F},
+{0x0105A7, 0x010580},
+{0x0105A8, 0x010581},
+{0x0105A9, 0x010582},
+{0x0105AA, 0x010583},
+{0x0105AB, 0x010584},
+{0x0105AC, 0x010585},
+{0x0105AD, 0x010586},
+{0x0105AE, 0x010587},
+{0x0105AF, 0x010588},
+{0x0105B0, 0x010589},
+{0x0105B1, 0x01058A},
+{0x0105B3, 0x01058C},
+{0x0105B4, 0x01058D},
+{0x0105B5, 0x01058E},
+{0x0105B6, 0x01058F},
+{0x0105B7, 0x010590},
+{0x0105B8, 0x010591},
+{0x0105B9, 0x010592},
+{0x0105BB, 0x010594},
+{0x0105BC, 0x010595},
 {0x010CC0, 0x010C80},
 {0x010CC1, 0x010C81},
 {0x010CC2, 0x010C82},
@@ -7006,4 +7056,3 @@ const std::vector unicode_ranges_nfd = {  // start, last, nfd
 {0x02FA1C, 0x02FA1C, 0x009F3B},
 {0x02FA1D, 0x02FA1D, 0x02A600},
 };
-
diff --git a/llama/unicode-data.h b/llama/unicode-data.h
index 458e5bf1..3abb9c74 100644
--- a/llama/unicode-data.h
+++ b/llama/unicode-data.h
@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *
diff --git a/llama/unicode.cpp b/llama/unicode.cpp
index 4e0ff2aa..774a5210 100644
--- a/llama/unicode.cpp
+++ b/llama/unicode.cpp
@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *
@@ -24,6 +24,10 @@
  * SOFTWARE.
  */
 
+#if defined(_MSC_VER)
+#define _SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING
+#endif
+
 #include "unicode.h"
 #include "unicode-data.h"
 
@@ -41,6 +45,12 @@
 #include 
 #include 
 
+size_t unicode_len_utf8(char src) {
+    const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
+    uint8_t highbits = static_cast(src) >> 4;
+    return lookup[highbits];
+}
+
 static std::string unicode_cpts_to_utf8(const std::vector & cps) {
     std::string result;
     for (size_t i = 0; i < cps.size(); ++i) {
@@ -49,7 +59,7 @@ static std::string unicode_cpts_to_utf8(const std::vector & cps) {
     return result;
 }
 
-static uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset) {
+uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset) {
     assert(offset < utf8.size());
     if (!(utf8[offset + 0] & 0x80)) {
         auto result = utf8[offset + 0];
@@ -252,13 +262,13 @@ static std::vector unicode_regex_split_custom_gpt2(const std::string & t
         assert(offset_end <= cpts.size());
         start = offset_end;
 
-        auto _get_cpt = [&] (const size_t pos) -> char32_t {
-            return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : 0;
+        static const uint32_t OUT_OF_RANGE = 0xFFFFFFFF;
+        auto _get_cpt = [&] (const size_t pos) -> uint32_t {
+            return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE;
         };
 
         auto _get_flags = [&] (const size_t pos) -> codepoint_flags {
-            static const codepoint_flags undef(codepoint_flags::UNDEFINED);
-            return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : undef;
+            return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : codepoint_flags{};
         };
 
         size_t _prev_end = offset_ini;
@@ -279,18 +289,18 @@ static std::vector unicode_regex_split_custom_gpt2(const std::string & t
         };
 
         for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) {
-            const char32_t cpt = _get_cpt(pos);
+            const uint32_t cpt = _get_cpt(pos);
             const auto flags = _get_flags(pos);
 
             // regex: 's|'t|'re|'ve|'m|'ll|'d
             if (cpt == '\'' && pos+1 < offset_end) {
-                char32_t cpt_next = _get_cpt(pos+1);
+                uint32_t cpt_next = _get_cpt(pos+1);
                 if (cpt_next == 's' || cpt_next == 't' || cpt_next == 'm' || cpt_next == 'd') {
                     pos += _add_token(pos+2);
                     continue;
                 }
                 if (pos+2 < offset_end) {
-                    char32_t cpt_next_next = _get_cpt(pos+2);
+                    uint32_t cpt_next_next = _get_cpt(pos+2);
                     if ((cpt_next == 'r' && cpt_next_next == 'e') ||
                         (cpt_next == 'v' && cpt_next_next == 'e') ||
                         (cpt_next == 'l' && cpt_next_next == 'l')) {
@@ -320,9 +330,9 @@ static std::vector unicode_regex_split_custom_gpt2(const std::string & t
                 continue;
             }
             // regex: ?[^\s\p{L}\p{N}]+
-            if (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number || flags2.is_undefined)) {
+            if (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags2.as_uint()) {
                 pos += (cpt == ' ');
-                while (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number || flags2.is_undefined)) {
+                while (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags2.as_uint()) {
                     flags2 = _get_flags(++pos);
                 }
                 _add_token(pos);
@@ -335,7 +345,7 @@ static std::vector unicode_regex_split_custom_gpt2(const std::string & t
             }
 
             // regex: \s+(?!\S)
-            if (num_whitespaces > 1 && _get_cpt(pos+num_whitespaces) != 0) {
+            if (num_whitespaces > 1 && _get_cpt(pos+num_whitespaces) != OUT_OF_RANGE) {
                 pos += num_whitespaces - 1;
                 _add_token(pos);
                 continue;
@@ -370,13 +380,13 @@ static std::vector unicode_regex_split_custom_llama3(const std::string &
         assert(offset_end <= cpts.size());
         start = offset_end;
 
-        auto _get_cpt = [&] (const size_t pos) -> char32_t {
-            return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : 0;
+        static const uint32_t OUT_OF_RANGE = 0xFFFFFFFF;
+        auto _get_cpt = [&] (const size_t pos) -> uint32_t {
+            return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE;
         };
 
         auto _get_flags = [&] (const size_t pos) -> codepoint_flags {
-            static const codepoint_flags undef(codepoint_flags::UNDEFINED);
-            return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : undef;
+            return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : codepoint_flags{};
         };
 
         size_t _prev_end = offset_ini;
@@ -397,18 +407,18 @@ static std::vector unicode_regex_split_custom_llama3(const std::string &
         };
 
         for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) {
-            const char32_t cpt = _get_cpt(pos);
+            const uint32_t cpt = _get_cpt(pos);
             const auto flags = _get_flags(pos);
 
             // regex: (?i:'s|'t|'re|'ve|'m|'ll|'d) // case insensitive
             if (cpt == '\'' && pos+1 < offset_end) {
-                char32_t cpt_next = unicode_tolower(_get_cpt(pos+1));
+                uint32_t cpt_next = unicode_tolower(_get_cpt(pos+1));
                 if (cpt_next == 's' || cpt_next == 't' || cpt_next == 'm' || cpt_next == 'd') {
                     pos += _add_token(pos+2);
                     continue;
                 }
                 if (pos+2 < offset_end) {
-                    char32_t cpt_next_next = unicode_tolower(_get_cpt(pos+2));
+                    uint32_t cpt_next_next = unicode_tolower(_get_cpt(pos+2));
                     if ((cpt_next == 'r' && cpt_next_next == 'e') ||
                         (cpt_next == 'v' && cpt_next_next == 'e') ||
                         (cpt_next == 'l' && cpt_next_next == 'l')) {
@@ -418,8 +428,8 @@ static std::vector unicode_regex_split_custom_llama3(const std::string &
                 }
             }
 
-            // regex: [^\r\n\p{L}\p{N}]?\p{L}+  //####FIXME: the first \p{L} is correct?
-            if (!(cpt == '\r' || cpt == '\n' || /*flags.is_letter |*/ flags.is_number)) {
+            // regex: [^\r\n\p{L}\p{N}]?\p{L}+
+            if (!(cpt == '\r' || cpt == '\n' || flags.is_number)) {
                 if (flags.is_letter || _get_flags(pos+1).is_letter) {  // one or more letters
                     pos++;
                     while (_get_flags(pos).is_letter) {
@@ -445,12 +455,12 @@ static std::vector unicode_regex_split_custom_llama3(const std::string &
 
             // regex: ?[^\s\p{L}\p{N}]+[\r\n]*
             auto flags2 = (cpt == ' ' ? _get_flags(pos+1) : flags);
-            if (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number || flags2.is_undefined)) {
+            if (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags.as_uint()) {
                 pos += (cpt == ' ');
-                while (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number || flags2.is_undefined)) {
+                while (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags2.as_uint()) {
                     flags2 = _get_flags(++pos);
                 }
-                char32_t cpt2 = _get_cpt(pos);
+                uint32_t cpt2 = _get_cpt(pos);
                 while (cpt2 == '\r' || cpt2 == '\n') {
                     cpt2 = _get_cpt(++pos);
                 }
@@ -461,7 +471,7 @@ static std::vector unicode_regex_split_custom_llama3(const std::string &
             size_t num_whitespaces = 0;
             size_t last_end_r_or_n = 0;
             while (_get_flags(pos+num_whitespaces).is_whitespace) {
-                char32_t cpt2 = _get_cpt(pos+num_whitespaces);
+                uint32_t cpt2 = _get_cpt(pos+num_whitespaces);
                 if (cpt2 == '\r' || cpt2 == '\n') {
                     last_end_r_or_n = pos + num_whitespaces + 1;
                 }
@@ -476,7 +486,7 @@ static std::vector unicode_regex_split_custom_llama3(const std::string &
             }
 
             // regex: \s+(?!\S)
-            if (num_whitespaces > 1 && _get_cpt(pos+num_whitespaces) != 0) {
+            if (num_whitespaces > 1 && _get_cpt(pos+num_whitespaces) != OUT_OF_RANGE) {
                 pos += num_whitespaces - 1;
                 _add_token(pos);
                 continue;
@@ -620,6 +630,7 @@ std::vector unicode_cpts_normalize_nfd(const std::vector & c
 
 std::vector unicode_cpts_from_utf8(const std::string & utf8) {
     std::vector result;
+    result.reserve(utf8.size());
     size_t offset = 0;
     while (offset < utf8.size()) {
         result.push_back(unicode_cpt_from_utf8(utf8, offset));
@@ -652,7 +663,7 @@ uint8_t unicode_utf8_to_byte(const std::string & utf8) {
     return map.at(utf8);
 }
 
-char32_t unicode_tolower(char32_t cp) {
+uint32_t unicode_tolower(uint32_t cp) {
     auto it = unicode_map_lowercase.find(cp);
     return it == unicode_map_lowercase.end() ? cp : it->second;
 }
@@ -705,10 +716,14 @@ std::vector unicode_regex_split(const std::string & text, const std
                 continue;
             }
 
-            const int cpt_flag = unicode_cpt_flags(cpts[i]).category_flag();
+            const auto flags = unicode_cpt_flags(cpts[i]);
 
-            if (k_ucat_cpt.find(cpt_flag) != k_ucat_cpt.end()) {
-                text_collapsed[i] = k_ucat_cpt.at(cpt_flag);
+            if (flags.is_whitespace) {
+                //NOTE: C++ std::regex \s does not mach 0x85, Rust and Python regex does.
+                //text_collapsed[i] = (char) 0x85;  //  as whitespace fallback
+                text_collapsed[i] = (char) 0x0B;    //  as whitespace fallback
+            } else if (k_ucat_cpt.find(flags.category_flag()) != k_ucat_cpt.end()) {
+                text_collapsed[i] = k_ucat_cpt.at(flags.category_flag());
             } else {
                 text_collapsed[i] = (char) 0xD0; // fallback
             }
@@ -792,9 +807,16 @@ std::vector unicode_regex_split(const std::string & text, const std
                 bpe_offsets = unicode_regex_split_stl(text_collapsed, regex_expr_collapsed, bpe_offsets);
             } else {
                 // no unicode category used, we can use std::wregex directly
-                const std::wstring wtext       = unicode_wstring_from_utf8(text);
                 const std::wstring wregex_expr = unicode_wstring_from_utf8(regex_expr);
 
+                // std::wregex \s does not mach non-ASCII whitespaces, using 0x0B as fallback
+                std::wstring wtext(cpts.begin(), cpts.end());
+                for (size_t i = 0; i < wtext.size(); ++i) {
+                    if (wtext[i] > 0x7F && unicode_cpt_flags(wtext[i]).is_whitespace) {
+                        wtext[i] = 0x0B;
+                    }
+                }
+
                 //printf("text: %s\n", text.c_str());
                 //printf("regex_expr: %s\n", regex_expr.c_str());
                 bpe_offsets = unicode_regex_split_stl(wtext, wregex_expr, bpe_offsets);
diff --git a/llama/unicode.h b/llama/unicode.h
index 729b28e8..1850ceeb 100644
--- a/llama/unicode.h
+++ b/llama/unicode.h
@@ -1,5 +1,5 @@
 /**
- * llama.cpp - commit ee459f40f65810a810151b24eba5b8bd174ceffe - do not edit this file
+ * llama.cpp - commit 6eeaeba126ff701f3e8f79f246805b7023709972 - do not edit this file
  *
  * MIT License
  *
@@ -30,6 +30,8 @@
 #include 
 #include 
 
+// TODO: prefix all symbols with "llama_"
+
 struct codepoint_flags {
     enum {
         UNDEFINED       = 0x0001,
@@ -72,8 +74,10 @@ struct codepoint_flags {
     }
 };
 
+size_t unicode_len_utf8(char src);
 
 std::string unicode_cpt_to_utf8(uint32_t cp);
+uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset);
 std::vector unicode_cpts_from_utf8(const std::string & utf8);
 
 std::vector unicode_cpts_normalize_nfd(const std::vector & cpts);
@@ -84,6 +88,6 @@ codepoint_flags unicode_cpt_flags(const std::string & utf8);
 std::string unicode_byte_to_utf8(uint8_t byte);
 uint8_t unicode_utf8_to_byte(const std::string & utf8);
 
-char32_t unicode_tolower(char32_t cp);
+uint32_t unicode_tolower(uint32_t cp);
 
 std::vector unicode_regex_split(const std::string & text, const std::vector & regex_exprs);
diff --git a/llm/llm.go b/llm/llm.go
index 2a0c4b91..87cea5cd 100644
--- a/llm/llm.go
+++ b/llm/llm.go
@@ -1,12 +1,12 @@
 package llm
 
-// #cgo CFLAGS: -Illama.cpp
-// #cgo darwin,arm64 LDFLAGS: ${SRCDIR}/build/darwin/arm64_static/libllama.a -lstdc++
-// #cgo darwin,amd64 LDFLAGS: ${SRCDIR}/build/darwin/x86_64_static/libllama.a -lstdc++
-// #cgo windows,amd64 LDFLAGS: ${SRCDIR}/build/windows/amd64_static/libllama.a -static -lstdc++
-// #cgo windows,arm64 LDFLAGS: ${SRCDIR}/build/windows/arm64_static/libllama.a -static -lstdc++
-// #cgo linux,amd64 LDFLAGS: ${SRCDIR}/build/linux/x86_64_static/libllama.a -lstdc++
-// #cgo linux,arm64 LDFLAGS: ${SRCDIR}/build/linux/arm64_static/libllama.a -lstdc++
+// #cgo CFLAGS: -Illama.cpp -Illama.cpp/include -Illama.cpp/ggml/include
+// #cgo darwin,arm64 LDFLAGS: ${SRCDIR}/build/darwin/arm64_static/src/libllama.a ${SRCDIR}/build/darwin/arm64_static/ggml/src/libggml.a -framework Accelerate -lstdc++
+// #cgo darwin,amd64 LDFLAGS: ${SRCDIR}/build/darwin/x86_64_static/src/libllama.a ${SRCDIR}/build/darwin/x86_64_static/ggml/src/libggml.a -framework Accelerate -lstdc++
+// #cgo windows,amd64 LDFLAGS: ${SRCDIR}/build/windows/amd64_static/src/libllama.a ${SRCDIR}/build/windows/amd64_static/ggml/src/libggml.a -static -lstdc++
+// #cgo windows,arm64 LDFLAGS: ${SRCDIR}/build/windows/arm64_static/src/libllama.a ${SRCDIR}/build/windows/arm64_static/ggml/src/libggml.a -static -lstdc++
+// #cgo linux,amd64 LDFLAGS: ${SRCDIR}/build/linux/x86_64_static/src/libllama.a ${SRCDIR}/build/linux/x86_64_static/ggml/src/libggml.a -lstdc++
+// #cgo linux,arm64 LDFLAGS: ${SRCDIR}/build/linux/arm64_static/src/libllama.a ${SRCDIR}/build/linux/arm64_static/ggml/src/libggml.a -lstdc++
 // #include 
 // #include "llama.h"
 import "C"