diff --git a/ggml-metal.o b/ggml-metal.o deleted file mode 100644 index 00e044d8..00000000 Binary files a/ggml-metal.o and /dev/null differ diff --git a/llama/build-info.cpp b/llama/build-info.cpp index 25857e6b..348fda3f 100644 --- a/llama/build-info.cpp +++ b/llama/build-info.cpp @@ -1,4 +1,30 @@ +/** + * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28 + * + * MIT License + * + * Copyright (c) 2023-2024 The ggml authors + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + int LLAMA_BUILD_NUMBER = 0; -char const *LLAMA_COMMIT = "059031b8c40e1f4ba60586842c5b1ed3ddf61842"; +char const *LLAMA_COMMIT = ""; char const *LLAMA_COMPILER = ""; char const *LLAMA_BUILD_TARGET = ""; diff --git a/llama/clip.cpp b/llama/clip.cpp index 26782a2a..e6229af0 100644 --- a/llama/clip.cpp +++ b/llama/clip.cpp @@ -1,5 +1,5 @@ /** - * llama.cpp - git 059031b8c40e1f4ba60586842c5b1ed3ddf61842 + * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28 * * MIT License * diff --git a/llama/clip.h b/llama/clip.h index 1324b170..d6c28c5a 100644 --- a/llama/clip.h +++ b/llama/clip.h @@ -1,5 +1,5 @@ /** - * llama.cpp - git 059031b8c40e1f4ba60586842c5b1ed3ddf61842 + * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28 * * MIT License * @@ -94,7 +94,7 @@ CLIP_API bool clip_image_load_from_file(const char * fname, struct clip_image_u8 /** interpret bytes as an image file with length bytes_length, and use the result to populate img */ CLIP_API bool clip_image_load_from_bytes(const unsigned char * bytes, size_t bytes_length, struct clip_image_u8 * img); -/** preprocess img and store the result in res_imgs, pad_to_square may be overriden to false depending on model configuration */ +/** preprocess img and store the result in res_imgs, pad_to_square may be overridden to false depending on model configuration */ CLIP_API bool clip_image_preprocess(struct clip_ctx * ctx, const struct clip_image_u8 * img, struct clip_image_f32_batch * res_imgs ); CLIP_API struct ggml_tensor * clip_get_newline_tensor(const struct clip_ctx * ctx); diff --git a/llama/common.cpp b/llama/common.cpp index c7af80ca..dd8bcee0 100644 --- a/llama/common.cpp +++ b/llama/common.cpp @@ -1,5 +1,5 @@ /** - * llama.cpp - git 059031b8c40e1f4ba60586842c5b1ed3ddf61842 + * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28 * * MIT License * @@ -33,20 +33,21 @@ #include #include +#include #include +#include +#include #include #include #include -#include #include +#include #include #include #include #include #include #include -#include -#include #if defined(__APPLE__) && defined(__MACH__) #include @@ -99,7 +100,11 @@ using json = nlohmann::ordered_json; -int32_t get_num_physical_cores() { +// +// CPU utils +// + +int32_t cpu_get_num_physical_cores() { #ifdef __linux__ // enumerate the set of thread siblings, num entries is num cores std::unordered_set siblings; @@ -168,9 +173,9 @@ static bool is_running_on_efficiency_core(void) { return core_type == intel_atom; } -static int count_math_cpus(int cpu_count) { +static int cpu_count_math_cpus(int n_cpu) { int result = 0; - for (int cpu = 0; cpu < cpu_count; ++cpu) { + for (int cpu = 0; cpu < n_cpu; ++cpu) { if (pin_cpu(cpu)) { return -1; } @@ -188,16 +193,16 @@ static int count_math_cpus(int cpu_count) { /** * Returns number of CPUs on system that are useful for math. */ -int get_math_cpu_count() { +int32_t cpu_get_num_math() { #if defined(__x86_64__) && defined(__linux__) && !defined(__ANDROID__) - int cpu_count = sysconf(_SC_NPROCESSORS_ONLN); - if (cpu_count < 1) { - return get_num_physical_cores(); + int n_cpu = sysconf(_SC_NPROCESSORS_ONLN); + if (n_cpu < 1) { + return cpu_get_num_physical_cores(); } if (is_hybrid_cpu()) { cpu_set_t affinity; if (!pthread_getaffinity_np(pthread_self(), sizeof(affinity), &affinity)) { - int result = count_math_cpus(cpu_count); + int result = cpu_count_math_cpus(n_cpu); pthread_setaffinity_np(pthread_self(), sizeof(affinity), &affinity); if (result > 0) { return result; @@ -205,109 +210,105 @@ int get_math_cpu_count() { } } #endif - return get_num_physical_cores(); + return cpu_get_num_physical_cores(); } -void process_escapes(std::string & input) { - std::size_t input_len = input.length(); - std::size_t output_idx = 0; +// +// CLI argument parsing +// - for (std::size_t input_idx = 0; input_idx < input_len; ++input_idx) { - if (input[input_idx] == '\\' && input_idx + 1 < input_len) { - switch (input[++input_idx]) { - case 'n': input[output_idx++] = '\n'; break; - case 'r': input[output_idx++] = '\r'; break; - case 't': input[output_idx++] = '\t'; break; - case '\'': input[output_idx++] = '\''; break; - case '\"': input[output_idx++] = '\"'; break; - case '\\': input[output_idx++] = '\\'; break; - case 'x': - // Handle \x12, etc - if (input_idx + 2 < input_len) { - const char x[3] = { input[input_idx + 1], input[input_idx + 2], 0 }; - char *err_p = nullptr; - const long val = std::strtol(x, &err_p, 16); - if (err_p == x + 2) { - input_idx += 2; - input[output_idx++] = char(val); - break; - } - } - // fall through - default: input[output_idx++] = '\\'; - input[output_idx++] = input[input_idx]; break; +void gpt_params_handle_model_default(gpt_params & params) { + if (!params.hf_repo.empty()) { + // short-hand to avoid specifying --hf-file -> default it to --model + if (params.hf_file.empty()) { + if (params.model.empty()) { + throw std::invalid_argument("error: --hf-repo requires either --hf-file or --model\n"); } - } else { - input[output_idx++] = input[input_idx]; + params.hf_file = params.model; + } else if (params.model.empty()) { + std::string cache_directory = fs_get_cache_directory(); + const bool success = fs_create_directory_with_parents(cache_directory); + if (!success) { + throw std::runtime_error("failed to create cache directory: " + cache_directory); + } + params.model = cache_directory + string_split(params.hf_file, '/').back(); + } + } else if (!params.model_url.empty()) { + if (params.model.empty()) { + auto f = string_split(params.model_url, '#').front(); + f = string_split(f, '?').front(); + f = string_split(f, '/').back(); + params.model = "models/" + f; + } + } else if (params.model.empty()) { + params.model = DEFAULT_MODEL_PATH; + } +} + +bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { + bool invalid_param = false; + std::string arg; + const std::string arg_prefix = "--"; + llama_sampling_params & sparams = params.sparams; + + for (int i = 1; i < argc; i++) { + arg = argv[i]; + if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) { + std::replace(arg.begin(), arg.end(), '_', '-'); + } + if (!gpt_params_find_arg(argc, argv, arg, params, i, invalid_param)) { + throw std::invalid_argument("error: unknown argument: " + arg); + } + if (invalid_param) { + throw std::invalid_argument("error: invalid parameter for argument: " + arg); } } - input.resize(output_idx); + if (params.prompt_cache_all && (params.interactive || params.interactive_first)) { + throw std::invalid_argument("error: --prompt-cache-all not supported in interactive mode yet\n"); + } + + gpt_params_handle_model_default(params); + + if (params.escape) { + string_process_escapes(params.prompt); + string_process_escapes(params.input_prefix); + string_process_escapes(params.input_suffix); + string_process_escapes(sparams.cfg_negative_prompt); + for (auto & antiprompt : params.antiprompt) { + string_process_escapes(antiprompt); + } + } + + if (!params.kv_overrides.empty()) { + params.kv_overrides.emplace_back(); + params.kv_overrides.back().key[0] = 0; + } + + return true; } bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { - bool result = true; - try { - if (!gpt_params_parse_ex(argc, argv, params)) { - gpt_print_usage(argc, argv, gpt_params()); - exit(0); - } - } - catch (const std::invalid_argument & ex) { - fprintf(stderr, "%s\n", ex.what()); - gpt_print_usage(argc, argv, gpt_params()); - exit(1); - } - return result; -} + const auto params_org = params; // the example can modify the default params -bool parse_kv_override(const char * data, std::vector & overrides) { - const char * sep = strchr(data, '='); - if (sep == nullptr || sep - data >= 128) { - fprintf(stderr, "%s: malformed KV override '%s'\n", __func__, data); - return false; - } - llama_model_kv_override kvo; - std::strncpy(kvo.key, data, sep - data); - kvo.key[sep - data] = 0; - sep++; - if (strncmp(sep, "int:", 4) == 0) { - sep += 4; - kvo.tag = LLAMA_KV_OVERRIDE_TYPE_INT; - kvo.val_i64 = std::atol(sep); - } else if (strncmp(sep, "float:", 6) == 0) { - sep += 6; - kvo.tag = LLAMA_KV_OVERRIDE_TYPE_FLOAT; - kvo.val_f64 = std::atof(sep); - } else if (strncmp(sep, "bool:", 5) == 0) { - sep += 5; - kvo.tag = LLAMA_KV_OVERRIDE_TYPE_BOOL; - if (std::strcmp(sep, "true") == 0) { - kvo.val_bool = true; - } else if (std::strcmp(sep, "false") == 0) { - kvo.val_bool = false; - } else { - fprintf(stderr, "%s: invalid boolean value for KV override '%s'\n", __func__, data); + try { + if (!gpt_params_parse_ex(argc, argv, params) || params.usage) { + params = params_org; + params.usage = true; return false; } - } else if (strncmp(sep, "str:", 4) == 0) { - sep += 4; - kvo.tag = LLAMA_KV_OVERRIDE_TYPE_STR; - if (strlen(sep) > 127) { - fprintf(stderr, "%s: malformed KV override '%s', value cannot exceed 127 chars\n", __func__, data); - return false; - } - strncpy(kvo.val_str, sep, 127); - kvo.val_str[127] = '\0'; - } else { - fprintf(stderr, "%s: invalid type for KV override '%s'\n", __func__, data); + } catch (const std::invalid_argument & ex) { + fprintf(stderr, "%s\n", ex.what()); + params = params_org; return false; } - overrides.emplace_back(std::move(kvo)); + return true; } bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_params & params, int & i, bool & invalid_param) { + const char split_delim = ','; + llama_sampling_params & sparams = params.sparams; if (arg == "-s" || arg == "--seed") { @@ -315,7 +316,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa invalid_param = true; return true; } - // This is temporary, in the future the samplign state will be moved fully to llama_sampling_context. + // TODO: this is temporary, in the future the sampling state will be moved fully to llama_sampling_context. params.seed = std::stoul(argv[i]); sparams.seed = std::stoul(argv[i]); return true; @@ -376,6 +377,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.escape = true; return true; } + if (arg == "--no-escape") { + params.escape = false; + return true; + } if (arg == "--prompt-cache") { if (++i >= argc) { invalid_param = true; @@ -430,7 +435,21 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa } return true; } - if (arg == "-n" || arg == "--n-predict") { + if (arg == "--in-file") { + if (++i >= argc) { + invalid_param = true; + return true; + } + std::ifstream file(argv[i]); + if (!file) { + fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); + invalid_param = true; + return true; + } + params.in_files.push_back(argv[i]); + return true; + } + if (arg == "-n" || arg == "--predict" || arg == "--n-predict") { if (++i >= argc) { invalid_param = true; return true; @@ -572,7 +591,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } const auto sampler_names = string_split(argv[i], ';'); - sparams.samplers_sequence = sampler_types_from_names(sampler_names, true); + sparams.samplers_sequence = llama_sampling_types_from_names(sampler_names, true); return true; } if (arg == "--sampling-seq") { @@ -580,7 +599,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa invalid_param = true; return true; } - sparams.samplers_sequence = sampler_types_from_chars(argv[i]); + sparams.samplers_sequence = llama_sampling_types_from_chars(argv[i]); return true; } if (arg == "--top-p") { @@ -927,30 +946,22 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.interactive = true; return true; } - if (arg == "--interactive-specials") { - params.interactive_specials = true; + if (arg == "-sp" || arg == "--special") { + params.special = true; return true; } - if (arg == "--embedding") { + if (arg == "--embedding" || arg == "--embeddings") { params.embedding = true; return true; } - if (arg == "--interactive-first") { + if (arg == "-if" || arg == "--interactive-first") { params.interactive_first = true; return true; } - if (arg == "-ins" || arg == "--instruct") { - params.instruct = true; - return true; - } if (arg == "-cnv" || arg == "--conversation") { params.conversation = true; return true; } - if (arg == "-cml" || arg == "--chatml") { - params.chatml = true; - return true; - } if (arg == "--infill") { params.infill = true; return true; @@ -987,7 +998,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.flash_attn = true; return true; } - if (arg == "--color") { + if (arg == "-co" || arg == "--color") { params.use_color = true; return true; } @@ -995,26 +1006,26 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.use_mlock = true; return true; } - if (arg == "--gpu-layers" || arg == "-ngl" || arg == "--n-gpu-layers") { + if (arg == "-ngl" || arg == "--gpu-layers" || arg == "--n-gpu-layers") { if (++i >= argc) { invalid_param = true; return true; } params.n_gpu_layers = std::stoi(argv[i]); if (!llama_supports_gpu_offload()) { - fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers option will be ignored\n"); + fprintf(stderr, "warning: not compiled with GPU offload support, --gpu-layers option will be ignored\n"); fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n"); } return true; } - if (arg == "--gpu-layers-draft" || arg == "-ngld" || arg == "--n-gpu-layers-draft") { + if (arg == "-ngld" || arg == "--gpu-layers-draft" || arg == "--gpu-layers-draft") { if (++i >= argc) { invalid_param = true; return true; } params.n_gpu_layers_draft = std::stoi(argv[i]); if (!llama_supports_gpu_offload()) { - fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers-draft option will be ignored\n"); + fprintf(stderr, "warning: not compiled with GPU offload support, --gpu-layers-draft option will be ignored\n"); fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n"); } return true; @@ -1025,9 +1036,9 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } params.main_gpu = std::stoi(argv[i]); -#ifndef GGML_USE_CUDA_SYCL - fprintf(stderr, "warning: llama.cpp was compiled without CUDA/SYCL. Setting the main GPU has no effect.\n"); -#endif // GGML_USE_CUDA_SYCL +#ifndef GGML_USE_CUDA_SYCL_VULKAN + fprintf(stderr, "warning: llama.cpp was compiled without CUDA/SYCL/Vulkan. Setting the main GPU has no effect.\n"); +#endif // GGML_USE_CUDA_SYCL_VULKAN return true; } if (arg == "--split-mode" || arg == "-sm") { @@ -1053,9 +1064,9 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa invalid_param = true; return true; } -#ifndef GGML_USE_CUDA_SYCL - fprintf(stderr, "warning: llama.cpp was compiled without CUDA/SYCL. Setting the split mode has no effect.\n"); -#endif // GGML_USE_CUDA_SYCL +#ifndef GGML_USE_CUDA_SYCL_VULKAN + fprintf(stderr, "warning: llama.cpp was compiled without CUDA/SYCL/Vulkan. Setting the split mode has no effect.\n"); +#endif // GGML_USE_CUDA_SYCL_VULKAN return true; } if (arg == "--tensor-split" || arg == "-ts") { @@ -1110,6 +1121,18 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa else { invalid_param = true; } return true; } + if (arg == "-v" || arg == "--verbose") { + params.verbosity = 1; + return true; + } + if (arg == "--verbosity") { + if (++i >= argc) { + invalid_param = true; + return true; + } + params.verbosity = std::stoi(argv[i]); + return true; + } if (arg == "--verbose-prompt") { params.verbose_prompt = true; return true; @@ -1174,6 +1197,14 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.ppl_stride = std::stoi(argv[i]); return true; } + if (arg == "--ppl-output-type") { + if (++i >= argc) { + invalid_param = true; + return true; + } + params.ppl_output_type = std::stoi(argv[i]); + return true; + } if (arg == "-ptc" || arg == "--print-token-count") { if (++i >= argc) { invalid_param = true; @@ -1186,14 +1217,6 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.check_tensors = true; return true; } - if (arg == "--ppl-output-type") { - if (++i >= argc) { - invalid_param = true; - return true; - } - params.ppl_output_type = std::stoi(argv[i]); - return true; - } if (arg == "--hellaswag") { params.hellaswag = true; return true; @@ -1265,19 +1288,15 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa } return true; } - if (arg == "-h" || arg == "--help") { - gpt_print_usage(argc, argv, gpt_params()); - exit(0); + if (arg == "-h" || arg == "--help" || arg == "--usage" ) { + params.usage = true; + return true; } if (arg == "--version") { fprintf(stderr, "version: %d (%s)\n", LLAMA_BUILD_NUMBER, LLAMA_COMMIT); fprintf(stderr, "built with %s for %s\n", LLAMA_COMPILER, LLAMA_BUILD_TARGET); exit(0); } - if (arg == "--random-prompt") { - params.random_prompt = true; - return true; - } if (arg == "--in-prefix-bos") { params.input_prefix_bos = true; return true; @@ -1337,13 +1356,284 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa invalid_param = true; return true; } - if (!parse_kv_override(argv[i], params.kv_overrides)) { + if (!string_parse_kv_override(argv[i], params.kv_overrides)) { fprintf(stderr, "error: Invalid type for KV override: %s\n", argv[i]); invalid_param = true; return true; } return true; } + if (arg == "--host") { + if (++i >= argc) { + invalid_param = true; + return true; + } + params.hostname = argv[i]; + return true; + } + if (arg == "--port") { + if (++i >= argc) { + invalid_param = true; + return true; + } + params.port = std::stoi(argv[i]); + return true; + } + if (arg == "--path") { + if (++i >= argc) { + invalid_param = true; + return true; + } + params.public_path = argv[i]; + return true; + } + if (arg == "--api-key") { + if (++i >= argc) { + invalid_param = true; + return true; + } + params.api_keys.push_back(argv[i]); + return true; + } + if (arg == "--api-key-file") { + if (++i >= argc) { + invalid_param = true; + return true; + } + std::ifstream key_file(argv[i]); + if (!key_file) { + fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); + invalid_param = true; + return true; + } + std::string key; + while (std::getline(key_file, key)) { + if (!key.empty()) { + params.api_keys.push_back(key); + } + } + key_file.close(); + return true; + } + if (arg == "--ssl-key-file") { + if (++i >= argc) { + invalid_param = true; + return true; + } + params.ssl_file_key = argv[i]; + return true; + } + if (arg == "--ssl-cert-file") { + if (++i >= argc) { + invalid_param = true; + return true; + } + params.ssl_file_cert = argv[i]; + return true; + } + if (arg == "--timeout" || arg == "-to") { + if (++i >= argc) { + invalid_param = true; + return true; + } + params.timeout_read = std::stoi(argv[i]); + params.timeout_write = std::stoi(argv[i]); + return true; + } + if (arg == "--threads-http") { + if (++i >= argc) { + invalid_param = true; + return true; + } + params.n_threads_http = std::stoi(argv[i]); + return true; + } + if (arg == "-spf" || arg == "--system-prompt-file") { + if (++i >= argc) { + invalid_param = true; + return true; + } + std::ifstream file(argv[i]); + if (!file) { + fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); + invalid_param = true; + return true; + } + std::string system_prompt; + std::copy( + std::istreambuf_iterator(file), + std::istreambuf_iterator(), + std::back_inserter(system_prompt) + ); + params.system_prompt = system_prompt; + return true; + } + if (arg == "--log-format") { + if (++i >= argc) { + invalid_param = true; + return true; + } + if (std::strcmp(argv[i], "json") == 0) { + params.log_json = true; + } else if (std::strcmp(argv[i], "text") == 0) { + params.log_json = false; + } else { + invalid_param = true; + return true; + } + return true; + } + if (arg == "--no-slots") { + params.endpoint_slots = false; + return true; + } + if (arg == "--metrics") { + params.endpoint_metrics = true; + return true; + } + if (arg == "--slot-save-path") { + if (++i >= argc) { + invalid_param = true; + return true; + } + params.slot_save_path = argv[i]; + // if doesn't end with DIRECTORY_SEPARATOR, add it + if (!params.slot_save_path.empty() && params.slot_save_path[params.slot_save_path.size() - 1] != DIRECTORY_SEPARATOR) { + params.slot_save_path += DIRECTORY_SEPARATOR; + } + return true; + } + if (arg == "--chat-template") { + if (++i >= argc) { + invalid_param = true; + return true; + } + if (!llama_chat_verify_template(argv[i])) { + fprintf(stderr, "error: the supplied chat template is not supported: %s\n", argv[i]); + fprintf(stderr, "note: llama.cpp does not use jinja parser, we only support commonly used templates\n"); + invalid_param = true; + return true; + } + params.chat_template = argv[i]; + return true; + } + if (arg == "-pps") { + params.is_pp_shared = true; + return true; + } + if (arg == "-npp") { + if (++i >= argc) { + invalid_param = true; + return true; + } + auto p = string_split(argv[i], split_delim); + params.n_pp.insert(params.n_pp.end(), p.begin(), p.end()); + return true; + } + if (arg == "-ntg") { + if (++i >= argc) { + invalid_param = true; + return true; + } + auto p = string_split(argv[i], split_delim); + params.n_tg.insert(params.n_tg.end(), p.begin(), p.end()); + return true; + } + if (arg == "-npl") { + if (++i >= argc) { + invalid_param = true; + return true; + } + auto p = string_split(argv[i], split_delim); + params.n_pl.insert(params.n_pl.end(), p.begin(), p.end()); + return true; + } + if (arg == "--context-file") { + if (++i >= argc) { + invalid_param = true; + return true; + } + std::ifstream file(argv[i], std::ios::binary); + if (!file) { + fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); + invalid_param = true; + return true; + } + params.context_files.push_back(argv[i]); + return true; + } + if (arg == "--chunk-size") { + if (++i >= argc) { + invalid_param = true; + return true; + } + params.chunk_size = std::stoi(argv[i]); + return true; + } + if (arg == "--chunk-separator") { + if (++i >= argc) { + invalid_param = true; + return true; + } + params.chunk_separator = argv[i]; + return true; + } + if (arg == "--junk") { + if (++i >= argc) { + invalid_param = true; + return true; + } + params.n_junk = std::stoi(argv[i]); + return true; + } + if (arg == "--pos") { + if (++i >= argc) { + invalid_param = true; + return true; + } + params.i_pos = std::stoi(argv[i]); + return true; + } + if (arg == "-o" || arg == "--output" || arg == "--output-file") { + if (++i >= argc) { + invalid_param = true; + return true; + } + params.out_file = argv[i]; + return true; + } + if (arg == "-ofreq" || arg == "--output-frequency") { + if (++i >= argc) { + invalid_param = true; + return true; + } + params.n_out_freq = std::stoi(argv[i]); + return true; + } + if (arg == "--save-frequency") { + if (++i >= argc) { + invalid_param = true; + return true; + } + params.n_save_freq = std::stoi(argv[i]); + return true; + } + if (arg == "--process-output") { + params.process_output = true; + return true; + } + if (arg == "--no-ppl") { + params.compute_ppl = false; + return true; + } + if (arg == "--chunk" || arg == "--from-chunk") { + if (++i >= argc) { + invalid_param = true; + return true; + } + params.i_chunk = std::stoi(argv[i]); + return true; + } #ifndef LOG_DISABLE_LOGS // Parse args for logging parameters if (log_param_single_parse(argv[i])) { @@ -1371,280 +1661,325 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return false; } -void gpt_params_handle_model_default(gpt_params & params) { - if (!params.hf_repo.empty()) { - // short-hand to avoid specifying --hf-file -> default it to --model - if (params.hf_file.empty()) { - if (params.model.empty()) { - throw std::invalid_argument("error: --hf-repo requires either --hf-file or --model\n"); - } - params.hf_file = params.model; - } else if (params.model.empty()) { - params.model = "models/" + string_split(params.hf_file, '/').back(); - } - } else if (!params.model_url.empty()) { - if (params.model.empty()) { - auto f = string_split(params.model_url, '#').front(); - f = string_split(f, '?').front(); - f = string_split(f, '/').back(); - params.model = "models/" + f; - } - } else if (params.model.empty()) { - params.model = DEFAULT_MODEL_PATH; - } -} +#ifdef __GNUC__ +#ifdef __MINGW32__ +#define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__))) +#else +#define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__))) +#endif +#else +#define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) +#endif -bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { - bool invalid_param = false; - std::string arg; - const std::string arg_prefix = "--"; - llama_sampling_params & sparams = params.sparams; - - for (int i = 1; i < argc; i++) { - arg = argv[i]; - if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) { - std::replace(arg.begin(), arg.end(), '_', '-'); - } - if (!gpt_params_find_arg(argc, argv, arg, params, i, invalid_param)) { - throw std::invalid_argument("error: unknown argument: " + arg); - } - if (invalid_param) { - throw std::invalid_argument("error: invalid parameter for argument: " + arg); - } - } - - if (params.prompt_cache_all && - (params.interactive || params.interactive_first || - params.instruct)) { - - throw std::invalid_argument("error: --prompt-cache-all not supported in interactive mode yet\n"); - } - - gpt_params_handle_model_default(params); - - if (params.escape) { - process_escapes(params.prompt); - process_escapes(params.input_prefix); - process_escapes(params.input_suffix); - process_escapes(sparams.cfg_negative_prompt); - for (auto & antiprompt : params.antiprompt) { - process_escapes(antiprompt); - } - } - - if (!params.kv_overrides.empty()) { - params.kv_overrides.emplace_back(); - params.kv_overrides.back().key[0] = 0; - } - - return true; -} - -void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { +void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { const llama_sampling_params & sparams = params.sparams; std::string sampler_type_chars; std::string sampler_type_names; for (const auto sampler_type : sparams.samplers_sequence) { sampler_type_chars += static_cast(sampler_type); - sampler_type_names += sampler_type_to_name_string(sampler_type) + ";"; + sampler_type_names += llama_sampling_type_to_str(sampler_type) + ";"; } sampler_type_names.pop_back(); - printf("\n"); - printf("usage: %s [options]\n", argv[0]); - printf("\n"); - printf("options:\n"); - printf(" -h, --help show this help message and exit\n"); - printf(" --version show version and build info\n"); - printf(" -i, --interactive run in interactive mode\n"); - printf(" --interactive-specials allow special tokens in user text, in interactive mode\n"); - printf(" --interactive-first run in interactive mode and wait for input right away\n"); - printf(" -cnv, --conversation run in conversation mode (does not print special tokens and suffix/prefix)\n"); - printf(" -ins, --instruct run in instruction mode (use with Alpaca models)\n"); - printf(" -cml, --chatml run in chatml mode (use with ChatML-compatible models)\n"); - printf(" --multiline-input allows you to write or paste multiple lines without ending each in '\\'\n"); - printf(" -r PROMPT, --reverse-prompt PROMPT\n"); - printf(" halt generation at PROMPT, return control in interactive mode\n"); - printf(" (can be specified more than once for multiple prompts).\n"); - printf(" --color colorise output to distinguish prompt and user input from generations\n"); - printf(" -s SEED, --seed SEED RNG seed (default: -1, use random seed for < 0)\n"); - printf(" -t N, --threads N number of threads to use during generation (default: %d)\n", params.n_threads); - printf(" -tb N, --threads-batch N\n"); - printf(" number of threads to use during batch and prompt processing (default: same as --threads)\n"); - printf(" -td N, --threads-draft N"); - printf(" number of threads to use during generation (default: same as --threads)\n"); - printf(" -tbd N, --threads-batch-draft N\n"); - printf(" number of threads to use during batch and prompt processing (default: same as --threads-draft)\n"); - printf(" -p PROMPT, --prompt PROMPT\n"); - printf(" prompt to start generation with (default: empty)\n"); - printf(" -e, --escape process prompt escapes sequences (\\n, \\r, \\t, \\', \\\", \\\\)\n"); - printf(" --prompt-cache FNAME file to cache prompt state for faster startup (default: none)\n"); - printf(" --prompt-cache-all if specified, saves user input and generations to cache as well.\n"); - printf(" not supported with --interactive or other interactive options\n"); - printf(" --prompt-cache-ro if specified, uses the prompt cache but does not update it.\n"); - printf(" --random-prompt start with a randomized prompt.\n"); - printf(" --in-prefix-bos prefix BOS to user inputs, preceding the `--in-prefix` string\n"); - printf(" --in-prefix STRING string to prefix user inputs with (default: empty)\n"); - printf(" --in-suffix STRING string to suffix after user inputs with (default: empty)\n"); - printf(" -f FNAME, --file FNAME\n"); - printf(" prompt file to start generation.\n"); - printf(" -bf FNAME, --binary-file FNAME\n"); - printf(" binary file containing multiple choice tasks.\n"); - printf(" -n N, --n-predict N number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)\n", params.n_predict); - printf(" -c N, --ctx-size N size of the prompt context (default: %d, 0 = loaded from model)\n", params.n_ctx); - printf(" -b N, --batch-size N logical maximum batch size (default: %d)\n", params.n_batch); - printf(" -ub N, --ubatch-size N\n"); - printf(" physical maximum batch size (default: %d)\n", params.n_ubatch); - printf(" --samplers samplers that will be used for generation in the order, separated by \';\'\n"); - printf(" (default: %s)\n", sampler_type_names.c_str()); - printf(" --sampling-seq simplified sequence for samplers that will be used (default: %s)\n", sampler_type_chars.c_str()); - printf(" --top-k N top-k sampling (default: %d, 0 = disabled)\n", sparams.top_k); - printf(" --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)sparams.top_p); - printf(" --min-p N min-p sampling (default: %.1f, 0.0 = disabled)\n", (double)sparams.min_p); - printf(" --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)sparams.tfs_z); - printf(" --typical N locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)\n", (double)sparams.typical_p); - printf(" --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)\n", sparams.penalty_last_n); - printf(" --repeat-penalty N penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)\n", (double)sparams.penalty_repeat); - printf(" --presence-penalty N repeat alpha presence penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.penalty_present); - printf(" --frequency-penalty N repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.penalty_freq); - printf(" --dynatemp-range N dynamic temperature range (default: %.1f, 0.0 = disabled)\n", (double)sparams.dynatemp_range); - printf(" --dynatemp-exp N dynamic temperature exponent (default: %.1f)\n", (double)sparams.dynatemp_exponent); - printf(" --mirostat N use Mirostat sampling.\n"); - printf(" Top K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.\n"); - printf(" (default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)\n", sparams.mirostat); - printf(" --mirostat-lr N Mirostat learning rate, parameter eta (default: %.1f)\n", (double)sparams.mirostat_eta); - printf(" --mirostat-ent N Mirostat target entropy, parameter tau (default: %.1f)\n", (double)sparams.mirostat_tau); - printf(" -l TOKEN_ID(+/-)BIAS, --logit-bias TOKEN_ID(+/-)BIAS\n"); - printf(" modifies the likelihood of token appearing in the completion,\n"); - printf(" i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n"); - printf(" or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'\n"); - printf(" --grammar GRAMMAR BNF-like grammar to constrain generations (see samples in grammars/ dir)\n"); - printf(" --grammar-file FNAME file to read grammar from\n"); - printf(" -j SCHEMA, --json-schema SCHEMA\n"); - printf(" JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object.\n"); - printf(" For schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead\n"); - printf(" --cfg-negative-prompt PROMPT\n"); - printf(" negative prompt to use for guidance. (default: empty)\n"); - printf(" --cfg-negative-prompt-file FNAME\n"); - printf(" negative prompt file to use for guidance. (default: empty)\n"); - printf(" --cfg-scale N strength of guidance (default: %f, 1.0 = disable)\n", sparams.cfg_scale); - printf(" --rope-scaling {none,linear,yarn}\n"); - printf(" RoPE frequency scaling method, defaults to linear unless specified by the model\n"); - printf(" --rope-scale N RoPE context scaling factor, expands context by a factor of N\n"); - printf(" --rope-freq-base N RoPE base frequency, used by NTK-aware scaling (default: loaded from model)\n"); - printf(" --rope-freq-scale N RoPE frequency scaling factor, expands context by a factor of 1/N\n"); - printf(" --yarn-orig-ctx N YaRN: original context size of model (default: 0 = model training context size)\n"); - printf(" --yarn-ext-factor N YaRN: extrapolation mix factor (default: 1.0, 0.0 = full interpolation)\n"); - printf(" --yarn-attn-factor N YaRN: scale sqrt(t) or attention magnitude (default: 1.0)\n"); - printf(" --yarn-beta-slow N YaRN: high correction dim or alpha (default: %.1f)\n", params.yarn_beta_slow); - printf(" --yarn-beta-fast N YaRN: low correction dim or beta (default: %.1f)\n", params.yarn_beta_fast); - printf(" --pooling {none,mean,cls}\n"); - printf(" pooling type for embeddings, use model default if unspecified\n"); - printf(" -dt N, --defrag-thold N\n"); - printf(" KV cache defragmentation threshold (default: %.1f, < 0 - disabled)\n", params.defrag_thold); - printf(" --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n"); - printf(" --penalize-nl penalize newline tokens\n"); - printf(" --temp N temperature (default: %.1f)\n", (double)sparams.temp); - printf(" --all-logits return logits for all tokens in the batch (default: disabled)\n"); - printf(" --hellaswag compute HellaSwag score over random tasks from datafile supplied with -f\n"); - printf(" --hellaswag-tasks N number of tasks to use when computing the HellaSwag score (default: %zu)\n", params.hellaswag_tasks); - printf(" --winogrande compute Winogrande score over random tasks from datafile supplied with -f\n"); - printf(" --winogrande-tasks N number of tasks to use when computing the Winogrande score (default: %zu)\n", params.winogrande_tasks); - printf(" --multiple-choice compute multiple choice score over random tasks from datafile supplied with -f\n"); - printf(" --multiple-choice-tasks N number of tasks to use when computing the multiple choice score (default: %zu)\n", params.winogrande_tasks); - printf(" --kl-divergence computes KL-divergence to logits provided via --kl-divergence-base\n"); - printf(" --keep N number of tokens to keep from the initial prompt (default: %d, -1 = all)\n", params.n_keep); - printf(" --draft N number of tokens to draft for speculative decoding (default: %d)\n", params.n_draft); - printf(" --chunks N max number of chunks to process (default: %d, -1 = all)\n", params.n_chunks); - printf(" -np N, --parallel N number of parallel sequences to decode (default: %d)\n", params.n_parallel); - printf(" -ns N, --sequences N number of sequences to decode (default: %d)\n", params.n_sequences); - printf(" -ps N, --p-split N speculative decoding split probability (default: %.1f)\n", (double)params.p_split); - printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: disabled)\n"); - printf(" -fa, --flash-attn enable Flash Attention (default: %s)\n", params.flash_attn ? "enabled" : "disabled"); - printf(" --mmproj MMPROJ_FILE path to a multimodal projector file for LLaVA. see examples/llava/README.md\n"); - printf(" --image IMAGE_FILE path to an image file. use with multimodal models. Specify multiple times for batching\n"); + struct option_info { + LLAMA_COMMON_ATTRIBUTE_FORMAT(4, 5) + option_info(const std::string & tags, const char * args, const char * desc, ...) : tags(tags), args(args), desc(desc) { + va_list args_list; + va_start(args_list, desc); + char buffer[1024]; + vsnprintf(buffer, sizeof(buffer), desc, args_list); + va_end(args_list); + this->desc = buffer; + } + + option_info(const std::string & grp) : grp(grp) {} + + std::string tags; + std::string args; + std::string desc; + std::string grp; + }; + + std::vector options; + + // TODO: filter by tags + + options.push_back({ "general" }); + options.push_back({ "*", "-h, --help, --usage", "print usage and exit" }); + options.push_back({ "*", " --version", "show version and build info" }); + options.push_back({ "*", "-v, --verbose", "print verbose information" }); + options.push_back({ "*", " --verbosity N", "set specific verbosity level (default: %d)", params.verbosity }); + options.push_back({ "*", " --verbose-prompt", "print a verbose prompt before generation (default: %s)", params.verbose_prompt ? "true" : "false" }); + options.push_back({ "*", " --no-display-prompt", "don't print prompt at generation (default: %s)", !params.display_prompt ? "true" : "false" }); + options.push_back({ "*", "-co, --color", "colorise output to distinguish prompt and user input from generations (default: %s)", params.use_color ? "true" : "false" }); + options.push_back({ "*", "-s, --seed SEED", "RNG seed (default: %d, use random seed for < 0)", params.seed }); + options.push_back({ "*", "-t, --threads N", "number of threads to use during generation (default: %d)", params.n_threads }); + options.push_back({ "*", "-tb, --threads-batch N", "number of threads to use during batch and prompt processing (default: same as --threads)" }); + options.push_back({ "speculative", "-td, --threads-draft N", "number of threads to use during generation (default: same as --threads)" }); + options.push_back({ "speculative", "-tbd, --threads-batch-draft N", + "number of threads to use during batch and prompt processing (default: same as --threads-draft)" }); + options.push_back({ "speculative", " --draft N", "number of tokens to draft for speculative decoding (default: %d)", params.n_draft }); + options.push_back({ "speculative", "-ps, --p-split N", "speculative decoding split probability (default: %.1f)", (double)params.p_split }); + options.push_back({ "*", "-lcs, --lookup-cache-static FNAME", + "path to static lookup cache to use for lookup decoding (not updated by generation)" }); + options.push_back({ "*", "-lcd, --lookup-cache-dynamic FNAME", + "path to dynamic lookup cache to use for lookup decoding (updated by generation)" }); + + options.push_back({ "*", "-c, --ctx-size N", "size of the prompt context (default: %d, 0 = loaded from model)", params.n_ctx }); + options.push_back({ "*", "-n, --predict N", "number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)", params.n_predict }); + options.push_back({ "*", "-b, --batch-size N", "logical maximum batch size (default: %d)", params.n_batch }); + options.push_back({ "*", "-ub, --ubatch-size N", "physical maximum batch size (default: %d)", params.n_ubatch }); + options.push_back({ "*", " --keep N", "number of tokens to keep from the initial prompt (default: %d, -1 = all)", params.n_keep }); + options.push_back({ "*", " --chunks N", "max number of chunks to process (default: %d, -1 = all)", params.n_chunks }); + options.push_back({ "*", "-fa, --flash-attn", "enable Flash Attention (default: %s)", params.flash_attn ? "enabled" : "disabled" }); + options.push_back({ "*", "-p, --prompt PROMPT", "prompt to start generation with (default: '%s')", params.prompt.c_str() }); + options.push_back({ "*", "-f, --file FNAME", "a file containing the prompt (default: none)" }); + options.push_back({ "*", " --in-file FNAME", "an input file (repeat to specify multiple files)" }); + options.push_back({ "*", "-bf, --binary-file FNAME", "binary file containing the prompt (default: none)" }); + options.push_back({ "*", "-e, --escape", "process escapes sequences (\\n, \\r, \\t, \\', \\\", \\\\) (default: %s)", params.escape ? "true" : "false" }); + options.push_back({ "*", " --no-escape", "do not process escape sequences" }); + options.push_back({ "main", "-ptc, --print-token-count N", "print token count every N tokens (default: %d)", params.n_print }); + options.push_back({ "main", " --prompt-cache FNAME", "file to cache prompt state for faster startup (default: none)" }); + options.push_back({ "main", " --prompt-cache-all", "if specified, saves user input and generations to cache as well\n" + "not supported with --interactive or other interactive options" }); + options.push_back({ "main", " --prompt-cache-ro", "if specified, uses the prompt cache but does not update it" }); + options.push_back({ "main", "-r, --reverse-prompt PROMPT", + "halt generation at PROMPT, return control in interactive mode\n" + "can be specified more than once for multiple prompts" }); + options.push_back({ "main", "-sp, --special", "special tokens output enabled (default: %s)", params.special ? "true" : "false" }); + options.push_back({ "main", "-cnv, --conversation", "run in conversation mode (does not print special tokens and suffix/prefix) (default: %s)", params.conversation ? "true" : "false" }); + options.push_back({ "main infill", "-i, --interactive", "run in interactive mode (default: %s)", params.interactive ? "true" : "false" }); + options.push_back({ "main infill", "-if, --interactive-first", "run in interactive mode and wait for input right away (default: %s)", params.interactive_first ? "true" : "false" }); + options.push_back({ "main infill", "-mli, --multiline-input", "allows you to write or paste multiple lines without ending each in '\\'" }); + options.push_back({ "main infill", " --in-prefix-bos", "prefix BOS to user inputs, preceding the `--in-prefix` string" }); + options.push_back({ "main infill", " --in-prefix STRING", "string to prefix user inputs with (default: empty)" }); + options.push_back({ "main infill", " --in-suffix STRING", "string to suffix after user inputs with (default: empty)" }); + + options.push_back({ "sampling" }); + options.push_back({ "*", " --samplers SAMPLERS", "samplers that will be used for generation in the order, separated by \';\'\n" + "(default: %s)", sampler_type_names.c_str() }); + options.push_back({ "*", " --sampling-seq SEQUENCE", + "simplified sequence for samplers that will be used (default: %s)", sampler_type_chars.c_str() }); + options.push_back({ "*", " --ignore-eos", "ignore end of stream token and continue generating (implies --logit-bias EOS-inf)" }); + options.push_back({ "*", " --penalize-nl", "penalize newline tokens (default: %s)", sparams.penalize_nl ? "true" : "false" }); + options.push_back({ "*", " --temp N", "temperature (default: %.1f)", (double)sparams.temp }); + options.push_back({ "*", " --top-k N", "top-k sampling (default: %d, 0 = disabled)", sparams.top_k }); + options.push_back({ "*", " --top-p N", "top-p sampling (default: %.1f, 1.0 = disabled)", (double)sparams.top_p }); + options.push_back({ "*", " --min-p N", "min-p sampling (default: %.1f, 0.0 = disabled)", (double)sparams.min_p }); + options.push_back({ "*", " --tfs N", "tail free sampling, parameter z (default: %.1f, 1.0 = disabled)", (double)sparams.tfs_z }); + options.push_back({ "*", " --typical N", "locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)", (double)sparams.typical_p }); + options.push_back({ "*", " --repeat-last-n N", "last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)", sparams.penalty_last_n }); + options.push_back({ "*", " --repeat-penalty N", "penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)", (double)sparams.penalty_repeat }); + options.push_back({ "*", " --presence-penalty N", "repeat alpha presence penalty (default: %.1f, 0.0 = disabled)", (double)sparams.penalty_present }); + options.push_back({ "*", " --frequency-penalty N", "repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)", (double)sparams.penalty_freq }); + options.push_back({ "*", " --dynatemp-range N", "dynamic temperature range (default: %.1f, 0.0 = disabled)", (double)sparams.dynatemp_range }); + options.push_back({ "*", " --dynatemp-exp N", "dynamic temperature exponent (default: %.1f)", (double)sparams.dynatemp_exponent }); + options.push_back({ "*", " --mirostat N", "use Mirostat sampling.\n" + "Top K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.\n" + "(default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)", sparams.mirostat }); + options.push_back({ "*", " --mirostat-lr N", "Mirostat learning rate, parameter eta (default: %.1f)", (double)sparams.mirostat_eta }); + options.push_back({ "*", " --mirostat-ent N", "Mirostat target entropy, parameter tau (default: %.1f)", (double)sparams.mirostat_tau }); + options.push_back({ "*", " -l TOKEN_ID(+/-)BIAS", "modifies the likelihood of token appearing in the completion,\n" + "i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n" + "or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'" }); + options.push_back({ "main", " --cfg-negative-prompt PROMPT", + "negative prompt to use for guidance (default: '%s')", sparams.cfg_negative_prompt.c_str() }); + options.push_back({ "main", " --cfg-negative-prompt-file FNAME", + "negative prompt file to use for guidance" }); + options.push_back({ "main", " --cfg-scale N", "strength of guidance (default: %.1f, 1.0 = disable)", (double)sparams.cfg_scale }); + + options.push_back({ "grammar" }); + options.push_back({ "*", " --grammar GRAMMAR", "BNF-like grammar to constrain generations (see samples in grammars/ dir) (default: '%s')", sparams.grammar.c_str() }); + options.push_back({ "*", " --grammar-file FNAME", "file to read grammar from" }); + options.push_back({ "*", "-j, --json-schema SCHEMA", + "JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object\n" + "For schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead" }); + + options.push_back({ "embedding" }); + options.push_back({ "embedding", " --pooling {none,mean,cls}", + "pooling type for embeddings, use model default if unspecified" }); + + options.push_back({ "context hacking" }); + options.push_back({ "*", " --rope-scaling {none,linear,yarn}", + "RoPE frequency scaling method, defaults to linear unless specified by the model" }); + options.push_back({ "*", " --rope-scale N", "RoPE context scaling factor, expands context by a factor of N" }); + options.push_back({ "*", " --rope-freq-base N", "RoPE base frequency, used by NTK-aware scaling (default: loaded from model)" }); + options.push_back({ "*", " --rope-freq-scale N", "RoPE frequency scaling factor, expands context by a factor of 1/N" }); + options.push_back({ "*", " --yarn-orig-ctx N", "YaRN: original context size of model (default: %d = model training context size)", params.yarn_orig_ctx }); + options.push_back({ "*", " --yarn-ext-factor N", "YaRN: extrapolation mix factor (default: %.1f, 0.0 = full interpolation)", (double)params.yarn_ext_factor }); + options.push_back({ "*", " --yarn-attn-factor N", "YaRN: scale sqrt(t) or attention magnitude (default: %.1f)", (double)params.yarn_attn_factor }); + options.push_back({ "*", " --yarn-beta-slow N", "YaRN: high correction dim or alpha (default: %.1f)", (double)params.yarn_beta_slow }); + options.push_back({ "*", " --yarn-beta-fast N", "YaRN: low correction dim or beta (default: %.1f)", (double)params.yarn_beta_fast }); + options.push_back({ "*", "-gan, --grp-attn-n N", "group-attention factor (default: %d)", params.grp_attn_n }); + options.push_back({ "*", "-gaw, --grp-attn-w N", "group-attention width (default: %.1f)", (double)params.grp_attn_w }); + options.push_back({ "*", "-dkvc, --dump-kv-cache", "verbose print of the KV cache" }); + options.push_back({ "*", "-nkvo, --no-kv-offload", "disable KV offload" }); + options.push_back({ "*", "-ctk, --cache-type-k TYPE", "KV cache data type for K (default: %s)", params.cache_type_k.c_str() }); + options.push_back({ "*", "-ctv, --cache-type-v TYPE", "KV cache data type for V (default: %s)", params.cache_type_v.c_str() }); + + options.push_back({ "perplexity" }); + options.push_back({ "perplexity", " --all-logits", "return logits for all tokens in the batch (default: %s)", params.logits_all ? "true" : "false" }); + options.push_back({ "perplexity", " --hellaswag", "compute HellaSwag score over random tasks from datafile supplied with -f" }); + options.push_back({ "perplexity", " --hellaswag-tasks N", "number of tasks to use when computing the HellaSwag score (default: %zu)", params.hellaswag_tasks }); + options.push_back({ "perplexity", " --winogrande", "compute Winogrande score over random tasks from datafile supplied with -f" }); + options.push_back({ "perplexity", " --winogrande-tasks N", "number of tasks to use when computing the Winogrande score (default: %zu)", params.winogrande_tasks }); + options.push_back({ "perplexity", " --multiple-choice", "compute multiple choice score over random tasks from datafile supplied with -f" }); + options.push_back({ "perplexity", " --multiple-choice-tasks N", + "number of tasks to use when computing the multiple choice score (default: %zu)", params.multiple_choice_tasks }); + options.push_back({ "perplexity", " --kl-divergence", "computes KL-divergence to logits provided via --kl-divergence-base" }); + options.push_back({ "perplexity", " --ppl-stride N", "stride for perplexity calculation (default: %d)", params.ppl_stride }); + options.push_back({ "perplexity", " --ppl-output-type {0,1}", + "output type for perplexity calculation (default: %d)", params.ppl_output_type }); + + options.push_back({ "parallel" }); + options.push_back({ "*", "-dt, --defrag-thold N", "KV cache defragmentation threshold (default: %.1f, < 0 - disabled)", (double)params.defrag_thold }); + options.push_back({ "*", "-np, --parallel N", "number of parallel sequences to decode (default: %d)", params.n_parallel }); + options.push_back({ "*", "-ns, --sequences N", "number of sequences to decode (default: %d)", params.n_sequences }); + options.push_back({ "*", "-cb, --cont-batching", "enable continuous batching (a.k.a dynamic batching) (default: %s)", params.cont_batching ? "enabled" : "disabled" }); + + options.push_back({ "multi-modality" }); + options.push_back({ "*", " --mmproj FILE", "path to a multimodal projector file for LLaVA. see examples/llava/README.md" }); + options.push_back({ "*", " --image FILE", "path to an image file. use with multimodal models. Specify multiple times for batching" }); + + options.push_back({ "backend" }); + options.push_back({ "*", " --rpc SERVERS", "comma separated list of RPC servers" }); if (llama_supports_mlock()) { - printf(" --mlock force system to keep model in RAM rather than swapping or compressing\n"); + options.push_back({ "*", " --mlock", "force system to keep model in RAM rather than swapping or compressing" }); } if (llama_supports_mmap()) { - printf(" --no-mmap do not memory-map model (slower load but may reduce pageouts if not using mlock)\n"); + options.push_back({ "*", " --no-mmap", "do not memory-map model (slower load but may reduce pageouts if not using mlock)" }); } - printf(" --numa TYPE attempt optimizations that help on some NUMA systems\n"); - printf(" - distribute: spread execution evenly over all nodes\n"); - printf(" - isolate: only spawn threads on CPUs on the node that execution started on\n"); - printf(" - numactl: use the CPU map provided by numactl\n"); - printf(" if run without this previously, it is recommended to drop the system page cache before using this\n"); - printf(" see https://github.com/ggerganov/llama.cpp/issues/1437\n"); + options.push_back({ "*", " --numa TYPE", "attempt optimizations that help on some NUMA systems\n" + " - distribute: spread execution evenly over all nodes\n" + " - isolate: only spawn threads on CPUs on the node that execution started on\n" + " - numactl: use the CPU map provided by numactl\n" + "if run without this previously, it is recommended to drop the system page cache before using this\n" + "see https://github.com/ggerganov/llama.cpp/issues/1437" }); + if (llama_supports_gpu_offload()) { - printf(" -ngl N, --n-gpu-layers N\n"); - printf(" number of layers to store in VRAM\n"); - printf(" -ngld N, --n-gpu-layers-draft N\n"); - printf(" number of layers to store in VRAM for the draft model\n"); - printf(" -sm SPLIT_MODE, --split-mode SPLIT_MODE\n"); - printf(" how to split the model across multiple GPUs, one of:\n"); - printf(" - none: use one GPU only\n"); - printf(" - layer (default): split layers and KV across GPUs\n"); - printf(" - row: split rows across GPUs\n"); - printf(" -ts SPLIT, --tensor-split SPLIT\n"); - printf(" fraction of the model to offload to each GPU, comma-separated list of proportions, e.g. 3,1\n"); - printf(" -mg i, --main-gpu i the GPU to use for the model (with split-mode = none),\n"); - printf(" or for intermediate results and KV (with split-mode = row) (default: %d)\n", params.main_gpu); + options.push_back({ "*", "-ngl, --gpu-layers N", + "number of layers to store in VRAM" }); + options.push_back({ "*", "-ngld, --gpu-layers-draft N", + "number of layers to store in VRAM for the draft model" }); + options.push_back({ "*", "-sm, --split-mode SPLIT_MODE", + "how to split the model across multiple GPUs, one of:\n" + " - none: use one GPU only\n" + " - layer (default): split layers and KV across GPUs\n" + " - row: split rows across GPUs" }); + options.push_back({ "*", "-ts, --tensor-split SPLIT", + "fraction of the model to offload to each GPU, comma-separated list of proportions, e.g. 3,1" }); + options.push_back({ "*", "-mg, --main-gpu i", "the GPU to use for the model (with split-mode = none),\n" + "or for intermediate results and KV (with split-mode = row) (default: %d)", params.main_gpu }); } - printf(" --rpc SERVERS comma separated list of RPC servers\n"); - printf(" --verbose-prompt print a verbose prompt before generation (default: %s)\n", params.verbose_prompt ? "true" : "false"); - printf(" --no-display-prompt don't print prompt at generation (default: %s)\n", !params.display_prompt ? "true" : "false"); - printf(" -gan N, --grp-attn-n N\n"); - printf(" group-attention factor (default: %d)\n", params.grp_attn_n); - printf(" -gaw N, --grp-attn-w N\n"); - printf(" group-attention width (default: %.1f)\n", (double)params.grp_attn_w); - printf(" -dkvc, --dump-kv-cache\n"); - printf(" verbose print of the KV cache\n"); - printf(" -nkvo, --no-kv-offload\n"); - printf(" disable KV offload\n"); - printf(" -ctk TYPE, --cache-type-k TYPE\n"); - printf(" KV cache data type for K (default: %s)\n", params.cache_type_k.c_str()); - printf(" -ctv TYPE, --cache-type-v TYPE\n"); - printf(" KV cache data type for V (default: %s)\n", params.cache_type_v.c_str()); - printf(" --simple-io use basic IO for better compatibility in subprocesses and limited consoles\n"); - printf(" --lora FNAME apply LoRA adapter (implies --no-mmap)\n"); - printf(" --lora-scaled FNAME S apply LoRA adapter with user defined scaling S (implies --no-mmap)\n"); - printf(" --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n"); - printf(" --control-vector FNAME\n"); - printf(" add a control vector\n"); - printf(" --control-vector-scaled FNAME S\n"); - printf(" add a control vector with user defined scaling S\n"); - printf(" --control-vector-layer-range START END\n"); - printf(" layer range to apply the control vector(s) to, start and end inclusive\n"); - printf(" -m FNAME, --model FNAME\n"); - printf(" model path (default: models/$filename with filename from --hf-file or --model-url if set, otherwise %s)\n", DEFAULT_MODEL_PATH); - printf(" -md FNAME, --model-draft FNAME\n"); - printf(" draft model for speculative decoding (default: unused)\n"); - printf(" -mu MODEL_URL, --model-url MODEL_URL\n"); - printf(" model download url (default: unused)\n"); - printf(" -hfr REPO, --hf-repo REPO\n"); - printf(" Hugging Face model repository (default: unused)\n"); - printf(" -hff FILE, --hf-file FILE\n"); - printf(" Hugging Face model file (default: unused)\n"); - printf(" -ld LOGDIR, --logdir LOGDIR\n"); - printf(" path under which to save YAML logs (no logging if unset)\n"); - printf(" -lcs FNAME, --lookup-cache-static FNAME\n"); - printf(" path to static lookup cache to use for lookup decoding (not updated by generation)\n"); - printf(" -lcd FNAME, --lookup-cache-dynamic FNAME\n"); - printf(" path to dynamic lookup cache to use for lookup decoding (updated by generation)\n"); - printf(" --override-kv KEY=TYPE:VALUE\n"); - printf(" advanced option to override model metadata by key. may be specified multiple times.\n"); - printf(" types: int, float, bool, str. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n"); - printf(" -ptc N, --print-token-count N\n"); - printf(" print token count every N tokens (default: %d)\n", params.n_print); - printf(" --check-tensors check model tensor data for invalid values\n"); - printf("\n"); + + options.push_back({ "model" }); + options.push_back({ "*", " --check-tensors", "check model tensor data for invalid values (default: %s)", params.check_tensors ? "true" : "false" }); + options.push_back({ "*", " --override-kv KEY=TYPE:VALUE", + "advanced option to override model metadata by key. may be specified multiple times.\n" + "types: int, float, bool, str. example: --override-kv tokenizer.ggml.add_bos_token=bool:false" }); + options.push_back({ "*", " --lora FNAME", "apply LoRA adapter (implies --no-mmap)" }); + options.push_back({ "*", " --lora-scaled FNAME S", "apply LoRA adapter with user defined scaling S (implies --no-mmap)" }); + options.push_back({ "*", " --lora-base FNAME", "optional model to use as a base for the layers modified by the LoRA adapter" }); + options.push_back({ "*", " --control-vector FNAME", "add a control vector" }); + options.push_back({ "*", " --control-vector-scaled FNAME SCALE", + "add a control vector with user defined scaling SCALE" }); + options.push_back({ "*", " --control-vector-layer-range START END", + "layer range to apply the control vector(s) to, start and end inclusive" }); + options.push_back({ "*", "-m, --model FNAME", "model path (default: models/$filename with filename from --hf-file\n" + "or --model-url if set, otherwise %s)", DEFAULT_MODEL_PATH }); + options.push_back({ "*", "-md, --model-draft FNAME", "draft model for speculative decoding (default: unused)" }); + options.push_back({ "*", "-mu, --model-url MODEL_URL", "model download url (default: unused)" }); + options.push_back({ "*", "-hfr, --hf-repo REPO", "Hugging Face model repository (default: unused)" }); + options.push_back({ "*", "-hff, --hf-file FILE", "Hugging Face model file (default: unused)" }); + + options.push_back({ "retrieval" }); + options.push_back({ "retrieval", " --context-file FNAME", "file to load context from (repeat to specify multiple files)" }); + options.push_back({ "retrieval", " --chunk-size N", "minimum length of embedded text chunks (default: %d)", params.chunk_size }); + options.push_back({ "retrieval", " --chunk-separator STRING", + "separator between chunks (default: '%s')", params.chunk_separator.c_str() }); + + options.push_back({ "passkey" }); + options.push_back({ "passkey", " --junk N", "number of times to repeat the junk text (default: %d)", params.n_junk }); + options.push_back({ "passkey", " --pos N", "position of the passkey in the junk text (default: %d)", params.i_pos }); + + options.push_back({ "imatrix" }); + options.push_back({ "imatrix", "-o, --output FNAME", "output file (default: '%s')", params.out_file.c_str() }); + options.push_back({ "imatrix", " --output-frequency N", "output the imatrix every N iterations (default: %d)", params.n_out_freq }); + options.push_back({ "imatrix", " --save-frequency N", "save an imatrix copy every N iterations (default: %d)", params.n_save_freq }); + options.push_back({ "imatrix", " --process-output", "collect data for the output tensor (default: %s)", params.process_output ? "true" : "false" }); + options.push_back({ "imatrix", " --no-ppl", "do not compute perplexity (default: %s)", params.compute_ppl ? "true" : "false" }); + options.push_back({ "imatrix", " --chunk N", "start processing the input from chunk N (default: %d)", params.i_chunk }); + + options.push_back({ "bench" }); + options.push_back({ "bench", "-pps", "is the prompt shared across parallel sequences (default: %s)", params.is_pp_shared ? "true" : "false" }); + options.push_back({ "bench", "-npp n0,n1,...", "number of prompt tokens" }); + options.push_back({ "bench", "-ntg n0,n1,...", "number of text generation tokens" }); + options.push_back({ "bench", "-npl n0,n1,...", "number of parallel prompts" }); + + options.push_back({ "server" }); + options.push_back({ "server", " --host HOST", "ip address to listen (default: %s)", params.hostname.c_str() }); + options.push_back({ "server", " --port PORT", "port to listen (default: %d)", params.port }); + options.push_back({ "server", " --path PATH", "path to serve static files from (default: %s)", params.public_path.c_str() }); + options.push_back({ "server", " --embedding(s)", "enable embedding endpoint (default: %s)", params.embedding ? "enabled" : "disabled" }); + options.push_back({ "server", " --api-key KEY", "API key to use for authentication (default: none)" }); + options.push_back({ "server", " --api-key-file FNAME", "path to file containing API keys (default: none)" }); + options.push_back({ "server", " --ssl-key-file FNAME", "path to file a PEM-encoded SSL private key" }); + options.push_back({ "server", " --ssl-cert-file FNAME", "path to file a PEM-encoded SSL certificate" }); + options.push_back({ "server", " --timeout N", "server read/write timeout in seconds (default: %d)", params.timeout_read }); + options.push_back({ "server", " --threads-http N", "number of threads used to process HTTP requests (default: %d)", params.n_threads_http }); + options.push_back({ "server", " --system-prompt-file FNAME", + "set a file to load a system prompt (initial prompt of all slots), this is useful for chat applications" }); + options.push_back({ "server", " --log-format {text,json}", + "log output format: json or text (default: json)" }); + options.push_back({ "server", " --metrics", "enable prometheus compatible metrics endpoint (default: %s)", params.endpoint_metrics ? "enabled" : "disabled" }); + options.push_back({ "server", " --no-slots", "disables slots monitoring endpoint (default: %s)", params.endpoint_slots ? "enabled" : "disabled" }); + options.push_back({ "server", " --slot-save-path PATH", "path to save slot kv cache (default: disabled)" }); + options.push_back({ "server", " --chat-template JINJA_TEMPLATE", + "set custom jinja chat template (default: template taken from model's metadata)\n" + "only commonly used templates are accepted:\n" + "https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template" }); + #ifndef LOG_DISABLE_LOGS - log_print_usage(); + options.push_back({ "logging" }); + options.push_back({ "*", " --simple-io", "use basic IO for better compatibility in subprocesses and limited consoles" }); + options.push_back({ "*", "-ld, --logdir LOGDIR", "path under which to save YAML logs (no logging if unset)" }); + options.push_back({ "logging", " --log-test", "Run simple logging test" }); + options.push_back({ "logging", " --log-disable", "Disable trace logs" }); + options.push_back({ "logging", " --log-enable", "Enable trace logs" }); + options.push_back({ "logging", " --log-file FNAME", "Specify a log filename (without extension)" }); + options.push_back({ "logging", " --log-new", "Create a separate new log file on start. " + "Each log file will have unique name: \"..log\"" }); + options.push_back({ "logging", " --log-append", "Don't truncate the old log file." }); #endif // LOG_DISABLE_LOGS + + printf("usage: %s [options]\n", argv[0]); + + for (const auto & o : options) { + if (!o.grp.empty()) { + printf("\n%s:\n\n", o.grp.c_str()); + continue; + } + printf(" %-32s", o.args.c_str()); + if (o.args.length() > 30) { + printf("\n%34s", ""); + } + + const auto desc = o.desc; + size_t start = 0; + size_t end = desc.find('\n'); + while (end != std::string::npos) { + printf("%s\n%34s", desc.substr(start, end - start).c_str(), ""); + start = end + 1; + end = desc.find('\n', start); + } + + printf("%s\n", desc.substr(start).c_str()); + } + printf("\n"); } -std::string get_system_info(const gpt_params & params) { +std::string gpt_params_get_system_info(const gpt_params & params) { std::ostringstream os; os << "system_info: n_threads = " << params.n_threads; @@ -1656,27 +1991,141 @@ std::string get_system_info(const gpt_params & params) { return os.str(); } -std::string gpt_random_prompt(std::mt19937 & rng) { - const int r = rng() % 10; - switch (r) { - case 0: return "So"; - case 1: return "Once upon a time"; - case 2: return "When"; - case 3: return "The"; - case 4: return "After"; - case 5: return "If"; - case 6: return "import"; - case 7: return "He"; - case 8: return "She"; - case 9: return "They"; +// +// String utils +// + +std::vector string_split(std::string input, char separator) { + std::vector parts; + size_t separator_pos = input.find(separator); + while (separator_pos != std::string::npos) { + std::string part = input.substr(0, separator_pos); + parts.emplace_back(part); + input = input.substr(separator_pos + 1); + separator_pos = input.find(separator); + } + parts.emplace_back(input); + return parts; +} + +std::string string_strip(const std::string & str) { + size_t start = 0; + size_t end = str.size(); + while (start < end && std::isspace(str[start])) { + start++; + } + while (end > start && std::isspace(str[end - 1])) { + end--; + } + return str.substr(start, end - start); +} + +std::string string_get_sortable_timestamp() { + using clock = std::chrono::system_clock; + + const clock::time_point current_time = clock::now(); + const time_t as_time_t = clock::to_time_t(current_time); + char timestamp_no_ns[100]; + std::strftime(timestamp_no_ns, 100, "%Y_%m_%d-%H_%M_%S", std::localtime(&as_time_t)); + + const int64_t ns = std::chrono::duration_cast( + current_time.time_since_epoch() % 1000000000).count(); + char timestamp_ns[11]; + snprintf(timestamp_ns, 11, "%09" PRId64, ns); + + return std::string(timestamp_no_ns) + "." + std::string(timestamp_ns); +} + +void string_process_escapes(std::string & input) { + std::size_t input_len = input.length(); + std::size_t output_idx = 0; + + for (std::size_t input_idx = 0; input_idx < input_len; ++input_idx) { + if (input[input_idx] == '\\' && input_idx + 1 < input_len) { + switch (input[++input_idx]) { + case 'n': input[output_idx++] = '\n'; break; + case 'r': input[output_idx++] = '\r'; break; + case 't': input[output_idx++] = '\t'; break; + case '\'': input[output_idx++] = '\''; break; + case '\"': input[output_idx++] = '\"'; break; + case '\\': input[output_idx++] = '\\'; break; + case 'x': + // Handle \x12, etc + if (input_idx + 2 < input_len) { + const char x[3] = { input[input_idx + 1], input[input_idx + 2], 0 }; + char *err_p = nullptr; + const long val = std::strtol(x, &err_p, 16); + if (err_p == x + 2) { + input_idx += 2; + input[output_idx++] = char(val); + break; + } + } + // fall through + default: input[output_idx++] = '\\'; + input[output_idx++] = input[input_idx]; break; + } + } else { + input[output_idx++] = input[input_idx]; + } } - GGML_UNREACHABLE(); + input.resize(output_idx); } +bool string_parse_kv_override(const char * data, std::vector & overrides) { + const char * sep = strchr(data, '='); + if (sep == nullptr || sep - data >= 128) { + fprintf(stderr, "%s: malformed KV override '%s'\n", __func__, data); + return false; + } + llama_model_kv_override kvo; + std::strncpy(kvo.key, data, sep - data); + kvo.key[sep - data] = 0; + sep++; + if (strncmp(sep, "int:", 4) == 0) { + sep += 4; + kvo.tag = LLAMA_KV_OVERRIDE_TYPE_INT; + kvo.val_i64 = std::atol(sep); + } else if (strncmp(sep, "float:", 6) == 0) { + sep += 6; + kvo.tag = LLAMA_KV_OVERRIDE_TYPE_FLOAT; + kvo.val_f64 = std::atof(sep); + } else if (strncmp(sep, "bool:", 5) == 0) { + sep += 5; + kvo.tag = LLAMA_KV_OVERRIDE_TYPE_BOOL; + if (std::strcmp(sep, "true") == 0) { + kvo.val_bool = true; + } else if (std::strcmp(sep, "false") == 0) { + kvo.val_bool = false; + } else { + fprintf(stderr, "%s: invalid boolean value for KV override '%s'\n", __func__, data); + return false; + } + } else if (strncmp(sep, "str:", 4) == 0) { + sep += 4; + kvo.tag = LLAMA_KV_OVERRIDE_TYPE_STR; + if (strlen(sep) > 127) { + fprintf(stderr, "%s: malformed KV override '%s', value cannot exceed 127 chars\n", __func__, data); + return false; + } + strncpy(kvo.val_str, sep, 127); + kvo.val_str[127] = '\0'; + } else { + fprintf(stderr, "%s: invalid type for KV override '%s'\n", __func__, data); + return false; + } + overrides.emplace_back(std::move(kvo)); + return true; +} + +// +// Filesystem utils +// + // Validate if a filename is safe to use // To validate a full path, split the path by the OS-specific path separator, and validate each part with this function -bool validate_file_name(const std::string & filename) { +bool fs_validate_filename(const std::string & filename) { if (!filename.length()) { // Empty filename invalid return false; @@ -1745,120 +2194,198 @@ bool validate_file_name(const std::string & filename) { return true; } -// -// String utils -// +// returns true if successful, false otherwise +bool fs_create_directory_with_parents(const std::string & path) { +#ifdef _WIN32 + std::wstring_convert> converter; + std::wstring wpath = converter.from_bytes(path); -std::vector string_split(std::string input, char separator) { - std::vector parts; - size_t separator_pos = input.find(separator); - while (separator_pos != std::string::npos) { - std::string part = input.substr(0, separator_pos); - parts.emplace_back(part); - input = input.substr(separator_pos + 1); - separator_pos = input.find(separator); + // if the path already exists, check whether it's a directory + const DWORD attributes = GetFileAttributesW(wpath.c_str()); + if ((attributes != INVALID_FILE_ATTRIBUTES) && (attributes & FILE_ATTRIBUTE_DIRECTORY)) { + return true; } - parts.emplace_back(input); - return parts; -} -std::string string_strip(const std::string & str) { - size_t start = 0; - size_t end = str.size(); - while (start < end && std::isspace(str[start])) { - start++; - } - while (end > start && std::isspace(str[end - 1])) { - end--; - } - return str.substr(start, end - start); -} + size_t pos_slash = 0; -std::vector sampler_types_from_names(const std::vector & names, bool allow_alt_names) { - std::unordered_map sampler_canonical_name_map { - {"top_k", llama_sampler_type::TOP_K}, - {"top_p", llama_sampler_type::TOP_P}, - {"typical_p", llama_sampler_type::TYPICAL_P}, - {"min_p", llama_sampler_type::MIN_P}, - {"tfs_z", llama_sampler_type::TFS_Z}, - {"temperature", llama_sampler_type::TEMPERATURE} - }; + // process path from front to back, procedurally creating directories + while ((pos_slash = path.find('\\', pos_slash)) != std::string::npos) { + const std::wstring subpath = wpath.substr(0, pos_slash); + const wchar_t * test = subpath.c_str(); - // since samplers names are written multiple ways - // make it ready for both system names and input names - std::unordered_map sampler_alt_name_map { - {"top-k", llama_sampler_type::TOP_K}, - {"top-p", llama_sampler_type::TOP_P}, - {"nucleus", llama_sampler_type::TOP_P}, - {"typical-p", llama_sampler_type::TYPICAL_P}, - {"typical", llama_sampler_type::TYPICAL_P}, - {"min-p", llama_sampler_type::MIN_P}, - {"tfs-z", llama_sampler_type::TFS_Z}, - {"tfs", llama_sampler_type::TFS_Z}, - {"temp", llama_sampler_type::TEMPERATURE} - }; + const bool success = CreateDirectoryW(test, NULL); + if (!success) { + const DWORD error = GetLastError(); - std::vector sampler_types; - sampler_types.reserve(names.size()); - for (const auto & name : names) - { - auto sampler_item = sampler_canonical_name_map.find(name); - if (sampler_item != sampler_canonical_name_map.end()) - { - sampler_types.push_back(sampler_item->second); - } - else - { - if (allow_alt_names) - { - sampler_item = sampler_alt_name_map.find(name); - if (sampler_item != sampler_alt_name_map.end()) - { - sampler_types.push_back(sampler_item->second); + // if the path already exists, ensure that it's a directory + if (error == ERROR_ALREADY_EXISTS) { + const DWORD attributes = GetFileAttributesW(subpath.c_str()); + if (attributes == INVALID_FILE_ATTRIBUTES || !(attributes & FILE_ATTRIBUTE_DIRECTORY)) { + return false; } + } else { + return false; } } + + pos_slash += 1; } - return sampler_types; -} -std::vector sampler_types_from_chars(const std::string & names_string) { - std::unordered_map sampler_name_map { - {'k', llama_sampler_type::TOP_K}, - {'p', llama_sampler_type::TOP_P}, - {'y', llama_sampler_type::TYPICAL_P}, - {'m', llama_sampler_type::MIN_P}, - {'f', llama_sampler_type::TFS_Z}, - {'t', llama_sampler_type::TEMPERATURE} - }; + return true; +#else + // if the path already exists, check whether it's a directory + struct stat info; + if (stat(path.c_str(), &info) == 0) { + return S_ISDIR(info.st_mode); + } - std::vector sampler_types; - sampler_types.reserve(names_string.size()); - for (const auto & c : names_string) { - const auto sampler_item = sampler_name_map.find(c); - if (sampler_item != sampler_name_map.end()) { - sampler_types.push_back(sampler_item->second); + size_t pos_slash = 1; // skip leading slashes for directory creation + + // process path from front to back, procedurally creating directories + while ((pos_slash = path.find('/', pos_slash)) != std::string::npos) { + const std::string subpath = path.substr(0, pos_slash); + struct stat info; + + // if the path already exists, ensure that it's a directory + if (stat(subpath.c_str(), &info) == 0) { + if (!S_ISDIR(info.st_mode)) { + return false; + } + } else { + // create parent directories + const int ret = mkdir(subpath.c_str(), 0755); + if (ret != 0) { + return false; + } } + + pos_slash += 1; } - return sampler_types; + + return true; +#endif // _WIN32 } -std::string sampler_type_to_name_string(llama_sampler_type sampler_type) { - switch (sampler_type) { - case llama_sampler_type::TOP_K: return "top_k"; - case llama_sampler_type::TFS_Z: return "tfs_z"; - case llama_sampler_type::TYPICAL_P: return "typical_p"; - case llama_sampler_type::TOP_P: return "top_p"; - case llama_sampler_type::MIN_P: return "min_p"; - case llama_sampler_type::TEMPERATURE: return "temperature"; - default : return ""; +std::string fs_get_cache_directory() { + std::string cache_directory = ""; + auto ensure_trailing_slash = [](std::string p) { + // Make sure to add trailing slash + if (p.back() != DIRECTORY_SEPARATOR) { + p += DIRECTORY_SEPARATOR; + } + return p; + }; + if (getenv("LLAMA_CACHE")) { + cache_directory = std::getenv("LLAMA_CACHE"); + } else { +#ifdef __linux__ + if (std::getenv("XDG_CACHE_HOME")) { + cache_directory = std::getenv("XDG_CACHE_HOME"); + } else { + cache_directory = std::getenv("HOME") + std::string("/.cache/"); + } +#elif defined(__APPLE__) + cache_directory = std::getenv("HOME") + std::string("/Library/Caches/"); +#elif defined(_WIN32) + cache_directory = std::getenv("LOCALAPPDATA"); +#endif // __linux__ + cache_directory = ensure_trailing_slash(cache_directory); + cache_directory += "llama.cpp"; } + return ensure_trailing_slash(cache_directory); } + // // Model utils // +std::tuple llama_init_from_gpt_params(gpt_params & params) { + auto mparams = llama_model_params_from_gpt_params(params); + + llama_model * model = nullptr; + + if (!params.hf_repo.empty() && !params.hf_file.empty()) { + model = llama_load_model_from_hf(params.hf_repo.c_str(), params.hf_file.c_str(), params.model.c_str(), mparams); + } else if (!params.model_url.empty()) { + model = llama_load_model_from_url(params.model_url.c_str(), params.model.c_str(), mparams); + } else { + model = llama_load_model_from_file(params.model.c_str(), mparams); + } + + if (model == NULL) { + fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str()); + return std::make_tuple(nullptr, nullptr); + } + + auto cparams = llama_context_params_from_gpt_params(params); + + llama_context * lctx = llama_new_context_with_model(model, cparams); + if (lctx == NULL) { + fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, params.model.c_str()); + llama_free_model(model); + return std::make_tuple(nullptr, nullptr); + } + + if (!params.control_vectors.empty()) { + if (params.control_vector_layer_start <= 0) params.control_vector_layer_start = 1; + if (params.control_vector_layer_end <= 0) params.control_vector_layer_end = llama_n_layer(model); + + const auto cvec = llama_control_vector_load(params.control_vectors); + if (cvec.n_embd == -1) { + llama_free(lctx); + llama_free_model(model); + return std::make_tuple(nullptr, nullptr); + } + + int err = llama_control_vector_apply(lctx, + cvec.data.data(), + cvec.data.size(), + cvec.n_embd, + params.control_vector_layer_start, + params.control_vector_layer_end); + if (err) { + llama_free(lctx); + llama_free_model(model); + return std::make_tuple(nullptr, nullptr); + } + } + + for (unsigned int i = 0; i < params.lora_adapter.size(); ++i) { + const std::string & lora_adapter = std::get<0>(params.lora_adapter[i]); + float lora_scale = std::get<1>(params.lora_adapter[i]); + int err = llama_model_apply_lora_from_file(model, + lora_adapter.c_str(), + lora_scale, + ((i > 0) || params.lora_base.empty()) + ? NULL + : params.lora_base.c_str(), + params.n_threads); + if (err != 0) { + fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__); + llama_free(lctx); + llama_free_model(model); + return std::make_tuple(nullptr, nullptr); + } + } + + if (params.ignore_eos) { + params.sparams.logit_bias[llama_token_eos(model)] = -INFINITY; + } + + if (params.warmup) { + LOG("warming up the model with an empty run\n"); + + std::vector tmp = { llama_token_bos(model), llama_token_eos(model), }; + llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0)); + llama_kv_cache_clear(lctx); + llama_synchronize(lctx); + llama_reset_timings(lctx); + } + + return std::make_tuple(model, lctx); +} + struct llama_model_params llama_model_params_from_gpt_params(const gpt_params & params) { auto mparams = llama_model_default_params(); @@ -1944,27 +2471,6 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param return cparams; } -void llama_batch_clear(struct llama_batch & batch) { - batch.n_tokens = 0; -} - -void llama_batch_add( - struct llama_batch & batch, - llama_token id, - llama_pos pos, - const std::vector & seq_ids, - bool logits) { - batch.token [batch.n_tokens] = id; - batch.pos [batch.n_tokens] = pos; - batch.n_seq_id[batch.n_tokens] = seq_ids.size(); - for (size_t i = 0; i < seq_ids.size(); ++i) { - batch.seq_id[batch.n_tokens][i] = seq_ids[i]; - } - batch.logits [batch.n_tokens] = logits; - - batch.n_tokens++; -} - #ifdef LLAMA_USE_CURL static bool starts_with(const std::string & str, const std::string & prefix) { @@ -2295,90 +2801,29 @@ struct llama_model * llama_load_model_from_hf( #endif // LLAMA_USE_CURL -std::tuple llama_init_from_gpt_params(gpt_params & params) { - auto mparams = llama_model_params_from_gpt_params(params); +// +// Batch utils +// - llama_model * model = nullptr; +void llama_batch_clear(struct llama_batch & batch) { + batch.n_tokens = 0; +} - if (!params.hf_repo.empty() && !params.hf_file.empty()) { - model = llama_load_model_from_hf(params.hf_repo.c_str(), params.hf_file.c_str(), params.model.c_str(), mparams); - } else if (!params.model_url.empty()) { - model = llama_load_model_from_url(params.model_url.c_str(), params.model.c_str(), mparams); - } else { - model = llama_load_model_from_file(params.model.c_str(), mparams); +void llama_batch_add( + struct llama_batch & batch, + llama_token id, + llama_pos pos, + const std::vector & seq_ids, + bool logits) { + batch.token [batch.n_tokens] = id; + batch.pos [batch.n_tokens] = pos; + batch.n_seq_id[batch.n_tokens] = seq_ids.size(); + for (size_t i = 0; i < seq_ids.size(); ++i) { + batch.seq_id[batch.n_tokens][i] = seq_ids[i]; } + batch.logits [batch.n_tokens] = logits; - if (model == NULL) { - fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str()); - return std::make_tuple(nullptr, nullptr); - } - - auto cparams = llama_context_params_from_gpt_params(params); - - llama_context * lctx = llama_new_context_with_model(model, cparams); - if (lctx == NULL) { - fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, params.model.c_str()); - llama_free_model(model); - return std::make_tuple(nullptr, nullptr); - } - - if (!params.control_vectors.empty()) { - if (params.control_vector_layer_start <= 0) params.control_vector_layer_start = 1; - if (params.control_vector_layer_end <= 0) params.control_vector_layer_end = llama_n_layer(model); - - const auto cvec = llama_control_vector_load(params.control_vectors); - if (cvec.n_embd == -1) { - llama_free(lctx); - llama_free_model(model); - return std::make_tuple(nullptr, nullptr); - } - - int err = llama_control_vector_apply(lctx, - cvec.data.data(), - cvec.data.size(), - cvec.n_embd, - params.control_vector_layer_start, - params.control_vector_layer_end); - if (err) { - llama_free(lctx); - llama_free_model(model); - return std::make_tuple(nullptr, nullptr); - } - } - - for (unsigned int i = 0; i < params.lora_adapter.size(); ++i) { - const std::string & lora_adapter = std::get<0>(params.lora_adapter[i]); - float lora_scale = std::get<1>(params.lora_adapter[i]); - int err = llama_model_apply_lora_from_file(model, - lora_adapter.c_str(), - lora_scale, - ((i > 0) || params.lora_base.empty()) - ? NULL - : params.lora_base.c_str(), - params.n_threads); - if (err != 0) { - fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__); - llama_free(lctx); - llama_free_model(model); - return std::make_tuple(nullptr, nullptr); - } - } - - if (params.ignore_eos) { - params.sparams.logit_bias[llama_token_eos(model)] = -INFINITY; - } - - if (params.warmup) { - LOG("warming up the model with an empty run\n"); - - std::vector tmp = { llama_token_bos(model), llama_token_eos(model), }; - llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0)); - llama_kv_cache_clear(lctx); - llama_synchronize(lctx); - llama_reset_timings(lctx); - } - - return std::make_tuple(model, lctx); + batch.n_tokens++; } // @@ -2466,321 +2911,17 @@ bool llama_should_add_bos_token(const llama_model * model) { return add_bos != -1 ? bool(add_bos) : (llama_vocab_type(model) == LLAMA_VOCAB_TYPE_SPM); } -// -// YAML utils -// - -// returns true if successful, false otherwise -bool create_directory_with_parents(const std::string & path) { -#ifdef _WIN32 - std::wstring_convert> converter; - std::wstring wpath = converter.from_bytes(path); - - // if the path already exists, check whether it's a directory - const DWORD attributes = GetFileAttributesW(wpath.c_str()); - if ((attributes != INVALID_FILE_ATTRIBUTES) && (attributes & FILE_ATTRIBUTE_DIRECTORY)) { - return true; - } - - size_t pos_slash = 0; - - // process path from front to back, procedurally creating directories - while ((pos_slash = path.find('\\', pos_slash)) != std::string::npos) { - const std::wstring subpath = wpath.substr(0, pos_slash); - const wchar_t * test = subpath.c_str(); - - const bool success = CreateDirectoryW(test, NULL); - if (!success) { - const DWORD error = GetLastError(); - - // if the path already exists, ensure that it's a directory - if (error == ERROR_ALREADY_EXISTS) { - const DWORD attributes = GetFileAttributesW(subpath.c_str()); - if (attributes == INVALID_FILE_ATTRIBUTES || !(attributes & FILE_ATTRIBUTE_DIRECTORY)) { - return false; - } - } else { - return false; - } - } - - pos_slash += 1; - } - - return true; -#else - // if the path already exists, check whether it's a directory - struct stat info; - if (stat(path.c_str(), &info) == 0) { - return S_ISDIR(info.st_mode); - } - - size_t pos_slash = 1; // skip leading slashes for directory creation - - // process path from front to back, procedurally creating directories - while ((pos_slash = path.find('/', pos_slash)) != std::string::npos) { - const std::string subpath = path.substr(0, pos_slash); - struct stat info; - - // if the path already exists, ensure that it's a directory - if (stat(subpath.c_str(), &info) == 0) { - if (!S_ISDIR(info.st_mode)) { - return false; - } - } else { - // create parent directories - const int ret = mkdir(subpath.c_str(), 0755); - if (ret != 0) { - return false; - } - } - - pos_slash += 1; - } - - return true; -#endif // _WIN32 -} - -void dump_vector_float_yaml(FILE * stream, const char * prop_name, const std::vector & data) { - if (data.empty()) { - fprintf(stream, "%s:\n", prop_name); - return; - } - - fprintf(stream, "%s: [", prop_name); - for (size_t i = 0; i < data.size() - 1; ++i) { - fprintf(stream, "%e, ", data[i]); - } - fprintf(stream, "%e]\n", data.back()); -} - -void dump_vector_int_yaml(FILE * stream, const char * prop_name, const std::vector & data) { - if (data.empty()) { - fprintf(stream, "%s:\n", prop_name); - return; - } - - fprintf(stream, "%s: [", prop_name); - for (size_t i = 0; i < data.size() - 1; ++i) { - fprintf(stream, "%d, ", data[i]); - } - fprintf(stream, "%d]\n", data.back()); -} - -void dump_string_yaml_multiline(FILE * stream, const char * prop_name, const char * data) { - std::string data_str(data == NULL ? "" : data); - - if (data_str.empty()) { - fprintf(stream, "%s:\n", prop_name); - return; - } - - size_t pos_start = 0; - size_t pos_found = 0; - - if (std::isspace(data_str[0]) || std::isspace(data_str.back())) { - data_str = std::regex_replace(data_str, std::regex("\n"), "\\n"); - data_str = std::regex_replace(data_str, std::regex("\""), "\\\""); - data_str = std::regex_replace(data_str, std::regex(R"(\\[^n"])"), R"(\$&)"); - data_str = "\"" + data_str + "\""; - fprintf(stream, "%s: %s\n", prop_name, data_str.c_str()); - return; - } - - if (data_str.find('\n') == std::string::npos) { - fprintf(stream, "%s: %s\n", prop_name, data_str.c_str()); - return; - } - - fprintf(stream, "%s: |\n", prop_name); - while ((pos_found = data_str.find('\n', pos_start)) != std::string::npos) { - fprintf(stream, " %s\n", data_str.substr(pos_start, pos_found-pos_start).c_str()); - pos_start = pos_found + 1; - } -} - -std::string get_sortable_timestamp() { - using clock = std::chrono::system_clock; - - const clock::time_point current_time = clock::now(); - const time_t as_time_t = clock::to_time_t(current_time); - char timestamp_no_ns[100]; - std::strftime(timestamp_no_ns, 100, "%Y_%m_%d-%H_%M_%S", std::localtime(&as_time_t)); - - const int64_t ns = std::chrono::duration_cast( - current_time.time_since_epoch() % 1000000000).count(); - char timestamp_ns[11]; - snprintf(timestamp_ns, 11, "%09" PRId64, ns); - - return std::string(timestamp_no_ns) + "." + std::string(timestamp_ns); -} - -void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const llama_context * lctx, - const std::string & timestamp, const std::vector & prompt_tokens, const char * model_desc) { - const llama_sampling_params & sparams = params.sparams; - - fprintf(stream, "build_commit: %s\n", LLAMA_COMMIT); - fprintf(stream, "build_number: %d\n", LLAMA_BUILD_NUMBER); - fprintf(stream, "cpu_has_arm_fma: %s\n", ggml_cpu_has_arm_fma() ? "true" : "false"); - fprintf(stream, "cpu_has_avx: %s\n", ggml_cpu_has_avx() ? "true" : "false"); - fprintf(stream, "cpu_has_avx_vnni: %s\n", ggml_cpu_has_avx_vnni() ? "true" : "false"); - fprintf(stream, "cpu_has_avx2: %s\n", ggml_cpu_has_avx2() ? "true" : "false"); - fprintf(stream, "cpu_has_avx512: %s\n", ggml_cpu_has_avx512() ? "true" : "false"); - fprintf(stream, "cpu_has_avx512_vbmi: %s\n", ggml_cpu_has_avx512_vbmi() ? "true" : "false"); - fprintf(stream, "cpu_has_avx512_vnni: %s\n", ggml_cpu_has_avx512_vnni() ? "true" : "false"); - fprintf(stream, "cpu_has_cuda: %s\n", ggml_cpu_has_cuda() ? "true" : "false"); - fprintf(stream, "cpu_has_vulkan: %s\n", ggml_cpu_has_vulkan() ? "true" : "false"); - fprintf(stream, "cpu_has_clblast: %s\n", ggml_cpu_has_clblast() ? "true" : "false"); - fprintf(stream, "cpu_has_kompute: %s\n", ggml_cpu_has_kompute() ? "true" : "false"); - fprintf(stream, "cpu_has_fma: %s\n", ggml_cpu_has_fma() ? "true" : "false"); - fprintf(stream, "cpu_has_gpublas: %s\n", ggml_cpu_has_gpublas() ? "true" : "false"); - fprintf(stream, "cpu_has_neon: %s\n", ggml_cpu_has_neon() ? "true" : "false"); - fprintf(stream, "cpu_has_f16c: %s\n", ggml_cpu_has_f16c() ? "true" : "false"); - fprintf(stream, "cpu_has_fp16_va: %s\n", ggml_cpu_has_fp16_va() ? "true" : "false"); - fprintf(stream, "cpu_has_wasm_simd: %s\n", ggml_cpu_has_wasm_simd() ? "true" : "false"); - fprintf(stream, "cpu_has_blas: %s\n", ggml_cpu_has_blas() ? "true" : "false"); - fprintf(stream, "cpu_has_sse3: %s\n", ggml_cpu_has_sse3() ? "true" : "false"); - fprintf(stream, "cpu_has_vsx: %s\n", ggml_cpu_has_vsx() ? "true" : "false"); - fprintf(stream, "cpu_has_matmul_int8: %s\n", ggml_cpu_has_matmul_int8() ? "true" : "false"); - -#ifdef NDEBUG - fprintf(stream, "debug: false\n"); -#else - fprintf(stream, "debug: true\n"); -#endif // NDEBUG - - fprintf(stream, "model_desc: %s\n", model_desc); - fprintf(stream, "n_vocab: %d # output size of the final layer, 32001 for some models\n", llama_n_vocab(llama_get_model(lctx))); - -#ifdef __OPTIMIZE__ - fprintf(stream, "optimize: true\n"); -#else - fprintf(stream, "optimize: false\n"); -#endif // __OPTIMIZE__ - - fprintf(stream, "time: %s\n", timestamp.c_str()); - - fprintf(stream, "\n"); - fprintf(stream, "###############\n"); - fprintf(stream, "# User Inputs #\n"); - fprintf(stream, "###############\n"); - fprintf(stream, "\n"); - - fprintf(stream, "alias: %s # default: unknown\n", params.model_alias.c_str()); - fprintf(stream, "batch_size: %d # default: 512\n", params.n_batch); - dump_string_yaml_multiline(stream, "cfg_negative_prompt", sparams.cfg_negative_prompt.c_str()); - fprintf(stream, "cfg_scale: %f # default: 1.0\n", sparams.cfg_scale); - fprintf(stream, "chunks: %d # default: -1 (unlimited)\n", params.n_chunks); - fprintf(stream, "color: %s # default: false\n", params.use_color ? "true" : "false"); - fprintf(stream, "ctx_size: %d # default: 512\n", params.n_ctx); - fprintf(stream, "escape: %s # default: false\n", params.escape ? "true" : "false"); - fprintf(stream, "file: # never logged, see prompt instead. Can still be specified for input.\n"); - fprintf(stream, "frequency_penalty: %f # default: 0.0 \n", sparams.penalty_freq); - dump_string_yaml_multiline(stream, "grammar", sparams.grammar.c_str()); - fprintf(stream, "grammar-file: # never logged, see grammar instead. Can still be specified for input.\n"); - fprintf(stream, "hellaswag: %s # default: false\n", params.hellaswag ? "true" : "false"); - fprintf(stream, "hellaswag_tasks: %zu # default: 400\n", params.hellaswag_tasks); - - const auto logit_bias_eos = sparams.logit_bias.find(llama_token_eos(llama_get_model(lctx))); - const bool ignore_eos = logit_bias_eos != sparams.logit_bias.end() && logit_bias_eos->second == -INFINITY; - fprintf(stream, "ignore_eos: %s # default: false\n", ignore_eos ? "true" : "false"); - - dump_string_yaml_multiline(stream, "in_prefix", params.input_prefix.c_str()); - fprintf(stream, "in_prefix_bos: %s # default: false\n", params.input_prefix_bos ? "true" : "false"); - dump_string_yaml_multiline(stream, "in_suffix", params.input_prefix.c_str()); - fprintf(stream, "instruct: %s # default: false\n", params.instruct ? "true" : "false"); - fprintf(stream, "interactive: %s # default: false\n", params.interactive ? "true" : "false"); - fprintf(stream, "interactive_specials: %s # default: false\n", params.interactive_specials ? "true" : "false"); - fprintf(stream, "interactive_first: %s # default: false\n", params.interactive_first ? "true" : "false"); - fprintf(stream, "keep: %d # default: 0\n", params.n_keep); - fprintf(stream, "logdir: %s # default: unset (no logging)\n", params.logdir.c_str()); - - fprintf(stream, "logit_bias:\n"); - for (std::pair lb : sparams.logit_bias) { - if (ignore_eos && lb.first == logit_bias_eos->first) { - continue; - } - fprintf(stream, " %d: %f", lb.first, lb.second); - } - - fprintf(stream, "lora:\n"); - for (std::tuple la : params.lora_adapter) { - if (std::get<1>(la) != 1.0f) { - continue; - } - fprintf(stream, " - %s\n", std::get<0>(la).c_str()); - } - fprintf(stream, "lora_scaled:\n"); - for (std::tuple la : params.lora_adapter) { - if (std::get<1>(la) == 1.0f) { - continue; - } - fprintf(stream, " - %s: %f\n", std::get<0>(la).c_str(), std::get<1>(la)); - } - fprintf(stream, "lora_base: %s\n", params.lora_base.c_str()); - fprintf(stream, "main_gpu: %d # default: 0\n", params.main_gpu); - fprintf(stream, "min_keep: %d # default: 0 (disabled)\n", sparams.min_keep); - fprintf(stream, "mirostat: %d # default: 0 (disabled)\n", sparams.mirostat); - fprintf(stream, "mirostat_ent: %f # default: 5.0\n", sparams.mirostat_tau); - fprintf(stream, "mirostat_lr: %f # default: 0.1\n", sparams.mirostat_eta); - fprintf(stream, "mlock: %s # default: false\n", params.use_mlock ? "true" : "false"); - fprintf(stream, "model: %s # default: %s\n", params.model.c_str(), DEFAULT_MODEL_PATH); - fprintf(stream, "model_draft: %s # default:\n", params.model_draft.c_str()); - fprintf(stream, "multiline_input: %s # default: false\n", params.multiline_input ? "true" : "false"); - fprintf(stream, "n_gpu_layers: %d # default: -1\n", params.n_gpu_layers); - fprintf(stream, "n_predict: %d # default: -1 (unlimited)\n", params.n_predict); - fprintf(stream, "n_probs: %d # only used by server binary, default: 0\n", sparams.n_probs); - fprintf(stream, "no_mmap: %s # default: false\n", !params.use_mmap ? "true" : "false"); - fprintf(stream, "penalize_nl: %s # default: false\n", sparams.penalize_nl ? "true" : "false"); - fprintf(stream, "ppl_output_type: %d # default: 0\n", params.ppl_output_type); - fprintf(stream, "ppl_stride: %d # default: 0\n", params.ppl_stride); - fprintf(stream, "presence_penalty: %f # default: 0.0\n", sparams.penalty_present); - dump_string_yaml_multiline(stream, "prompt", params.prompt.c_str()); - fprintf(stream, "prompt_cache: %s\n", params.path_prompt_cache.c_str()); - fprintf(stream, "prompt_cache_all: %s # default: false\n", params.prompt_cache_all ? "true" : "false"); - fprintf(stream, "prompt_cache_ro: %s # default: false\n", params.prompt_cache_ro ? "true" : "false"); - dump_vector_int_yaml(stream, "prompt_tokens", prompt_tokens); - fprintf(stream, "random_prompt: %s # default: false\n", params.random_prompt ? "true" : "false"); - fprintf(stream, "repeat_penalty: %f # default: 1.1\n", sparams.penalty_repeat); - - fprintf(stream, "reverse_prompt:\n"); - for (std::string ap : params.antiprompt) { - size_t pos = 0; - while ((pos = ap.find('\n', pos)) != std::string::npos) { - ap.replace(pos, 1, "\\n"); - pos += 1; - } - - fprintf(stream, " - %s\n", ap.c_str()); - } - - fprintf(stream, "rope_freq_base: %f # default: 10000.0\n", params.rope_freq_base); - fprintf(stream, "rope_freq_scale: %f # default: 1.0\n", params.rope_freq_scale); - fprintf(stream, "seed: %u # default: -1 (random seed)\n", params.seed); - fprintf(stream, "simple_io: %s # default: false\n", params.simple_io ? "true" : "false"); - fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false"); - fprintf(stream, "flash_attn: %s # default: false\n", params.flash_attn ? "true" : "false"); - fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp); - - const std::vector tensor_split_vector(params.tensor_split, params.tensor_split + llama_max_devices()); - dump_vector_float_yaml(stream, "tensor_split", tensor_split_vector); - - fprintf(stream, "tfs: %f # default: 1.0\n", sparams.tfs_z); - fprintf(stream, "threads: %d # default: %u\n", params.n_threads, std::thread::hardware_concurrency()); - fprintf(stream, "top_k: %d # default: 40\n", sparams.top_k); - fprintf(stream, "top_p: %f # default: 0.95\n", sparams.top_p); - fprintf(stream, "min_p: %f # default: 0.0\n", sparams.min_p); - fprintf(stream, "typical_p: %f # default: 1.0\n", sparams.typical_p); - fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false"); - fprintf(stream, "display_prompt: %s # default: true\n", params.display_prompt ? "true" : "false"); +bool llama_chat_verify_template(const std::string & tmpl) { + llama_chat_message chat[] = {{"user", "test"}}; + int res = llama_chat_apply_template(nullptr, tmpl.c_str(), chat, 1, true, nullptr, 0); + return res >= 0; } // // KV cache utils // -void dump_kv_cache_view(const llama_kv_cache_view & view, int row_size) { +void llama_kv_cache_dump_view(const llama_kv_cache_view & view, int row_size) { static const char slot_chars[] = ".123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz+"; printf("=== Dumping KV cache. total cells %d, max sequences per cell %d, populated cells %d, total tokens in cache %d, largest empty slot=%d @ %d", @@ -2803,7 +2944,7 @@ void dump_kv_cache_view(const llama_kv_cache_view & view, int row_size) { printf("\n=== Done dumping\n"); } -void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size) { +void llama_kv_cache_dump_view_seqs(const llama_kv_cache_view & view, int row_size) { static const char slot_chars[] = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"; printf("=== Dumping KV cache. total cells %d, max sequences per cell %d, populated cells %d, total tokens in cache %d, largest empty slot=%d @ %d\n", @@ -2851,6 +2992,10 @@ void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size) { printf("\n=== Done dumping\n"); } +// +// Embedding utils +// + void llama_embd_normalize(const float * inp, float * out, int n) { double sum = 0.0; for (int i = 0; i < n; i++) { @@ -3035,3 +3180,222 @@ llama_control_vector_data llama_control_vector_load(const std::vector & data) { + if (data.empty()) { + fprintf(stream, "%s:\n", prop_name); + return; + } + + fprintf(stream, "%s: [", prop_name); + for (size_t i = 0; i < data.size() - 1; ++i) { + fprintf(stream, "%e, ", data[i]); + } + fprintf(stream, "%e]\n", data.back()); +} + +void yaml_dump_vector_int(FILE * stream, const char * prop_name, const std::vector & data) { + if (data.empty()) { + fprintf(stream, "%s:\n", prop_name); + return; + } + + fprintf(stream, "%s: [", prop_name); + for (size_t i = 0; i < data.size() - 1; ++i) { + fprintf(stream, "%d, ", data[i]); + } + fprintf(stream, "%d]\n", data.back()); +} + +void yaml_dump_string_multiline(FILE * stream, const char * prop_name, const char * data) { + std::string data_str(data == NULL ? "" : data); + + if (data_str.empty()) { + fprintf(stream, "%s:\n", prop_name); + return; + } + + size_t pos_start = 0; + size_t pos_found = 0; + + if (std::isspace(data_str[0]) || std::isspace(data_str.back())) { + data_str = std::regex_replace(data_str, std::regex("\n"), "\\n"); + data_str = std::regex_replace(data_str, std::regex("\""), "\\\""); + data_str = std::regex_replace(data_str, std::regex(R"(\\[^n"])"), R"(\$&)"); + data_str = "\"" + data_str + "\""; + fprintf(stream, "%s: %s\n", prop_name, data_str.c_str()); + return; + } + + if (data_str.find('\n') == std::string::npos) { + fprintf(stream, "%s: %s\n", prop_name, data_str.c_str()); + return; + } + + fprintf(stream, "%s: |\n", prop_name); + while ((pos_found = data_str.find('\n', pos_start)) != std::string::npos) { + fprintf(stream, " %s\n", data_str.substr(pos_start, pos_found-pos_start).c_str()); + pos_start = pos_found + 1; + } +} + +void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const llama_context * lctx, + const std::string & timestamp, const std::vector & prompt_tokens, const char * model_desc) { + const llama_sampling_params & sparams = params.sparams; + + fprintf(stream, "build_commit: %s\n", LLAMA_COMMIT); + fprintf(stream, "build_number: %d\n", LLAMA_BUILD_NUMBER); + fprintf(stream, "cpu_has_arm_fma: %s\n", ggml_cpu_has_arm_fma() ? "true" : "false"); + fprintf(stream, "cpu_has_avx: %s\n", ggml_cpu_has_avx() ? "true" : "false"); + fprintf(stream, "cpu_has_avx_vnni: %s\n", ggml_cpu_has_avx_vnni() ? "true" : "false"); + fprintf(stream, "cpu_has_avx2: %s\n", ggml_cpu_has_avx2() ? "true" : "false"); + fprintf(stream, "cpu_has_avx512: %s\n", ggml_cpu_has_avx512() ? "true" : "false"); + fprintf(stream, "cpu_has_avx512_vbmi: %s\n", ggml_cpu_has_avx512_vbmi() ? "true" : "false"); + fprintf(stream, "cpu_has_avx512_vnni: %s\n", ggml_cpu_has_avx512_vnni() ? "true" : "false"); + fprintf(stream, "cpu_has_cuda: %s\n", ggml_cpu_has_cuda() ? "true" : "false"); + fprintf(stream, "cpu_has_vulkan: %s\n", ggml_cpu_has_vulkan() ? "true" : "false"); + fprintf(stream, "cpu_has_kompute: %s\n", ggml_cpu_has_kompute() ? "true" : "false"); + fprintf(stream, "cpu_has_fma: %s\n", ggml_cpu_has_fma() ? "true" : "false"); + fprintf(stream, "cpu_has_gpublas: %s\n", ggml_cpu_has_gpublas() ? "true" : "false"); + fprintf(stream, "cpu_has_neon: %s\n", ggml_cpu_has_neon() ? "true" : "false"); + fprintf(stream, "cpu_has_sve: %s\n", ggml_cpu_has_sve() ? "true" : "false"); + fprintf(stream, "cpu_has_f16c: %s\n", ggml_cpu_has_f16c() ? "true" : "false"); + fprintf(stream, "cpu_has_fp16_va: %s\n", ggml_cpu_has_fp16_va() ? "true" : "false"); + fprintf(stream, "cpu_has_wasm_simd: %s\n", ggml_cpu_has_wasm_simd() ? "true" : "false"); + fprintf(stream, "cpu_has_blas: %s\n", ggml_cpu_has_blas() ? "true" : "false"); + fprintf(stream, "cpu_has_sse3: %s\n", ggml_cpu_has_sse3() ? "true" : "false"); + fprintf(stream, "cpu_has_vsx: %s\n", ggml_cpu_has_vsx() ? "true" : "false"); + fprintf(stream, "cpu_has_matmul_int8: %s\n", ggml_cpu_has_matmul_int8() ? "true" : "false"); + +#ifdef NDEBUG + fprintf(stream, "debug: false\n"); +#else + fprintf(stream, "debug: true\n"); +#endif // NDEBUG + + fprintf(stream, "model_desc: %s\n", model_desc); + fprintf(stream, "n_vocab: %d # output size of the final layer, 32001 for some models\n", llama_n_vocab(llama_get_model(lctx))); + +#ifdef __OPTIMIZE__ + fprintf(stream, "optimize: true\n"); +#else + fprintf(stream, "optimize: false\n"); +#endif // __OPTIMIZE__ + + fprintf(stream, "time: %s\n", timestamp.c_str()); + + fprintf(stream, "\n"); + fprintf(stream, "###############\n"); + fprintf(stream, "# User Inputs #\n"); + fprintf(stream, "###############\n"); + fprintf(stream, "\n"); + + fprintf(stream, "alias: %s # default: unknown\n", params.model_alias.c_str()); + fprintf(stream, "batch_size: %d # default: 512\n", params.n_batch); + yaml_dump_string_multiline(stream, "cfg_negative_prompt", sparams.cfg_negative_prompt.c_str()); + fprintf(stream, "cfg_scale: %f # default: 1.0\n", sparams.cfg_scale); + fprintf(stream, "chunks: %d # default: -1 (unlimited)\n", params.n_chunks); + fprintf(stream, "color: %s # default: false\n", params.use_color ? "true" : "false"); + fprintf(stream, "ctx_size: %d # default: 512\n", params.n_ctx); + fprintf(stream, "escape: %s # default: false\n", params.escape ? "true" : "false"); + fprintf(stream, "file: # never logged, see prompt instead. Can still be specified for input.\n"); + fprintf(stream, "frequency_penalty: %f # default: 0.0 \n", sparams.penalty_freq); + yaml_dump_string_multiline(stream, "grammar", sparams.grammar.c_str()); + fprintf(stream, "grammar-file: # never logged, see grammar instead. Can still be specified for input.\n"); + fprintf(stream, "hellaswag: %s # default: false\n", params.hellaswag ? "true" : "false"); + fprintf(stream, "hellaswag_tasks: %zu # default: 400\n", params.hellaswag_tasks); + + const auto logit_bias_eos = sparams.logit_bias.find(llama_token_eos(llama_get_model(lctx))); + const bool ignore_eos = logit_bias_eos != sparams.logit_bias.end() && logit_bias_eos->second == -INFINITY; + fprintf(stream, "ignore_eos: %s # default: false\n", ignore_eos ? "true" : "false"); + + yaml_dump_string_multiline(stream, "in_prefix", params.input_prefix.c_str()); + fprintf(stream, "in_prefix_bos: %s # default: false\n", params.input_prefix_bos ? "true" : "false"); + yaml_dump_string_multiline(stream, "in_suffix", params.input_prefix.c_str()); + fprintf(stream, "interactive: %s # default: false\n", params.interactive ? "true" : "false"); + fprintf(stream, "interactive_first: %s # default: false\n", params.interactive_first ? "true" : "false"); + fprintf(stream, "keep: %d # default: 0\n", params.n_keep); + fprintf(stream, "logdir: %s # default: unset (no logging)\n", params.logdir.c_str()); + + fprintf(stream, "logit_bias:\n"); + for (std::pair lb : sparams.logit_bias) { + if (ignore_eos && lb.first == logit_bias_eos->first) { + continue; + } + fprintf(stream, " %d: %f", lb.first, lb.second); + } + + fprintf(stream, "lora:\n"); + for (std::tuple la : params.lora_adapter) { + if (std::get<1>(la) != 1.0f) { + continue; + } + fprintf(stream, " - %s\n", std::get<0>(la).c_str()); + } + fprintf(stream, "lora_scaled:\n"); + for (std::tuple la : params.lora_adapter) { + if (std::get<1>(la) == 1.0f) { + continue; + } + fprintf(stream, " - %s: %f\n", std::get<0>(la).c_str(), std::get<1>(la)); + } + fprintf(stream, "lora_base: %s\n", params.lora_base.c_str()); + fprintf(stream, "main_gpu: %d # default: 0\n", params.main_gpu); + fprintf(stream, "min_keep: %d # default: 0 (disabled)\n", sparams.min_keep); + fprintf(stream, "mirostat: %d # default: 0 (disabled)\n", sparams.mirostat); + fprintf(stream, "mirostat_ent: %f # default: 5.0\n", sparams.mirostat_tau); + fprintf(stream, "mirostat_lr: %f # default: 0.1\n", sparams.mirostat_eta); + fprintf(stream, "mlock: %s # default: false\n", params.use_mlock ? "true" : "false"); + fprintf(stream, "model: %s # default: %s\n", params.model.c_str(), DEFAULT_MODEL_PATH); + fprintf(stream, "model_draft: %s # default:\n", params.model_draft.c_str()); + fprintf(stream, "multiline_input: %s # default: false\n", params.multiline_input ? "true" : "false"); + fprintf(stream, "n_gpu_layers: %d # default: -1\n", params.n_gpu_layers); + fprintf(stream, "n_predict: %d # default: -1 (unlimited)\n", params.n_predict); + fprintf(stream, "n_probs: %d # only used by server binary, default: 0\n", sparams.n_probs); + fprintf(stream, "no_mmap: %s # default: false\n", !params.use_mmap ? "true" : "false"); + fprintf(stream, "penalize_nl: %s # default: false\n", sparams.penalize_nl ? "true" : "false"); + fprintf(stream, "ppl_output_type: %d # default: 0\n", params.ppl_output_type); + fprintf(stream, "ppl_stride: %d # default: 0\n", params.ppl_stride); + fprintf(stream, "presence_penalty: %f # default: 0.0\n", sparams.penalty_present); + yaml_dump_string_multiline(stream, "prompt", params.prompt.c_str()); + fprintf(stream, "prompt_cache: %s\n", params.path_prompt_cache.c_str()); + fprintf(stream, "prompt_cache_all: %s # default: false\n", params.prompt_cache_all ? "true" : "false"); + fprintf(stream, "prompt_cache_ro: %s # default: false\n", params.prompt_cache_ro ? "true" : "false"); + yaml_dump_vector_int(stream, "prompt_tokens", prompt_tokens); + fprintf(stream, "repeat_penalty: %f # default: 1.1\n", sparams.penalty_repeat); + + fprintf(stream, "reverse_prompt:\n"); + for (std::string ap : params.antiprompt) { + size_t pos = 0; + while ((pos = ap.find('\n', pos)) != std::string::npos) { + ap.replace(pos, 1, "\\n"); + pos += 1; + } + + fprintf(stream, " - %s\n", ap.c_str()); + } + + fprintf(stream, "rope_freq_base: %f # default: 10000.0\n", params.rope_freq_base); + fprintf(stream, "rope_freq_scale: %f # default: 1.0\n", params.rope_freq_scale); + fprintf(stream, "seed: %u # default: -1 (random seed)\n", params.seed); + fprintf(stream, "simple_io: %s # default: false\n", params.simple_io ? "true" : "false"); + fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false"); + fprintf(stream, "flash_attn: %s # default: false\n", params.flash_attn ? "true" : "false"); + fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp); + + const std::vector tensor_split_vector(params.tensor_split, params.tensor_split + llama_max_devices()); + yaml_dump_vector_float(stream, "tensor_split", tensor_split_vector); + + fprintf(stream, "tfs: %f # default: 1.0\n", sparams.tfs_z); + fprintf(stream, "threads: %d # default: %u\n", params.n_threads, std::thread::hardware_concurrency()); + fprintf(stream, "top_k: %d # default: 40\n", sparams.top_k); + fprintf(stream, "top_p: %f # default: 0.95\n", sparams.top_p); + fprintf(stream, "min_p: %f # default: 0.0\n", sparams.min_p); + fprintf(stream, "typical_p: %f # default: 1.0\n", sparams.typical_p); + fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false"); + fprintf(stream, "display_prompt: %s # default: true\n", params.display_prompt ? "true" : "false"); +} diff --git a/llama/common.h b/llama/common.h index 75e926d3..913d8c2b 100644 --- a/llama/common.h +++ b/llama/common.h @@ -1,5 +1,5 @@ /** - * llama.cpp - git 059031b8c40e1f4ba60586842c5b1ed3ddf61842 + * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28 * * MIT License * @@ -53,7 +53,7 @@ #define die_fmt(fmt, ...) do { fprintf(stderr, "error: " fmt "\n", __VA_ARGS__); exit(1); } while (0) #define print_build_info() do { \ - fprintf(stderr, "%s: build = %d (%s)\n", __func__, LLAMA_BUILD_NUMBER, LLAMA_COMMIT); \ + fprintf(stderr, "%s: build = %d (%s)\n", __func__, LLAMA_BUILD_NUMBER, LLAMA_COMMIT); \ fprintf(stderr, "%s: built with %s for %s\n", __func__, LLAMA_COMPILER, LLAMA_BUILD_TARGET); \ } while(0) @@ -61,14 +61,18 @@ // build info extern int LLAMA_BUILD_NUMBER; -extern char const *LLAMA_COMMIT; -extern char const *LLAMA_COMPILER; -extern char const *LLAMA_BUILD_TARGET; +extern char const * LLAMA_COMMIT; +extern char const * LLAMA_COMPILER; +extern char const * LLAMA_BUILD_TARGET; struct llama_control_vector_load_info; -int get_math_cpu_count(); -int32_t get_num_physical_cores(); +// +// CPU utils +// + +int32_t cpu_get_num_physical_cores(); +int32_t cpu_get_num_math(); // // CLI argument parsing @@ -77,67 +81,68 @@ int32_t get_num_physical_cores(); struct gpt_params { uint32_t seed = LLAMA_DEFAULT_SEED; // RNG seed - int32_t n_threads = get_math_cpu_count(); - int32_t n_threads_draft = -1; - int32_t n_threads_batch = -1; // number of threads to use for batch processing (-1 = use n_threads) - int32_t n_threads_batch_draft = -1; - int32_t n_predict = -1; // new tokens to predict - int32_t n_ctx = 512; // context size - int32_t n_batch = 2048; // logical batch size for prompt processing (must be >=32 to use BLAS) - int32_t n_ubatch = 512; // physical batch size for prompt processing (must be >=32 to use BLAS) - int32_t n_keep = 0; // number of tokens to keep from initial prompt - int32_t n_draft = 5; // number of tokens to draft during speculative decoding - int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited) - int32_t n_parallel = 1; // number of parallel sequences to decode - int32_t n_sequences = 1; // number of sequences to decode - float p_split = 0.1f; // speculative decoding split probability - int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default) - int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default) - llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs - int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors - float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs - int32_t n_beams = 0; // if non-zero then use beam search of given width. - int32_t grp_attn_n = 1; // group-attention factor - int32_t grp_attn_w = 512; // group-attention width - int32_t n_print = -1; // print token count every n tokens (-1 = disabled) - float rope_freq_base = 0.0f; // RoPE base frequency - float rope_freq_scale = 0.0f; // RoPE frequency scaling factor + int32_t n_threads = cpu_get_num_math(); + int32_t n_threads_draft = -1; + int32_t n_threads_batch = -1; // number of threads to use for batch processing (-1 = use n_threads) + int32_t n_threads_batch_draft = -1; + int32_t n_predict = -1; // new tokens to predict + int32_t n_ctx = 0; // context size + int32_t n_batch = 2048; // logical batch size for prompt processing (must be >=32 to use BLAS) + int32_t n_ubatch = 512; // physical batch size for prompt processing (must be >=32 to use BLAS) + int32_t n_keep = 0; // number of tokens to keep from initial prompt + int32_t n_draft = 5; // number of tokens to draft during speculative decoding + int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited) + int32_t n_parallel = 1; // number of parallel sequences to decode + int32_t n_sequences = 1; // number of sequences to decode + float p_split = 0.1f; // speculative decoding split probability + int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default) + int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default) + int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors + float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs + int32_t n_beams = 0; // if non-zero then use beam search of given width. + int32_t grp_attn_n = 1; // group-attention factor + int32_t grp_attn_w = 512; // group-attention width + int32_t n_print = -1; // print token count every n tokens (-1 = disabled) + float rope_freq_base = 0.0f; // RoPE base frequency + float rope_freq_scale = 0.0f; // RoPE frequency scaling factor float yarn_ext_factor = -1.0f; // YaRN extrapolation mix factor - float yarn_attn_factor = 1.0f; // YaRN magnitude scaling factor + float yarn_attn_factor = 1.0f; // YaRN magnitude scaling factor float yarn_beta_fast = 32.0f; // YaRN low correction dim - float yarn_beta_slow = 1.0f; // YaRN high correction dim - int32_t yarn_orig_ctx = 0; // YaRN original context length + float yarn_beta_slow = 1.0f; // YaRN high correction dim + int32_t yarn_orig_ctx = 0; // YaRN original context length float defrag_thold = -1.0f; // KV cache defragmentation threshold - std::string rpc_servers = ""; // comma separated list of RPC servers ggml_backend_sched_eval_callback cb_eval = nullptr; void * cb_eval_user_data = nullptr; ggml_numa_strategy numa = GGML_NUMA_STRATEGY_DISABLED; + enum llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs enum llama_rope_scaling_type rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED; enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings // // sampling parameters struct llama_sampling_params sparams; - std::string model = ""; // model path - std::string model_draft = ""; // draft model for speculative decoding + std::string model = ""; // model path + std::string model_draft = ""; // draft model for speculative decoding std::string model_alias = "unknown"; // model alias - std::string model_url = ""; // model url to download - std::string hf_repo = ""; // HF repo - std::string hf_file = ""; // HF file + std::string model_url = ""; // model url to download + std::string hf_repo = ""; // HF repo + std::string hf_file = ""; // HF file std::string prompt = ""; - std::string prompt_file = ""; // store the external prompt file name - std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state - std::string input_prefix = ""; // string to prefix user inputs with - std::string input_suffix = ""; // string to suffix user inputs with - std::vector antiprompt; // string upon seeing which more user input is prompted - std::string logdir = ""; // directory in which to save YAML log files + std::string prompt_file = ""; // store the external prompt file name + std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state + std::string input_prefix = ""; // string to prefix user inputs with + std::string input_suffix = ""; // string to suffix user inputs with + std::string logdir = ""; // directory in which to save YAML log files std::string lookup_cache_static = ""; // path of static ngram cache file for lookup decoding std::string lookup_cache_dynamic = ""; // path of dynamic ngram cache file for lookup decoding - std::string logits_file = ""; // file for saving *all* logits + std::string logits_file = ""; // file for saving *all* logits + std::string rpc_servers = ""; // comma separated list of RPC servers + std::vector in_files; // all input files + std::vector antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts) std::vector kv_overrides; // TODO: avoid tuple, use struct @@ -146,36 +151,36 @@ struct gpt_params { std::vector control_vectors; // control vector with user defined scale + int32_t verbosity = 0; int32_t control_vector_layer_start = -1; // layer range for control vector int32_t control_vector_layer_end = -1; // layer range for control vector - int ppl_stride = 0; // stride for perplexity calculations. If left at 0, the pre-existing approach will be used. - int ppl_output_type = 0; // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line - // (which is more convenient to use for plotting) - // - bool hellaswag = false; // compute HellaSwag score over random tasks from datafile supplied in prompt - size_t hellaswag_tasks = 400; // number of tasks to use when computing the HellaSwag score + int32_t ppl_stride = 0; // stride for perplexity calculations. If left at 0, the pre-existing approach will be used. + int32_t ppl_output_type = 0; // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line + // (which is more convenient to use for plotting) + // + bool hellaswag = false; // compute HellaSwag score over random tasks from datafile supplied in prompt + size_t hellaswag_tasks = 400; // number of tasks to use when computing the HellaSwag score - bool winogrande = false; // compute Winogrande score over random tasks from datafile supplied in prompt - size_t winogrande_tasks= 0; // number of tasks to use when computing the Winogrande score. If 0, all tasks will be computed + bool winogrande = false; // compute Winogrande score over random tasks from datafile supplied in prompt + size_t winogrande_tasks = 0; // number of tasks to use when computing the Winogrande score. If 0, all tasks will be computed - bool multiple_choice = false; // compute TruthfulQA score over random tasks from datafile supplied in prompt - size_t multiple_choice_tasks = 0; // number of tasks to use when computing the TruthfulQA score. If 0, all tasks will be computed + bool multiple_choice = false; // compute TruthfulQA score over random tasks from datafile supplied in prompt + size_t multiple_choice_tasks = 0; // number of tasks to use when computing the TruthfulQA score. If 0, all tasks will be computed - bool kl_divergence = false; // compute KL divergence + bool kl_divergence = false; // compute KL divergence - bool random_prompt = false; // do not randomize prompt if none provided + bool usage = false; // print usage bool use_color = false; // use color to distinguish generations and inputs + bool special = false; // enable special token output bool interactive = false; // interactive mode - bool interactive_specials = false; // whether to allow special tokens from user, during interactive mode + bool interactive_first = false; // wait for user input immediately bool conversation = false; // conversation mode (does not print special tokens and suffix/prefix) - bool chatml = false; // chatml mode (used for models trained on chatml syntax) bool prompt_cache_all = false; // save user input and generations to prompt cache bool prompt_cache_ro = false; // open the prompt cache read-only and do not update it bool embedding = false; // get only sentence embedding - bool escape = false; // escape "\n", "\r", "\t", "\'", "\"", and "\\" - bool interactive_first = false; // wait for user input immediately + bool escape = true; // escape "\n", "\r", "\t", "\'", "\"", and "\\" bool multiline_input = false; // reverse the usage of `\` bool simple_io = false; // improves compatibility with subprocesses and limited consoles bool cont_batching = true; // insert new sequences for decoding on-the-fly @@ -183,7 +188,6 @@ struct gpt_params { bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix bool ignore_eos = false; // ignore generated EOS tokens - bool instruct = false; // instruction mode (used for Alpaca models) bool logits_all = false; // return logits for all tokens in the batch bool use_mmap = true; // use mmap for faster loads bool use_mlock = false; // use mlock to keep model in memory @@ -201,37 +205,102 @@ struct gpt_params { // multimodal models (see examples/llava) std::string mmproj = ""; // path to multimodal projector std::vector image; // path to image file(s) + + // server params + int32_t port = 8080; // server listens on this network port + int32_t timeout_read = 600; // http read timeout in seconds + int32_t timeout_write = timeout_read; // http write timeout in seconds + int32_t n_threads_http = -1; // number of threads to process HTTP requests + + std::string hostname = "127.0.0.1"; + std::string public_path = ""; + std::string chat_template = ""; + std::string system_prompt = ""; + + std::vector api_keys; + + std::string ssl_file_key = ""; + std::string ssl_file_cert = ""; + + bool endpoint_slots = true; + bool endpoint_metrics = false; + + bool log_json = false; + + std::string slot_save_path; + + // batched-bench params + bool is_pp_shared = false; + + std::vector n_pp; + std::vector n_tg; + std::vector n_pl; + + // retrieval params + std::vector context_files; // context files to embed + + int32_t chunk_size = 64; // chunk size for context embedding + + std::string chunk_separator = "\n"; // chunk separator for context embedding + + // passkey params + int32_t n_junk = 250; // number of times to repeat the junk text + int32_t i_pos = -1; // position of the passkey in the junk text + + // imatrix params + std::string out_file = "imatrix.dat"; // save the resulting imatrix to this file + + int32_t n_out_freq = 10; // output the imatrix every n_out_freq iterations + int32_t n_save_freq = 0; // save the imatrix every n_save_freq iterations + int32_t i_chunk = 0; // start processing from this chunk + + bool process_output = false; // collect data for the output tensor + bool compute_ppl = true; // whether to compute perplexity }; void gpt_params_handle_model_default(gpt_params & params); -bool parse_kv_override(const char * data, std::vector & overrides); +bool gpt_params_parse_ex (int argc, char ** argv, gpt_params & params); +bool gpt_params_parse (int argc, char ** argv, gpt_params & params); +bool gpt_params_find_arg (int argc, char ** argv, const std::string & arg, gpt_params & params, int & i, bool & invalid_param); +void gpt_params_print_usage(int argc, char ** argv, const gpt_params & params); -bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params); - -bool gpt_params_parse(int argc, char ** argv, gpt_params & params); - -void gpt_print_usage(int argc, char ** argv, const gpt_params & params); - -bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_params & params, int & i, bool & invalid_param); - -std::string get_system_info(const gpt_params & params); - -std::string gpt_random_prompt(std::mt19937 & rng); - -void process_escapes(std::string& input); - -bool validate_file_name(const std::string & filename); +std::string gpt_params_get_system_info(const gpt_params & params); // // String utils // -std::vector sampler_types_from_names(const std::vector & names, bool allow_alt_names); -std::vector sampler_types_from_chars(const std::string & names_string); std::vector string_split(std::string input, char separator); + std::string string_strip(const std::string & str); -std::string sampler_type_to_name_string(llama_sampler_type sampler_type); +std::string string_get_sortable_timestamp(); + +template +static std::vector string_split(const std::string & str, char delim) { + std::vector values; + std::istringstream str_stream(str); + std::string token; + while (std::getline(str_stream, token, delim)) { + T value; + std::istringstream token_stream(token); + token_stream >> value; + values.push_back(value); + } + return values; +} + +bool string_parse_kv_override(const char * data, std::vector & overrides); +void string_process_escapes(std::string & input); + +// +// Filesystem utils +// + +bool fs_validate_filename(const std::string & filename); +bool fs_create_directory_with_parents(const std::string & path); + +std::string fs_get_cache_directory(); // // Model utils @@ -303,28 +372,21 @@ std::string llama_detokenize_bpe( bool llama_should_add_bos_token(const llama_model * model); // -// YAML utils +// Chat template utils // -bool create_directory_with_parents(const std::string & path); -void dump_vector_float_yaml(FILE * stream, const char * prop_name, const std::vector & data); -void dump_vector_int_yaml(FILE * stream, const char * prop_name, const std::vector & data); -void dump_string_yaml_multiline(FILE * stream, const char * prop_name, const char * data); -std::string get_sortable_timestamp(); - -void dump_non_result_info_yaml( - FILE * stream, const gpt_params & params, const llama_context * lctx, - const std::string & timestamp, const std::vector & prompt_tokens, const char * model_desc); +// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid +bool llama_chat_verify_template(const std::string & tmpl); // // KV cache utils // // Dump the KV cache view with the number of sequences per cell. -void dump_kv_cache_view(const llama_kv_cache_view & view, int row_size = 80); +void llama_kv_cache_dump_view(const llama_kv_cache_view & view, int row_size = 80); // Dump the KV cache view showing individual sequences in each cell (long output). -void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size = 40); +void llama_kv_cache_dump_view_seqs(const llama_kv_cache_view & view, int row_size = 40); // // Embedding utils @@ -358,6 +420,20 @@ llama_control_vector_data llama_control_vector_load(const std::vector & data); +void yaml_dump_vector_int (FILE * stream, const char * prop_name, const std::vector & data); +void yaml_dump_string_multiline(FILE * stream, const char * prop_name, const char * data); + +void yaml_dump_non_result_info( + FILE * stream, const gpt_params & params, const llama_context * lctx, + const std::string & timestamp, const std::vector & prompt_tokens, const char * model_desc); + diff --git a/llama/ggml-alloc.c b/llama/ggml-alloc.c index 5bb17eaa..8d296910 100644 --- a/llama/ggml-alloc.c +++ b/llama/ggml-alloc.c @@ -1,5 +1,5 @@ /** - * llama.cpp - git 059031b8c40e1f4ba60586842c5b1ed3ddf61842 + * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28 * * MIT License * @@ -403,7 +403,7 @@ ggml_gallocr_t ggml_gallocr_new_n(ggml_backend_buffer_type_t * bufts, int n_bufs galloc->bufts = calloc(n_bufs, sizeof(ggml_backend_buffer_type_t)); GGML_ASSERT(galloc->bufts != NULL); - galloc->buffers = calloc(n_bufs, sizeof(ggml_backend_buffer_t) * n_bufs); + galloc->buffers = calloc(n_bufs, sizeof(ggml_backend_buffer_t)); GGML_ASSERT(galloc->buffers != NULL); galloc->buf_tallocs = calloc(n_bufs, sizeof(struct ggml_dyn_tallocr *)); @@ -776,7 +776,7 @@ static void ggml_gallocr_init_tensor(ggml_gallocr_t galloc, struct ggml_tensor * // this tensor was allocated without ggml-backend return; } - ggml_backend_view_init(galloc->buffers[buffer_id], tensor); + ggml_backend_view_init(tensor); } } else { if (tensor->data == NULL) { @@ -925,12 +925,12 @@ static bool alloc_tensor_range(struct ggml_context * ctx, if (t->view_src == NULL) { ggml_tallocr_alloc(&tallocr, t); } else if (t->buffer == NULL) { - ggml_backend_view_init(buffer, t); + ggml_backend_view_init(t); } } else { if (t->view_src != NULL && t->buffer == NULL) { // view of a pre-allocated tensor - ggml_backend_view_init(buffer, t); + ggml_backend_view_init(t); } } } diff --git a/llama/ggml-alloc.h b/llama/ggml-alloc.h index 05608614..b3647884 100644 --- a/llama/ggml-alloc.h +++ b/llama/ggml-alloc.h @@ -1,5 +1,5 @@ /** - * llama.cpp - git 059031b8c40e1f4ba60586842c5b1ed3ddf61842 + * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28 * * MIT License * diff --git a/llama/ggml-backend-impl.h b/llama/ggml-backend-impl.h index a0f3581c..91470140 100644 --- a/llama/ggml-backend-impl.h +++ b/llama/ggml-backend-impl.h @@ -1,5 +1,5 @@ /** - * llama.cpp - git 059031b8c40e1f4ba60586842c5b1ed3ddf61842 + * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28 * * MIT License * diff --git a/llama/ggml-backend.c b/llama/ggml-backend.c index ce6e9b23..9225f7ab 100644 --- a/llama/ggml-backend.c +++ b/llama/ggml-backend.c @@ -1,5 +1,5 @@ /** - * llama.cpp - git 059031b8c40e1f4ba60586842c5b1ed3ddf61842 + * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28 * * MIT License * @@ -182,7 +182,7 @@ void ggml_backend_buffer_reset(ggml_backend_buffer_t buffer) { bool ggml_backend_buffer_copy_tensor(const struct ggml_tensor * src, struct ggml_tensor * dst) { ggml_backend_buffer_t dst_buf = dst->view_src ? dst->view_src->buffer : dst->buffer; if (dst_buf->iface.cpy_tensor) { - return src->buffer->iface.cpy_tensor(dst_buf, src, dst); + return dst_buf->iface.cpy_tensor(dst_buf, src, dst); } return false; } @@ -1918,15 +1918,15 @@ ggml_backend_t ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, // utils -void ggml_backend_view_init(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) { +void ggml_backend_view_init(struct ggml_tensor * tensor) { GGML_ASSERT(tensor->buffer == NULL); GGML_ASSERT(tensor->view_src != NULL); GGML_ASSERT(tensor->view_src->buffer != NULL); GGML_ASSERT(tensor->view_src->data != NULL); - tensor->buffer = buffer; + tensor->buffer = tensor->view_src->buffer; tensor->data = (char *)tensor->view_src->data + tensor->view_offs; - ggml_backend_buffer_init_tensor(buffer, tensor); + ggml_backend_buffer_init_tensor(tensor->buffer, tensor); } void ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr) { @@ -1985,7 +1985,7 @@ static void graph_copy_init_tensor(struct ggml_hash_set hash_set, struct ggml_te struct ggml_tensor * dst = node_copies[id]; if (dst->view_src != NULL) { graph_copy_init_tensor(hash_set, node_copies, node_init, src->view_src); - ggml_backend_view_init(dst->view_src->buffer, dst); + ggml_backend_view_init(dst); } else { ggml_backend_tensor_copy(src, dst); diff --git a/llama/ggml-backend.h b/llama/ggml-backend.h index 31768e67..1041d35e 100644 --- a/llama/ggml-backend.h +++ b/llama/ggml-backend.h @@ -1,5 +1,5 @@ /** - * llama.cpp - git 059031b8c40e1f4ba60586842c5b1ed3ddf61842 + * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28 * * MIT License * @@ -251,7 +251,7 @@ extern "C" { // Tensor initialization GGML_API void ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr); - GGML_API void ggml_backend_view_init(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); + GGML_API void ggml_backend_view_init(struct ggml_tensor * tensor); #ifdef __cplusplus diff --git a/llama/ggml-common.h b/llama/ggml-common.h index 21939860..dbbaa5d1 100644 --- a/llama/ggml-common.h +++ b/llama/ggml-common.h @@ -1,5 +1,5 @@ /** - * llama.cpp - git 059031b8c40e1f4ba60586842c5b1ed3ddf61842 + * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28 * * MIT License * @@ -91,13 +91,8 @@ typedef sycl::half2 ggml_half2; // QK = number of values after dequantization // QK_K = super-block size -#ifdef GGML_QKK_64 -#define QK_K 64 -#define K_SCALE_SIZE 4 -#else #define QK_K 256 #define K_SCALE_SIZE 12 -#endif // GGML_QKK_64 #if defined(GGML_COMMON_DECL_CUDA) || defined(GGML_COMMON_DECL_HIP) || defined(GGML_COMMON_DECL_SYCL) // QR = QK / number of values before dequantization @@ -154,16 +149,17 @@ typedef sycl::half2 ggml_half2; #define QI1_S (QK_K / (4*QR1_S)) #define QR1_S 8 +#define QI1_M (QK_K / (4*QR1_M)) +#define QR1_M 8 + #define QI4_NL (QK4_NL / (4*QR4_NL)) #define QR4_NL 2 -#if QK_K == 64 -#define QI4_XS QI4_NL -#define QR4_XS QR4_NL -#else #define QI4_XS (QK_K / (4*QR4_XS)) #define QR4_XS 8 -#endif + +#define QI3_S (QK_K / (4*QR3_S)) +#define QR3_S 8 #endif // GGML_COMMON_DECL_CUDA || GGML_COMMON_DECL_HIP @@ -254,15 +250,6 @@ static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_half) + QK_K/16 + QK_K/4, "wro // weight is represented as x = a * q // 16 blocks of 16 elements each // Effectively 3.4375 bits per weight -#ifdef GGML_QKK_64 -typedef struct { - uint8_t hmask[QK_K/8]; // quants - high bit - uint8_t qs[QK_K/4]; // quants - low 2 bits - uint8_t scales[2]; - ggml_half d; // super-block scale -} block_q3_K; -static_assert(sizeof(block_q3_K) == sizeof(ggml_half) + QK_K / 4 + QK_K / 8 + 2, "wrong q3_K block size/padding"); -#else typedef struct { uint8_t hmask[QK_K/8]; // quants - high bit uint8_t qs[QK_K/4]; // quants - low 2 bits @@ -270,20 +257,11 @@ typedef struct { ggml_half d; // super-block scale } block_q3_K; static_assert(sizeof(block_q3_K) == sizeof(ggml_half) + QK_K / 4 + QK_K / 8 + 12, "wrong q3_K block size/padding"); -#endif // 4-bit quantization // 8 blocks of 32 elements each // weight is represented as x = a * q + b // Effectively 4.5 bits per weight -#ifdef GGML_QKK_64 -typedef struct { - ggml_half d[2]; // super-block scales/mins - uint8_t scales[2]; // 4-bit block scales/mins - uint8_t qs[QK_K/2]; // 4--bit quants -} block_q4_K; -static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_half) + QK_K/2 + 2, "wrong q4_K block size/padding"); -#else typedef struct { union { struct { @@ -296,21 +274,11 @@ typedef struct { uint8_t qs[QK_K/2]; // 4--bit quants } block_q4_K; static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_half) + K_SCALE_SIZE + QK_K/2, "wrong q4_K block size/padding"); -#endif // 5-bit quantization // 8 blocks of 32 elements each // weight is represented as x = a * q + b // Effectively 5.5 bits per weight -#ifdef GGML_QKK_64 -typedef struct { - ggml_half d; // super-block scale - int8_t scales[QK_K/16]; // 8-bit block scales - uint8_t qh[QK_K/8]; // quants, high bit - uint8_t qs[QK_K/2]; // quants, low 4 bits -} block_q5_K; -static_assert(sizeof(block_q5_K) == sizeof(ggml_half) + QK_K/2 + QK_K/8 + QK_K/16, "wrong q5_K block size/padding"); -#else typedef struct { union { struct { @@ -324,7 +292,6 @@ typedef struct { uint8_t qs[QK_K/2]; // quants, low 4 bits } block_q5_K; static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_half) + K_SCALE_SIZE + QK_K/2 + QK_K/8, "wrong q5_K block size/padding"); -#endif // 6-bit quantization // weight is represented as x = a * q @@ -382,11 +349,7 @@ typedef struct { static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_half) + 3*(QK_K/8), "wrong iq3_xxs block size/padding"); // 3.4375 bpw -#if QK_K == 64 -#define IQ3S_N_SCALE 2 -#else #define IQ3S_N_SCALE QK_K/64 -#endif typedef struct { ggml_half d; uint8_t qs[QK_K/4]; @@ -407,16 +370,9 @@ static_assert(sizeof(block_iq1_s) == sizeof(ggml_half) + QK_K/8 + QK_K/16, "wron typedef struct { uint8_t qs[QK_K/8]; // grid index, low 8 bits uint8_t qh[QK_K/16]; // grid index, high 3 bits + grid shift bit (for two groups of 8) -#if QK_K == 64 - ggml_half d; -#endif uint8_t scales[QK_K/32]; // 3-bit block scales (4-bit if QK_K == 64) } block_iq1_m; -#if QK_K == 64 -static_assert(sizeof(block_iq1_m) == QK_K/8 + QK_K/16 + QK_K/32 + sizeof(ggml_half), "wrong iq1_m block size/padding"); -#else static_assert(sizeof(block_iq1_m) == QK_K/8 + QK_K/16 + QK_K/32, "wrong iq1_m block size/padding"); -#endif // Used by IQ1_M quants typedef union { @@ -432,9 +388,6 @@ typedef struct { } block_iq4_nl; static_assert(sizeof(block_iq4_nl) == sizeof(ggml_half) + QK4_NL/2, "wrong iq4_nl block size/padding"); -#if QK_K == 64 -#define block_iq4_xs block_iq4_nl -#else typedef struct { ggml_half d; uint16_t scales_h; @@ -442,7 +395,6 @@ typedef struct { uint8_t qs[QK_K/2]; } block_iq4_xs; static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding"); -#endif #endif // GGML_COMMON_DECL #endif // GGML_COMMON_DECL diff --git a/llama/ggml-cuda.cu b/llama/ggml-cuda.cu index ff545ca7..a7b17d44 100644 --- a/llama/ggml-cuda.cu +++ b/llama/ggml-cuda.cu @@ -1,5 +1,5 @@ /** - * llama.cpp - git 059031b8c40e1f4ba60586842c5b1ed3ddf61842 + * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28 * * MIT License * @@ -69,19 +69,59 @@ #include #include #include +#include +#include #include #include static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size"); +static void ggml_cuda_default_log_callback(enum ggml_log_level level, const char * msg, void * user_data) { + GGML_UNUSED(level); + GGML_UNUSED(user_data); + fprintf(stderr, "%s", msg); +} + +ggml_log_callback ggml_cuda_log_callback = ggml_cuda_default_log_callback; +void * ggml_cuda_log_user_data = NULL; + +GGML_API void ggml_backend_cuda_log_set_callback(ggml_log_callback log_callback, void * user_data) { + ggml_cuda_log_callback = log_callback; + ggml_cuda_log_user_data = user_data; +} + +#define GGML_CUDA_LOG_INFO(...) ggml_cuda_log(GGML_LOG_LEVEL_INFO, __VA_ARGS__) +#define GGML_CUDA_LOG_WARN(...) ggml_cuda_log(GGML_LOG_LEVEL_WARN, __VA_ARGS__) +#define GGML_CUDA_LOG_ERROR(...) ggml_cuda_log(GGML_LOG_LEVEL_ERROR, __VA_ARGS__) + +GGML_ATTRIBUTE_FORMAT(2, 3) +static void ggml_cuda_log(enum ggml_log_level level, const char * format, ...) { + if (ggml_cuda_log_callback != NULL) { + va_list args; + va_start(args, format); + char buffer[128]; + int len = vsnprintf(buffer, 128, format, args); + if (len < 128) { + ggml_cuda_log_callback(level, buffer, ggml_cuda_log_user_data); + } else { + std::vector buffer2(len + 1); // vsnprintf adds a null terminator + va_end(args); + va_start(args, format); + vsnprintf(&buffer2[0], buffer2.size(), format, args); + ggml_cuda_log_callback(level, buffer2.data(), ggml_cuda_log_user_data); + } + va_end(args); + } +} + [[noreturn]] void ggml_cuda_error(const char * stmt, const char * func, const char * file, int line, const char * msg) { int id = -1; // in case cudaGetDevice fails cudaGetDevice(&id); - fprintf(stderr, "CUDA error: %s\n", msg); - fprintf(stderr, " current device: %d, in function %s at %s:%d\n", id, func, file, line); - fprintf(stderr, " %s\n", stmt); + GGML_CUDA_LOG_ERROR("CUDA error: %s\n", msg); + GGML_CUDA_LOG_ERROR(" current device: %d, in function %s at %s:%d\n", id, func, file, line); + GGML_CUDA_LOG_ERROR(" %s\n", stmt); // abort with GGML_ASSERT to get a stack trace GGML_ASSERT(!"CUDA error"); } @@ -105,6 +145,20 @@ int ggml_cuda_get_device() { return id; } +static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) { + ggml_cuda_set_device(device); +#if defined(GGML_USE_HIPBLAS) && defined(GGML_HIP_UMA) + auto res = hipMallocManaged(ptr, size); + if (res == hipSuccess) { + // if error we "need" to know why... + CUDA_CHECK(hipMemAdvise(*ptr, size, hipMemAdviseSetCoarseGrain, device)); + } + return res; +#else + return cudaMalloc(ptr, size); +#endif +} + static ggml_cuda_device_info ggml_cuda_init() { #ifdef __HIP_PLATFORM_AMD__ // Workaround for a rocBLAS bug when using multiple graphics cards: @@ -117,7 +171,7 @@ static ggml_cuda_device_info ggml_cuda_init() { cudaError_t err = cudaGetDeviceCount(&info.device_count); if (err != cudaSuccess) { - fprintf(stderr, "%s: failed to initialize " GGML_CUDA_NAME ": %s\n", __func__, cudaGetErrorString(err)); + GGML_CUDA_LOG_ERROR("%s: failed to initialize " GGML_CUDA_NAME ": %s\n", __func__, cudaGetErrorString(err)); return info; } @@ -125,16 +179,16 @@ static ggml_cuda_device_info ggml_cuda_init() { int64_t total_vram = 0; #if defined(GGML_CUDA_FORCE_MMQ) - fprintf(stderr, "%s: GGML_CUDA_FORCE_MMQ: yes\n", __func__); + GGML_CUDA_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ: yes\n", __func__); #else - fprintf(stderr, "%s: GGML_CUDA_FORCE_MMQ: no\n", __func__); + GGML_CUDA_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ: no\n", __func__); #endif #if defined(CUDA_USE_TENSOR_CORES) - fprintf(stderr, "%s: CUDA_USE_TENSOR_CORES: yes\n", __func__); + GGML_CUDA_LOG_INFO("%s: CUDA_USE_TENSOR_CORES: yes\n", __func__); #else - fprintf(stderr, "%s: CUDA_USE_TENSOR_CORES: no\n", __func__); + GGML_CUDA_LOG_INFO("%s: CUDA_USE_TENSOR_CORES: no\n", __func__); #endif - fprintf(stderr, "%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, info.device_count); + GGML_CUDA_LOG_INFO("%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, info.device_count); for (int id = 0; id < info.device_count; ++id) { int device_vmm = 0; @@ -155,7 +209,7 @@ static ggml_cuda_device_info ggml_cuda_init() { cudaDeviceProp prop; CUDA_CHECK(cudaGetDeviceProperties(&prop, id)); - fprintf(stderr, " Device %d: %s, compute capability %d.%d, VMM: %s\n", id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no"); + GGML_CUDA_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s\n", id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no"); info.default_tensor_split[id] = total_vram; total_vram += prop.totalGlobalMem; @@ -257,12 +311,12 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool { size_t look_ahead_size = (size_t) (1.05 * size); look_ahead_size = 256 * ((look_ahead_size + 255)/256); ggml_cuda_set_device(device); - CUDA_CHECK(cudaMalloc((void **) &ptr, look_ahead_size)); + CUDA_CHECK(ggml_cuda_device_malloc(&ptr, look_ahead_size, device)); *actual_size = look_ahead_size; pool_size += look_ahead_size; #ifdef DEBUG_CUDA_MALLOC - fprintf(stderr, "%s[%d]: %d buffers, max_size = %u MB, pool_size = %u MB, requested %u MB\n", __func__, device, nnz, - (uint32_t)(max_size/1024/1024), (uint32_t)(pool_size/1024/1024), (uint32_t)(size/1024/1024)); + GGML_CUDA_LOG_INFO("%s[%d]: %d buffers, max_size = %u MB, pool_size = %u MB, requested %u MB\n", __func__, device, nnz, + (uint32_t)(max_size / 1024 / 1024), (uint32_t)(pool_size / 1024 / 1024), (uint32_t)(size / 1024 / 1024)); #endif return ptr; } @@ -276,7 +330,7 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool { return; } } - fprintf(stderr, "WARNING: cuda buffer pool full, increase MAX_CUDA_BUFFERS\n"); + GGML_CUDA_LOG_WARN("Cuda buffer pool full, increase MAX_CUDA_BUFFERS\n"); ggml_cuda_set_device(device); CUDA_CHECK(cudaFree(ptr)); pool_size -= size; @@ -527,9 +581,11 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffe size = std::max(size, (size_t)1); // cudaMalloc returns null for size 0 void * dev_ptr; - cudaError_t err = cudaMalloc(&dev_ptr, size); + cudaError_t err = ggml_cuda_device_malloc(&dev_ptr, size, buft_ctx->device); if (err != cudaSuccess) { - fprintf(stderr, "%s: allocating %.2f MiB on device %d: cudaMalloc failed: %s\n", __func__, size/1024.0/1024.0, buft_ctx->device, cudaGetErrorString(err)); + // clear the error + cudaGetLastError(); + GGML_CUDA_LOG_ERROR("%s: allocating %.2f MiB on device %d: cudaMalloc failed: %s\n", __func__, size / 1024.0 / 1024.0, buft_ctx->device, cudaGetErrorString(err)); return nullptr; } @@ -607,88 +663,22 @@ GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device) { // cuda split buffer -static int64_t get_row_rounding(ggml_type type, const std::array & tensor_split) { - int64_t min_compute_capability = INT_MAX; - int64_t max_compute_capability = INT_MIN; +static int64_t get_row_rounding(const std::array & tensor_split) { + int64_t row_rounding = 0; for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) { - if (tensor_split[id] < (id + 1 < ggml_backend_cuda_get_device_count() ? tensor_split[id + 1] : 1.0f)) { - if (min_compute_capability > ggml_cuda_info().devices[id].cc) { - min_compute_capability = ggml_cuda_info().devices[id].cc; - } - if (max_compute_capability < ggml_cuda_info().devices[id].cc) { - max_compute_capability = ggml_cuda_info().devices[id].cc; - } + if (tensor_split[id] >= (id + 1 < ggml_backend_cuda_get_device_count() ? tensor_split[id + 1] : 1.0f)) { + continue; } - } -#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) - switch(type) { - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: - case GGML_TYPE_Q5_0: - case GGML_TYPE_Q5_1: - case GGML_TYPE_Q8_0: - return max_compute_capability >= CC_RDNA2 ? 128 : 64; - case GGML_TYPE_F16: - case GGML_TYPE_F32: - return 1; - case GGML_TYPE_Q2_K: - return max_compute_capability >= CC_RDNA2 ? 128 : 32; - case GGML_TYPE_Q3_K: - return min_compute_capability < CC_RDNA2 ? 128 : 64; - case GGML_TYPE_Q4_K: - case GGML_TYPE_Q5_K: - case GGML_TYPE_Q6_K: - case GGML_TYPE_IQ2_XXS: - case GGML_TYPE_IQ2_XS: - case GGML_TYPE_IQ2_S: - case GGML_TYPE_IQ3_XXS: - case GGML_TYPE_IQ1_S: - case GGML_TYPE_IQ1_M: - case GGML_TYPE_IQ4_NL: - case GGML_TYPE_IQ4_XS: - case GGML_TYPE_IQ3_S: - return max_compute_capability >= CC_RDNA2 ? 128 : 64; - default: - GGML_ASSERT(false); + const int cc = ggml_cuda_info().devices[id].cc; + row_rounding = std::max(row_rounding, (int64_t)get_mmq_y_host(cc, get_mmq_x_max_host(cc))); } -#else - switch(type) { - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: - return max_compute_capability >= CC_VOLTA ? 128 : 64; - case GGML_TYPE_Q5_0: - case GGML_TYPE_Q5_1: - case GGML_TYPE_Q8_0: - return 64; - case GGML_TYPE_F16: - case GGML_TYPE_F32: - return 1; - case GGML_TYPE_Q2_K: - case GGML_TYPE_Q3_K: - case GGML_TYPE_Q4_K: - case GGML_TYPE_Q5_K: - case GGML_TYPE_IQ2_XXS: - case GGML_TYPE_IQ2_XS: - case GGML_TYPE_IQ2_S: - case GGML_TYPE_IQ3_XXS: - case GGML_TYPE_IQ1_S: - case GGML_TYPE_IQ1_M: - case GGML_TYPE_IQ4_NL: - case GGML_TYPE_IQ4_XS: - case GGML_TYPE_IQ3_S: - return max_compute_capability >= CC_VOLTA ? 128 : 64; - case GGML_TYPE_Q6_K: - return 64; - default: - GGML_ASSERT(false); - } -#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) + return row_rounding; } static void get_row_split(int64_t * row_low, int64_t * row_high, const ggml_tensor * tensor, const std::array & tensor_split, int id) { const int64_t nrows = ggml_nrows(tensor); - const int64_t rounding = get_row_rounding(tensor->type, tensor_split); + const int64_t rounding = get_row_rounding(tensor_split); *row_low = id == 0 ? 0 : nrows*tensor_split[id]; *row_low -= *row_low % rounding; @@ -786,7 +776,7 @@ GGML_CALL static void ggml_backend_cuda_split_buffer_init_tensor(ggml_backend_bu // currently, init_tensor cannot fail, it needs to be fixed in ggml-backend first ggml_cuda_set_device(id); char * buf; - CUDA_CHECK(cudaMalloc(&buf, size)); + CUDA_CHECK(ggml_cuda_device_malloc((void**)&buf, size, id)); // set padding to 0 to avoid possible NaN values if (size > original_size) { @@ -1032,8 +1022,8 @@ static void * ggml_cuda_host_malloc(size_t size) { if (err != cudaSuccess) { // clear the error cudaGetLastError(); - fprintf(stderr, "%s: warning: failed to allocate %.2f MiB of pinned memory: %s\n", __func__, - size/1024.0/1024.0, cudaGetErrorString(err)); + GGML_CUDA_LOG_WARN("%s: failed to allocate %.2f MiB of pinned memory: %s\n", __func__, + size / 1024.0 / 1024.0, cudaGetErrorString(err)); return nullptr; } @@ -1473,7 +1463,7 @@ static void ggml_cuda_op_mul_mat( // for multi GPU, get the row boundaries from tensor split // and round to mul_mat_q tile sizes if (split) { - const int64_t rounding = get_row_rounding(src0->type, tensor_split); + const int64_t rounding = get_row_rounding(tensor_split); if (id != 0) { dev[id].row_low = ne01*tensor_split[id]; @@ -1844,7 +1834,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co } } #else - if (r2 == 1 && r3 == 1 && src0->nb[2]*src0->ne[2] == src0->nb[3] && src1->nb[2]*src1->ne[2] == src1->nb[3]) { + if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) { // there is no broadcast and src0, src1 are contiguous across dims 2, 3 // use cublasGemmStridedBatchedEx CUBLAS_CHECK( @@ -2276,7 +2266,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg break; case GGML_OP_MUL_MAT: if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) { - fprintf(stderr, "%s: cannot compute %s: src0->ne[3] = %" PRId64 ", src1->ne[3] = %" PRId64 " - fallback to CPU\n", __func__, dst->name, dst->src[0]->ne[3], dst->src[1]->ne[3]); + GGML_CUDA_LOG_ERROR("%s: cannot compute %s: src0->ne[3] = %" PRId64 ", src1->ne[3] = %" PRId64 " - fallback to CPU\n", __func__, dst->name, dst->src[0]->ne[3], dst->src[1]->ne[3]); return false; } else { ggml_cuda_mul_mat(ctx, dst->src[0], dst->src[1], dst); @@ -2330,7 +2320,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) { - fprintf(stderr, "%s: %s failed\n", __func__, ggml_op_desc(dst)); + GGML_CUDA_LOG_ERROR("%s: %s failed\n", __func__, ggml_op_desc(dst)); CUDA_CHECK(err); } @@ -2498,15 +2488,15 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t bool use_cuda_graph = true; bool cuda_graph_update_required = false; - // pointer to CUDA cpy kernel, which is required to identify + // vector of pointers to CUDA cpy kernels, which are required to identify // kernel parameters which need updated in the graph for each token - void * ggml_cuda_cpy_fn_ptr = nullptr; + std::vector ggml_cuda_cpy_fn_ptrs; if (cuda_ctx->cuda_graph->graph == nullptr) { if (ggml_cuda_info().devices[cuda_ctx->device].cc < CC_AMPERE) { cuda_ctx->cuda_graph->disable_due_to_gpu_arch = true; #ifndef NDEBUG - fprintf(stderr, "%s: disabling CUDA graphs due to GPU architecture\n", __func__); + GGML_CUDA_LOG_WARN("%s: disabling CUDA graphs due to GPU architecture\n", __func__); #endif } } @@ -2553,14 +2543,14 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t if (node->src[0] && ggml_backend_buffer_is_cuda_split(node->src[0]->buffer)) { use_cuda_graph = false; // Split buffers are not supported by CUDA graph capture #ifndef NDEBUG - fprintf(stderr, "%s: disabling CUDA graphs due to split buffer\n", __func__); + GGML_CUDA_LOG_WARN("%s: disabling CUDA graphs due to split buffer\n", __func__); #endif } if (node->op == GGML_OP_MUL_MAT_ID) { use_cuda_graph = false; // This node type is not supported by CUDA graph capture #ifndef NDEBUG - fprintf(stderr, "%s: disabling CUDA graphs due to mul_mat_id\n", __func__); + GGML_CUDA_LOG_WARN("%s: disabling CUDA graphs due to mul_mat_id\n", __func__); #endif } @@ -2569,16 +2559,17 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t // Changes in batch size or context size can cause changes to the grid size of some kernels. use_cuda_graph = false; #ifndef NDEBUG - fprintf(stderr, "%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]); + GGML_CUDA_LOG_WARN("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]); #endif } if (node->op == GGML_OP_CPY) { // store the copy op parameter which changes with each token. cuda_ctx->cuda_graph->updated_kernel_arg.push_back((char **) &(node->src[1]->data)); - if (ggml_cuda_cpy_fn_ptr == nullptr) { - // store a pointer to the copy op CUDA kernel to identify it later - ggml_cuda_cpy_fn_ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]); + // store a pointer to each copy op CUDA kernel to identify it later + void * ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]); + if (std::find(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), ptr) == ggml_cuda_cpy_fn_ptrs.end()) { + ggml_cuda_cpy_fn_ptrs.push_back(ptr); } } @@ -2597,7 +2588,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t if (cuda_ctx->cuda_graph->number_consecutive_updates >= 4) { cuda_ctx->cuda_graph->disable_due_to_too_many_updates = true; #ifndef NDEBUG - fprintf(stderr, "%s: disabling CUDA graphs due to too many consecutive updates\n", __func__); + GGML_CUDA_LOG_WARN("%s: disabling CUDA graphs due to too many consecutive updates\n", __func__); #endif } } @@ -2635,7 +2626,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t bool ok = ggml_cuda_compute_forward(*cuda_ctx, node); if (!ok) { - fprintf(stderr, "%s: error: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op)); + GGML_CUDA_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op)); } GGML_ASSERT(ok); } @@ -2654,7 +2645,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t use_cuda_graph = false; cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture = true; #ifndef NDEBUG - fprintf(stderr, "%s: disabling CUDA graphs due to failed graph capture\n", __func__); + GGML_CUDA_LOG_WARN("%s: disabling CUDA graphs due to failed graph capture\n", __func__); #endif } else { graph_evaluated_or_captured = true; // CUDA graph has been captured @@ -2675,10 +2666,8 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t if (cuda_graph_update_required) { // Extract nodes from graph - if (cuda_ctx->cuda_graph->num_nodes == 0) { - // First call with null argument gets number of nodes in graph - CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, nullptr, &cuda_ctx->cuda_graph->num_nodes)); - } + // First call with null argument gets number of nodes in graph + CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, nullptr, &cuda_ctx->cuda_graph->num_nodes)); // Subsequent call with non-null argument gets nodes cuda_ctx->cuda_graph->nodes.resize(cuda_ctx->cuda_graph->num_nodes); cuda_ctx->cuda_graph->params.resize(cuda_ctx->cuda_graph->num_nodes); @@ -2708,7 +2697,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t if (!cuda_graph_update_required) { // on update steps, the live parameters will already be captured int k = 0; for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) { - if (cuda_ctx->cuda_graph->params[i].func == ggml_cuda_cpy_fn_ptr) { + if(count(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), cuda_ctx->cuda_graph->params[i].func) > 0) { char ** updated_kernel_arg_ptr = cuda_ctx->cuda_graph->updated_kernel_arg.at(k++); cuda_ctx->cuda_graph->params[i].kernelParams[1] = updated_kernel_arg_ptr; CUDA_CHECK(cudaGraphKernelNodeSetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i])); @@ -2721,7 +2710,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info); if (stat == cudaErrorGraphExecUpdateFailure) { #ifndef NDEBUG - fprintf(stderr, "%s: CUDA graph update failed\n", __func__); + GGML_CUDA_LOG_ERROR("%s: CUDA graph update failed\n", __func__); #endif // The pre-existing graph exec cannot be updated due to violated constraints // so instead clear error and re-instantiate @@ -2859,7 +2848,9 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_OP_CONT: case GGML_OP_DIAG_MASK_INF: case GGML_OP_SOFT_MAX: + return true; case GGML_OP_ROPE: + return ggml_is_contiguous(op->src[0]); case GGML_OP_IM2COL: case GGML_OP_POOL_2D: case GGML_OP_SUM_ROWS: @@ -2876,10 +2867,14 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) return op->src[0]->ne[0] == 64 || op->src[0]->ne[0] == 128; #else - if (op->src[0]->ne[0] == 64 || op->src[0]->ne[0] == 128) { + if (op->src[0]->ne[0] == 128) { return true; } - return ggml_cuda_info().devices[cuda_ctx->device].cc >= CC_VOLTA; + if (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) { + return true; + } + return ggml_cuda_info().devices[cuda_ctx->device].cc >= CC_VOLTA && + op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16; #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) default: return false; @@ -2978,13 +2973,13 @@ static ggml_guid_t ggml_backend_cuda_guid() { GGML_CALL ggml_backend_t ggml_backend_cuda_init(int device) { if (device < 0 || device >= ggml_backend_cuda_get_device_count()) { - fprintf(stderr, "%s: error: invalid device %d\n", __func__, device); + GGML_CUDA_LOG_ERROR("%s: invalid device %d\n", __func__, device); return nullptr; } ggml_backend_cuda_context * ctx = new ggml_backend_cuda_context(device); if (ctx == nullptr) { - fprintf(stderr, "%s: error: failed to allocate context\n", __func__); + GGML_CUDA_LOG_ERROR("%s: failed to allocate context\n", __func__); return nullptr; } @@ -3028,8 +3023,8 @@ GGML_CALL bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size // clear the error cudaGetLastError(); - fprintf(stderr, "%s: warning: failed to register %.2f MiB of pinned memory: %s\n", __func__, - size/1024.0/1024.0, cudaGetErrorString(err)); + GGML_CUDA_LOG_WARN("%s: failed to register %.2f MiB of pinned memory: %s\n", __func__, + size / 1024.0 / 1024.0, cudaGetErrorString(err)); return false; } return true; diff --git a/llama/ggml-cuda.h b/llama/ggml-cuda.h index 3983f201..821853e8 100644 --- a/llama/ggml-cuda.h +++ b/llama/ggml-cuda.h @@ -1,5 +1,5 @@ /** - * llama.cpp - git 059031b8c40e1f4ba60586842c5b1ed3ddf61842 + * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28 * * MIT License * @@ -66,6 +66,7 @@ GGML_API GGML_CALL void ggml_backend_cuda_get_device_memory(int device, size_t * GGML_API GGML_CALL bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size); GGML_API GGML_CALL void ggml_backend_cuda_unregister_host_buffer(void * buffer); +GGML_API void ggml_backend_cuda_log_set_callback(ggml_log_callback log_callback, void * user_data); #ifdef __cplusplus } #endif diff --git a/llama/ggml-cuda/common.cuh b/llama/ggml-cuda/common.cuh index 8f6fd71c..90a0a81e 100644 --- a/llama/ggml-cuda/common.cuh +++ b/llama/ggml-cuda/common.cuh @@ -79,13 +79,8 @@ #define cudaHostRegisterReadOnly hipHostRegisterReadOnly #define cudaHostUnregister hipHostUnregister #define cudaLaunchHostFunc hipLaunchHostFunc -#ifdef GGML_HIP_UMA -#define cudaMalloc hipMallocManaged -#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size) -#else #define cudaMalloc hipMalloc #define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault) -#endif #define cudaMemcpy hipMemcpy #define cudaMemcpyAsync hipMemcpyAsync #define cudaMemcpyPeerAsync hipMemcpyPeerAsync @@ -165,7 +160,7 @@ #endif #define MMVQ_MAX_BATCH_SIZE 8 // max batch size to use MMVQ kernels -#define MMQ_MAX_BATCH_SIZE 32 // max batch size to use MMQ kernels when tensor cores are available +#define MMQ_MAX_BATCH_SIZE 64 // max batch size to use MMQ kernels when tensor cores are available #define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses @@ -489,6 +484,161 @@ static __device__ __forceinline__ float get_alibi_slope( return powf(base, exph); } +template +struct ggml_cuda_type_traits; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = 1; + static constexpr int qr = 1; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK4_0; + static constexpr int qr = QR4_0; + static constexpr int qi = QI4_0; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK4_1; + static constexpr int qr = QR4_1; + static constexpr int qi = QI4_1; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK5_0; + static constexpr int qr = QR5_0; + static constexpr int qi = QI5_0; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK5_1; + static constexpr int qr = QR5_1; + static constexpr int qi = QI5_1; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK8_0; + static constexpr int qr = QR8_0; + static constexpr int qi = QI8_0; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_K; + static constexpr int qr = QR2_K; + static constexpr int qi = QI2_K; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_K; + static constexpr int qr = QR3_K; + static constexpr int qi = QI3_K; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_K; + static constexpr int qr = QR4_K; + static constexpr int qi = QI4_K; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_K; + static constexpr int qr = QR5_K; + static constexpr int qi = QI5_K; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_K; + static constexpr int qr = QR6_K; + static constexpr int qi = QI6_K; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_K; + static constexpr int qr = QR2_XXS; + static constexpr int qi = QI2_XXS; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_K; + static constexpr int qr = QR2_XS; + static constexpr int qi = QI2_XS; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_K; + static constexpr int qr = QR2_S; + static constexpr int qi = QI2_S; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_K; + static constexpr int qr = QR3_XXS; + static constexpr int qi = QI3_XXS; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_K; + static constexpr int qr = QR1_S; + static constexpr int qi = QI1_S; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_K; + static constexpr int qr = QR1_M; + static constexpr int qi = QI1_M; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK4_NL; + static constexpr int qr = QR4_NL; + static constexpr int qi = QI4_NL; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_K; + static constexpr int qr = QR4_XS; + static constexpr int qi = QI4_XS; +}; + +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_K; + static constexpr int qr = QR3_S; + static constexpr int qi = QI3_S; +}; + +static int get_mmq_x_max_host(const int cc) { +#ifdef CUDA_USE_TENSOR_CORES + return cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? MMQ_MAX_BATCH_SIZE : 64; +#else + return cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? 128 : 64; +#endif // CUDA_USE_TENSOR_CORES +} + +// Round rows to this value for --split-mode row: +static int get_mmq_y_host(const int cc, const int mmq_x) { + return cc >= CC_VOLTA && mmq_x >= 32 ? 128 : 64; +} + ////////////////////// struct ggml_cuda_device_info { diff --git a/llama/ggml-cuda/concat.cu b/llama/ggml-cuda/concat.cu index 2941d2f1..dac10ec3 100644 --- a/llama/ggml-cuda/concat.cu +++ b/llama/ggml-cuda/concat.cu @@ -1,15 +1,69 @@ #include "concat.cuh" -static __global__ void concat_f32(const float * x,const float * y, float * dst, const int ne0, const int ne02) { +// contiguous kernels +static __global__ void concat_f32_dim0(const float * x, const float * y, float * dst, const int ne0, const int ne00) { int nidx = threadIdx.x + blockIdx.x * blockDim.x; if (nidx >= ne0) { return; } - // operation + int offset_dst = nidx + blockIdx.y * ne0 + blockIdx.z * ne0 * gridDim.y; + + if (nidx < ne00) { // src0 + int offset_src = + nidx + + blockIdx.y * ne00 + + blockIdx.z * ne00 * gridDim.y; + dst[offset_dst] = x[offset_src]; + } else { + int offset_src = + (nidx - ne00) + + blockIdx.y * (ne0 - ne00) + + blockIdx.z * (ne0 - ne00) * gridDim.y; + dst[offset_dst] = y[offset_src]; + } +} + +static __global__ void concat_f32_dim1(const float * x, const float * y, float * dst, const int ne0, const int ne01) { + int nidx = threadIdx.x + blockIdx.x * blockDim.x; + if (nidx >= ne0) { + return; + } + + int offset_dst = + nidx + + blockIdx.y * ne0 + + blockIdx.z * ne0 * gridDim.y; + + if (blockIdx.y < ne01) { // src0 + int offset_src = + nidx + + blockIdx.y * ne0 + + blockIdx.z * ne0 * ne01; + dst[offset_dst] = x[offset_src]; + } else { + int offset_src = + nidx + + (blockIdx.y - ne01) * ne0 + + blockIdx.z * ne0 * (gridDim.y - ne01); + dst[offset_dst] = y[offset_src]; + } +} + +static __global__ void concat_f32_dim2(const float * x, const float * y, float * dst, const int ne0, const int ne02) { + int nidx = threadIdx.x + blockIdx.x * blockDim.x; + if (nidx >= ne0) { + return; + } + + int offset_dst = + nidx + + blockIdx.y * ne0 + + blockIdx.z * ne0 * gridDim.y; + if (blockIdx.z < ne02) { // src0 int offset_src = nidx + @@ -25,25 +79,118 @@ static __global__ void concat_f32(const float * x,const float * y, float * dst, } } -static void concat_f32_cuda(const float * x, const float * y, float * dst, const int ne0, int ne1, int ne2, int ne02, cudaStream_t stream) { +static void concat_f32_cuda(const float * x, const float * y, float * dst, int ne00, int ne01, int ne02, int ne0, int ne1, int ne2, int dim, cudaStream_t stream) { int num_blocks = (ne0 + CUDA_CONCAT_BLOCK_SIZE - 1) / CUDA_CONCAT_BLOCK_SIZE; dim3 gridDim(num_blocks, ne1, ne2); - concat_f32<<>>(x, y, dst, ne0, ne02); + if (dim == 0) { + concat_f32_dim0<<>>(x, y, dst, ne0, ne00); + return; + } + if (dim == 1) { + concat_f32_dim1<<>>(x, y, dst, ne0, ne01); + return; + } + concat_f32_dim2<<>>(x, y, dst, ne0, ne02); } +// non-contiguous kernel (slow) +static __global__ void concat_f32_non_cont( + const char * src0, + const char * src1, + char * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne03, + uint64_t nb00, + uint64_t nb01, + uint64_t nb02, + uint64_t nb03, + int64_t /*ne10*/, + int64_t /*ne11*/, + int64_t /*ne12*/, + int64_t /*ne13*/, + uint64_t nb10, + uint64_t nb11, + uint64_t nb12, + uint64_t nb13, + int64_t ne0, + int64_t /*ne1*/, + int64_t /*ne2*/, + int64_t /*ne3*/, + uint64_t nb0, + uint64_t nb1, + uint64_t nb2, + uint64_t nb3, + int32_t dim) { + const int64_t i3 = blockIdx.z; + const int64_t i2 = blockIdx.y; + const int64_t i1 = blockIdx.x; + + int64_t o[4] = {0, 0, 0, 0}; + o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03)); + + const float * x; + + for (int i0 = threadIdx.x; i0 < ne0; i0 += blockDim.x) { + if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) { + x = (const float *)(src0 + (i3 )*nb03 + (i2 )*nb02 + (i1 )*nb01 + (i0 )*nb00); + } else { + x = (const float *)(src1 + (i3 - o[3])*nb13 + (i2 - o[2])*nb12 + (i1 - o[1])*nb11 + (i0 - o[0])*nb10); + } + + float * y = (float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + *y = *x; + } +} + + void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; - const float * src0_d = (const float *)src0->data; - const float * src1_d = (const float *)src1->data; - float * dst_d = (float *)dst->data; + cudaStream_t stream = ctx.stream(); + const int32_t dim = ((int32_t *) dst->op_params)[0]; + GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); - for (int i3 = 0; i3 < dst->ne[3]; i3++) { - concat_f32_cuda(src0_d + i3 * (src0->nb[3] / 4), src1_d + i3 * (src1->nb[3] / 4), dst_d + i3 * (dst->nb[3] / 4), dst->ne[0], dst->ne[1], dst->ne[2], src0->ne[2], stream); + if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) { + const float * src0_d = (const float *)src0->data; + const float * src1_d = (const float *)src1->data; + + float * dst_d = (float *)dst->data; + + if (dim != 3) { + for (int i3 = 0; i3 < dst->ne[3]; i3++) { + concat_f32_cuda( + src0_d + i3 * (src0->nb[3] / 4), + src1_d + i3 * (src1->nb[3] / 4), + dst_d + i3 * ( dst->nb[3] / 4), + src0->ne[0], src0->ne[1], src0->ne[2], + dst->ne[0], dst->ne[1], dst->ne[2], dim, stream); + } + } else { + const size_t size0 = ggml_nbytes(src0); + const size_t size1 = ggml_nbytes(src1); + + CUDA_CHECK(cudaMemcpyAsync(dst_d, src0_d, size0, cudaMemcpyDeviceToDevice, stream)); + CUDA_CHECK(cudaMemcpyAsync(dst_d + size0/4, src1_d, size1, cudaMemcpyDeviceToDevice, stream)); + } + } else { + dim3 grid_dim(dst->ne[1], dst->ne[2], dst->ne[3]); + concat_f32_non_cont<<>>( + (const char *)src0->data, + (const char *)src1->data, + ( char *)dst->data, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], + src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3], + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], dim); } } diff --git a/llama/ggml-cuda/convert.cu b/llama/ggml-cuda/convert.cu index 830e2d75..c0a44470 100644 --- a/llama/ggml-cuda/convert.cu +++ b/llama/ggml-cuda/convert.cu @@ -131,7 +131,6 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t const block_q2_K * x = (const block_q2_K *) vx; const int64_t tid = threadIdx.x; -#if QK_K == 256 const int64_t n = tid/32; const int64_t l = tid - 32*n; const int64_t is = 8*n + l/16; @@ -145,17 +144,6 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t y[l+32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is+2] >> 4); y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4); y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4); -#else - const int64_t is = tid/16; // 0 or 1 - const int64_t il = tid%16; // 0...15 - const uint8_t q = x[i].qs[il] >> (2*is); - dst_t * y = yy + i*QK_K + 16*is + il; - float dall = __low2half(x[i].dm); - float dmin = __high2half(x[i].dm); - y[ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4); - y[32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+2] >> 4); -#endif - } template @@ -164,7 +152,6 @@ static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t const int64_t i = blockIdx.x; const block_q3_K * x = (const block_q3_K *) vx; -#if QK_K == 256 const int64_t r = threadIdx.x/4; const int64_t tid = r/2; const int64_t is0 = r%2; @@ -188,31 +175,8 @@ static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t const uint8_t * hm = x[i].hmask; for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4)); -#else - const int64_t tid = threadIdx.x; - const int64_t is = tid/16; // 0 or 1 - const int64_t il = tid%16; // 0...15 - const int64_t im = il/8; // 0...1 - const int64_t in = il%8; // 0...7 - - dst_t * y = yy + i*QK_K + 16*is + il; - - const uint8_t q = x[i].qs[il] >> (2*is); - const uint8_t h = x[i].hmask[in] >> (2*is + im); - const float d = (float)x[i].d; - - if (is == 0) { - y[ 0] = d * ((x[i].scales[0] & 0xF) - 8) * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4)); - y[32] = d * ((x[i].scales[1] & 0xF) - 8) * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4)); - } else { - y[ 0] = d * ((x[i].scales[0] >> 4) - 8) * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4)); - y[32] = d * ((x[i].scales[1] >> 4) - 8) * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4)); - } -#endif - } -#if QK_K == 256 static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) { if (j < 4) { d = q[j] & 63; m = q[j + 4] & 63; @@ -221,7 +185,6 @@ static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4); } } -#endif template static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy) { @@ -229,7 +192,6 @@ static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, dst_t const int64_t i = blockIdx.x; -#if QK_K == 256 // assume 32 threads const int64_t tid = threadIdx.x; const int64_t il = tid/8; @@ -253,15 +215,6 @@ static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, dst_t y[l + 0] = d1 * (q[l] & 0xF) - m1; y[l +32] = d2 * (q[l] >> 4) - m2; } -#else - const int64_t tid = threadIdx.x; - const uint8_t * q = x[i].qs; - dst_t * y = yy + i*QK_K; - const float d = (float)x[i].dm[0]; - const float m = (float)x[i].dm[1]; - y[tid+ 0] = d * (x[i].scales[0] & 0xF) * (q[tid] & 0xF) - m * (x[i].scales[0] >> 4); - y[tid+32] = d * (x[i].scales[1] & 0xF) * (q[tid] >> 4) - m * (x[i].scales[1] >> 4); -#endif } template @@ -270,7 +223,6 @@ static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, dst_t const int64_t i = blockIdx.x; -#if QK_K == 256 // assume 64 threads - this is very slightly better than the one below const int64_t tid = threadIdx.x; const int64_t il = tid/16; // il is in 0...3 @@ -297,18 +249,6 @@ static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, dst_t hm <<= 1; y[32] = d2 * ((ql[ 0] >> 4) + (qh[ 0] & hm ? 16 : 0)) - m2; y[33] = d2 * ((ql[ 1] >> 4) + (qh[ 1] & hm ? 16 : 0)) - m2; -#else - const int64_t tid = threadIdx.x; - const uint8_t q = x[i].qs[tid]; - const int64_t im = tid/8; // 0...3 - const int64_t in = tid%8; // 0...7 - const int64_t is = tid/16; // 0 or 1 - const uint8_t h = x[i].qh[in] >> im; - const float d = x[i].d; - dst_t * y = yy + i*QK_K + tid; - y[ 0] = d * x[i].scales[is+0] * ((q & 0xF) - ((h >> 0) & 1 ? 0 : 16)); - y[32] = d * x[i].scales[is+2] * ((q >> 4) - ((h >> 4) & 1 ? 0 : 16)); -#endif } template @@ -316,7 +256,6 @@ static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t const block_q6_K * x = (const block_q6_K *) vx; const int64_t i = blockIdx.x; -#if QK_K == 256 // assume 64 threads - this is very slightly better than the one below const int64_t tid = threadIdx.x; @@ -336,24 +275,6 @@ static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t y[32] = d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32); y[64] = d * sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32); y[96] = d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32); -#else - - // assume 32 threads - const int64_t tid = threadIdx.x; - const int64_t ip = tid/16; // 0 or 1 - const int64_t il = tid - 16*ip; // 0...15 - - dst_t * y = yy + i*QK_K + 16*ip + il; - - const float d = x[i].d; - - const uint8_t ql = x[i].ql[16*ip + il]; - const uint8_t qh = x[i].qh[il] >> (2*ip); - const int8_t * sc = x[i].scales; - - y[ 0] = d * sc[ip+0] * ((int8_t)((ql & 0xF) | (((qh >> 0) & 3) << 4)) - 32); - y[32] = d * sc[ip+2] * ((int8_t)((ql >> 4) | (((qh >> 4) & 3) << 4)) - 32); -#endif } template @@ -363,7 +284,6 @@ static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, ds const block_iq2_xxs * x = (const block_iq2_xxs *) vx; const int64_t tid = threadIdx.x; -#if QK_K == 256 const int64_t il = tid/8; // 0...3 const int64_t ib = tid%8; // 0...7 dst_t * y = yy + i*QK_K + 32*ib + 8*il; @@ -374,10 +294,6 @@ static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, ds const float d = (float)x[i].d * (0.5f + (aux32 >> 28)) * 0.25f; const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127]; for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); -#else - NO_DEVICE_CODE; -#endif - } template @@ -387,7 +303,6 @@ static __global__ void dequantize_block_iq2_xs(const void * __restrict__ vx, dst const block_iq2_xs * x = (const block_iq2_xs *) vx; const int64_t tid = threadIdx.x; -#if QK_K == 256 const int64_t il = tid/8; // 0...3 const int64_t ib = tid%8; // 0...7 dst_t * y = yy + i*QK_K + 32*ib + 8*il; @@ -396,10 +311,6 @@ static __global__ void dequantize_block_iq2_xs(const void * __restrict__ vx, dst const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f; const uint8_t signs = ksigns_iq2xs[q2[il] >> 9]; for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); -#else - NO_DEVICE_CODE; -#endif - } template @@ -409,7 +320,6 @@ static __global__ void dequantize_block_iq2_s(const void * __restrict__ vx, dst_ const block_iq2_s * x = (const block_iq2_s *) vx; const int64_t tid = threadIdx.x; -#if QK_K == 256 const int64_t il = tid/8; // 0...3 const int64_t ib = tid%8; // 0...7 dst_t * y = yy + i*QK_K + 32*ib + 8*il; @@ -417,10 +327,6 @@ static __global__ void dequantize_block_iq2_s(const void * __restrict__ vx, dst_ const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f; const uint8_t signs = x[i].qs[QK_K/8+4*ib+il]; for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); -#else - NO_DEVICE_CODE; -#endif - } template @@ -430,7 +336,6 @@ static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, ds const block_iq3_xxs * x = (const block_iq3_xxs *) vx; const int64_t tid = threadIdx.x; -#if QK_K == 256 const int64_t il = tid/8; // 0...3 const int64_t ib = tid%8; // 0...7 dst_t * y = yy + i*QK_K + 32*ib + 8*il; @@ -445,10 +350,6 @@ static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, ds y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f); y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f); } -#else - NO_DEVICE_CODE; -#endif - } template @@ -458,7 +359,6 @@ static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_ const block_iq3_s * x = (const block_iq3_s *) vx; const int64_t tid = threadIdx.x; -#if QK_K == 256 const int64_t il = tid/8; // 0...3 const int64_t ib = tid%8; // 0...7 dst_t * y = yy + i*QK_K + 32*ib + 8*il; @@ -471,10 +371,6 @@ static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_ y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f); y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f); } -#else - NO_DEVICE_CODE; -#endif - } template @@ -484,7 +380,6 @@ static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_ const block_iq1_s * x = (const block_iq1_s *) vx; const int64_t tid = threadIdx.x; -#if QK_K == 256 const int64_t il = tid/8; // 0...3 const int64_t ib = tid%8; // 0...7 dst_t * y = yy + i*QK_K + 32*ib + 8*il; @@ -497,10 +392,6 @@ static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_ for (int j = 0; j < 8; ++j) { y[j] = d * (q[j] + delta); } -#else - NO_DEVICE_CODE; -#endif - } template @@ -510,7 +401,6 @@ static __global__ void dequantize_block_iq1_m(const void * __restrict__ vx, dst_ const block_iq1_m * x = (const block_iq1_m *) vx; const int64_t tid = threadIdx.x; -#if QK_K == 256 const int64_t il = tid/8; // 0...3 const int64_t ib = tid%8; // 0...7 dst_t * y = yy + i*QK_K + 32*ib + 8*il; @@ -527,13 +417,8 @@ static __global__ void dequantize_block_iq1_m(const void * __restrict__ vx, dst_ for (int j = 0; j < 8; ++j) { y[j] = d * (q[j] + delta); } -#else - NO_DEVICE_CODE; -#endif - } - template static __global__ void dequantize_block_iq4_nl(const void * __restrict__ vx, dst_t * __restrict__ yy) { @@ -550,10 +435,8 @@ static __global__ void dequantize_block_iq4_nl(const void * __restrict__ vx, dst y[j+ 0] = d * kvalues_iq4nl[q4[j] & 0xf]; y[j+16] = d * kvalues_iq4nl[q4[j] >> 4]; } - } -#if QK_K != 64 template static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) { const int64_t i = blockIdx.x; @@ -570,7 +453,6 @@ static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst y[j+16] = d * kvalues_iq4nl[q4[j] >> 4]; } } -#endif template static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) { @@ -592,21 +474,13 @@ static void dequantize_block_q8_0_f16_cuda(const void * __restrict__ vx, half * template static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { const int nb = k / QK_K; -#if QK_K == 256 dequantize_block_q2_K<<>>(vx, y); -#else - dequantize_block_q2_K<<>>(vx, y); -#endif } template static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { const int nb = k / QK_K; -#if QK_K == 256 dequantize_block_q3_K<<>>(vx, y); -#else - dequantize_block_q3_K<<>>(vx, y); -#endif } template @@ -632,21 +506,13 @@ static void dequantize_row_q4_K_cuda(const void * vx, dst_t * y, const int64_t k template static void dequantize_row_q5_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { const int nb = k / QK_K; -#if QK_K == 256 dequantize_block_q5_K<<>>(vx, y); -#else - dequantize_block_q5_K<<>>(vx, y); -#endif } template static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { const int nb = k / QK_K; -#if QK_K == 256 dequantize_block_q6_K<<>>(vx, y); -#else - dequantize_block_q6_K<<>>(vx, y); -#endif } template @@ -700,11 +566,7 @@ static void dequantize_row_iq1_m_cuda(const void * vx, dst_t * y, const int64_t template static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { const int nb = (k + QK_K - 1) / QK_K; -#if QK_K == 64 - dequantize_block_iq4_nl<<>>(vx, y); -#else dequantize_block_iq4_xs<<>>(vx, y); -#endif } template diff --git a/llama/ggml-cuda/dmmv.cu b/llama/ggml-cuda/dmmv.cu index 7313e3e1..174489e0 100644 --- a/llama/ggml-cuda/dmmv.cu +++ b/llama/ggml-cuda/dmmv.cu @@ -22,7 +22,6 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, float tmp = 0; // partial sum for thread in warp -#if QK_K == 256 const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...15 const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1 @@ -71,37 +70,6 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, tmp += dall * sum1 - dmin * sum2; } -#else - const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15 or 0...7 - const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); // 0....1 or 0...3 - const int offset = tid * K_QUANTS_PER_ITERATION; - - uint32_t uaux[2]; - const uint8_t * d = (const uint8_t *)uaux; - - for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { - - const float * y = yy + i * QK_K + offset; - const uint8_t * q = x[i].qs + offset; - const uint32_t * s = (const uint32_t *)x[i].scales; - - uaux[0] = s[0] & 0x0f0f0f0f; - uaux[1] = (s[0] >> 4) & 0x0f0f0f0f; - - const float2 dall = __half22float2(x[i].dm); - - float sum1 = 0, sum2 = 0; - for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) { - const uint8_t ql = q[l]; - sum1 += y[l+ 0] * d[0] * ((ql >> 0) & 3) - + y[l+16] * d[1] * ((ql >> 2) & 3) - + y[l+32] * d[2] * ((ql >> 4) & 3) - + y[l+48] * d[3] * ((ql >> 6) & 3); - sum2 += y[l+0] * d[4] + y[l+16] * d[5] + y[l+32] * d[6] + y[l+48] * d[7]; - } - tmp += dall.x * sum1 - dall.y * sum2; - } -#endif // sum up partial sums and write back result tmp = warp_reduce_sum(tmp); @@ -123,8 +91,6 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * __restrict__ vx, float tmp = 0; // partial sum for thread in warp -#if QK_K == 256 - const uint16_t kmask1 = 0x0303; const uint16_t kmask2 = 0x0f0f; @@ -175,34 +141,6 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * __restrict__ vx, tmp += d * sum; } -#else - - const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15 or 0...7 - const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); // 0....1 or 0...3 - const int offset = tid * K_QUANTS_PER_ITERATION; // 0...15 or 0...14 - const int in = offset/8; // 0 or 1 - const int im = offset%8; // 0...7 - - for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { - - const float * y = yy + i * QK_K + offset; - const uint8_t * q = x[i].qs + offset; - const uint8_t * s = x[i].scales; - - const float dall = (float)x[i].d; - - float sum = 0; - for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) { - const uint8_t hl = x[i].hmask[im+l] >> in; - const uint8_t ql = q[l]; - sum += y[l+ 0] * dall * ((s[0] & 0xF) - 8) * ((int8_t)((ql >> 0) & 3) - ((hl >> 0) & 1 ? 0 : 4)) - + y[l+16] * dall * ((s[0] >> 4) - 8) * ((int8_t)((ql >> 2) & 3) - ((hl >> 2) & 1 ? 0 : 4)) - + y[l+32] * dall * ((s[1] & 0xF) - 8) * ((int8_t)((ql >> 4) & 3) - ((hl >> 4) & 1 ? 0 : 4)) - + y[l+48] * dall * ((s[1] >> 4) - 8) * ((int8_t)((ql >> 6) & 3) - ((hl >> 6) & 1 ? 0 : 4)); - } - tmp += sum; - } -#endif // sum up partial sums and write back result tmp = warp_reduce_sum(tmp); @@ -221,7 +159,6 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx, const block_q4_K * x = (const block_q4_K *)vx + ib0; -#if QK_K == 256 const uint16_t kmask1 = 0x3f3f; const uint16_t kmask2 = 0x0f0f; const uint16_t kmask3 = 0xc0c0; @@ -306,36 +243,6 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx, #endif } -#else - const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15 - const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); - - const int step = tid * K_QUANTS_PER_ITERATION; - - uint16_t aux16[2]; - const uint8_t * s = (const uint8_t *)aux16; - - float tmp = 0; - - for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { - const uint8_t * q = x[i].qs + step; - const float * y = yy + i*QK_K + step; - const uint16_t * a = (const uint16_t *)x[i].scales; - aux16[0] = a[0] & 0x0f0f; - aux16[1] = (a[0] >> 4) & 0x0f0f; - const float d = (float)x[i].dm[0]; - const float m = (float)x[i].dm[1]; - float sum = 0.f; - for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) { - sum += y[j+ 0] * (d * s[0] * (q[j+ 0] & 0xF) - m * s[2]) - + y[j+16] * (d * s[0] * (q[j+16] & 0xF) - m * s[2]) - + y[j+32] * (d * s[1] * (q[j+ 0] >> 4) - m * s[3]) - + y[j+48] * (d * s[1] * (q[j+16] >> 4) - m * s[3]); - } - tmp += sum; - } - -#endif // sum up partial sums and write back result tmp = warp_reduce_sum(tmp); @@ -355,7 +262,6 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * __restrict__ vx, float tmp = 0; // partial sum for thread in warp -#if QK_K == 256 const uint16_t kmask1 = 0x3f3f; const uint16_t kmask2 = 0x0f0f; const uint16_t kmask3 = 0xc0c0; @@ -426,30 +332,6 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * __restrict__ vx, tmp += dall * (sum.x * sc[0] + sum.y * sc[1] + sum.z * sc[4] + sum.w * sc[5]) - dmin * smin; } -#else - const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15 - const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); - const int step = tid * K_QUANTS_PER_ITERATION; - const int im = step/8; - const int in = step%8; - - for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { - const uint8_t * q = x[i].qs + step; - const int8_t * s = x[i].scales; - const float * y = yy + i*QK_K + step; - const float d = x[i].d; - float sum = 0.f; - for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) { - const uint8_t h = x[i].qh[in+j] >> im; - sum += y[j+ 0] * d * s[0] * ((q[j+ 0] & 0xF) - ((h >> 0) & 1 ? 0 : 16)) - + y[j+16] * d * s[1] * ((q[j+16] & 0xF) - ((h >> 2) & 1 ? 0 : 16)) - + y[j+32] * d * s[2] * ((q[j+ 0] >> 4) - ((h >> 4) & 1 ? 0 : 16)) - + y[j+48] * d * s[3] * ((q[j+16] >> 4) - ((h >> 6) & 1 ? 0 : 16)); - } - tmp += sum; - } -#endif - // sum up partial sums and write back result tmp = warp_reduce_sum(tmp); @@ -470,8 +352,6 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const block_q6_K * x = (const block_q6_K *)vx + ib0; -#if QK_K == 256 - const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1 @@ -526,37 +406,6 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, } -#else - - const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...7 - const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); // 0...3 - - const int step = tid * K_QUANTS_PER_ITERATION; - - float tmp = 0; // partial sum for thread in warp - - for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { - - const float * y = yy + i * QK_K + step; - const uint8_t * ql = x[i].ql + step; - const uint8_t * qh = x[i].qh + step; - const int8_t * s = x[i].scales; - - const float d = x[i+0].d; - - float sum = 0; - for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) { - sum += y[j+ 0] * s[0] * d * ((int8_t)((ql[j+ 0] & 0xF) | ((qh[j] & 0x03) << 4)) - 32) - + y[j+16] * s[1] * d * ((int8_t)((ql[j+16] & 0xF) | ((qh[j] & 0x0c) << 2)) - 32) - + y[j+32] * s[2] * d * ((int8_t)((ql[j+ 0] >> 4) | ((qh[j] & 0x30) >> 0)) - 32) - + y[j+48] * s[3] * d * ((int8_t)((ql[j+16] >> 4) | ((qh[j] & 0xc0) >> 2)) - 32); - } - tmp += sum; - - } - -#endif - // sum up partial sums and write back result tmp = warp_reduce_sum(tmp); @@ -573,10 +422,22 @@ static __device__ void convert_f16(const void * vx, const int64_t ib, const int v.y = x[ib + iqs + 1]; } -template +static constexpr __device__ dequantize_kernel_t get_dequantize_kernel(ggml_type type) { + return type == GGML_TYPE_Q4_0 ? dequantize_q4_0 : + type == GGML_TYPE_Q4_1 ? dequantize_q4_1 : + type == GGML_TYPE_Q5_0 ? dequantize_q5_0 : + type == GGML_TYPE_Q5_1 ? dequantize_q5_1 : + type == GGML_TYPE_Q8_0 ? dequantize_q8_0 : + type == GGML_TYPE_F16 ? convert_f16 : + nullptr; +} + +template static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) { - // qk = quantized weights per x block - // qr = number of quantized weights per data value in x block + constexpr int qk = ggml_cuda_type_traits::qk; // quantized weights per x block + constexpr int qr = ggml_cuda_type_traits::qr; // number of quantized weights per data value in x block + constexpr dequantize_kernel_t dequantize_kernel = get_dequantize_kernel(type); + const int64_t row = (int64_t)blockIdx.x*blockDim.y + threadIdx.y; if (row >= nrows) { @@ -644,7 +505,7 @@ static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y, // the number of rows may exceed maximum grid size in the y or z dimensions, use the x dimension instead const dim3 block_nums(block_num_y, 1, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - dequantize_mul_mat_vec + dequantize_mul_mat_vec <<>>(vx, y, dst, ncols, nrows); } @@ -653,7 +514,7 @@ static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const dfloat * y, const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - dequantize_mul_mat_vec + dequantize_mul_mat_vec <<>>(vx, y, dst, ncols, nrows); } @@ -662,7 +523,7 @@ static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const dfloat * y, const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - dequantize_mul_mat_vec + dequantize_mul_mat_vec <<>>(vx, y, dst, ncols, nrows); } @@ -671,7 +532,7 @@ static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const dfloat * y, const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - dequantize_mul_mat_vec + dequantize_mul_mat_vec <<>>(vx, y, dst, ncols, nrows); } @@ -680,7 +541,7 @@ static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const dfloat * y, const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - dequantize_mul_mat_vec + dequantize_mul_mat_vec <<>>(vx, y, dst, ncols, nrows); } @@ -731,7 +592,7 @@ static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, floa const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - dequantize_mul_mat_vec<1, 1, convert_f16> + dequantize_mul_mat_vec <<>>(vx, y, dst, ncols, nrows); } diff --git a/llama/ggml-cuda/fattn-common.cuh b/llama/ggml-cuda/fattn-common.cuh index 1dd519bd..c00f8606 100644 --- a/llama/ggml-cuda/fattn-common.cuh +++ b/llama/ggml-cuda/fattn-common.cuh @@ -1,4 +1,8 @@ +#pragma once + #include "common.cuh" +#include "convert.cuh" +#include "vecdotq.cuh" #include @@ -34,11 +38,523 @@ typedef void (* fattn_kernel_t)( const int nb11, const int nb12, const int nb13, + const int nb21, + const int nb22, + const int nb23, const int ne0, const int ne1, const int ne2, const int ne3); +typedef half (*vec_dot_KQ_f16_t)( + const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds); +typedef float (*vec_dot_KQ_f32_t)( + const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds); + +template +static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0( + const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { +#if __CUDA_ARCH__ >= MIN_CC_DP4A + + const block_q4_0 * K_q4_0 = (const block_q4_0 *) K_c; + GGML_UNUSED(Q_v); + + half sum = 0.0f; + +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) { + const int k_KQ = k_KQ_0 + threadIdx.x; + + const int ib = k_KQ / QI8_1; + const int iqs4 = k_KQ % QI4_0; + const int shift = k_KQ & (QI8_1/2); + + const int v = (get_int_from_uint8(K_q4_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F; + const int u = Q_q8[k_KQ_0/WARP_SIZE]; + + const int sumi = __dp4a(v, u, 0); + +#if FP16_AVAILABLE + if (std::is_same::value) { + const half2 * Q_ds = (const half2 *) Q_ds_v; + + const half2 sum2 = __half2half2(K_q4_0[ib].d) * Q_ds[k_KQ_0/WARP_SIZE]; + sum += (T) (((half) sumi)*__low2half(sum2) - __high2half(sum2) /* *8/QI8_1 == 1 */); + } else +#endif // FP16_AVAILABLE + { + const float2 * Q_ds = (const float2 *) Q_ds_v; + + sum += (T) (__half2float(K_q4_0[ib].d) * (sumi*Q_ds[k_KQ_0/WARP_SIZE].x - (8/QI8_1)*Q_ds[k_KQ_0/WARP_SIZE].y)); + } + } + + return sum; +#else + GGML_UNUSED(K_c); + GGML_UNUSED(Q_v); + GGML_UNUSED(Q_q8); + GGML_UNUSED(Q_ds_v); + NO_DEVICE_CODE; +#endif // __CUDA_ARCH__ >= MIN_CC_DP4A +} + +template +static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1( + const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { +#if __CUDA_ARCH__ >= MIN_CC_DP4A + + const block_q4_1 * K_q4_1 = (const block_q4_1 *) K_c; + GGML_UNUSED(Q_v); + + T sum = 0.0f; + +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) { + const int k_KQ = k_KQ_0 + threadIdx.x; + + const int ib = k_KQ / QI8_1; + const int iqs4 = k_KQ % QI4_1; + const int shift = k_KQ & (QI8_1/2); + + const int v = (get_int_from_uint8_aligned(K_q4_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F; + const int u = Q_q8[k_KQ_0/WARP_SIZE]; + + const int sumi = __dp4a(v, u, 0); + +#if FP16_AVAILABLE + if (std::is_same::value) { + const half2 * Q_ds = (const half2 *) Q_ds_v; + + const half2 d4d8_m4s8 = K_q4_1[ib].dm * Q_ds[k_KQ_0/WARP_SIZE]; + const half2 sumid4d8_m4s8scaled = d4d8_m4s8 * make_half2(sumi, 1.0f/QI8_1); + sum += (T) (__low2half(sumid4d8_m4s8scaled) + __high2half(sumid4d8_m4s8scaled)); + } else +#endif // FP16_AVAILABLE + { + const float2 * Q_ds = (const float2 *) Q_ds_v; + + const float sumid4d8 = __low2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].x * sumi; + const float m4s8scaled = __high2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].y / QI8_1; + + sum += (T) (sumid4d8 + m4s8scaled); + } + } + + return sum; +#else + GGML_UNUSED(K_c); + GGML_UNUSED(Q_v); + GGML_UNUSED(Q_q8); + GGML_UNUSED(Q_ds_v); + NO_DEVICE_CODE; +#endif // __CUDA_ARCH__ >= MIN_CC_DP4A +} + +template +static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0( + const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { +#if __CUDA_ARCH__ >= MIN_CC_DP4A + + const block_q5_0 * K_q5_0 = (const block_q5_0 *) K_c; + GGML_UNUSED(Q_v); + + T sum = 0.0f; + +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) { + const int k_KQ = k_KQ_0 + threadIdx.x; + + const int ib = k_KQ / QI8_1; + const int iqs4 = k_KQ % QI5_0; + const int iqs8 = k_KQ % QI8_1; + const int shift = k_KQ & (QI8_1/2); + + int v = (get_int_from_uint8(K_q5_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F; + const int vh = get_int_from_uint8(K_q5_0[ib].qh, 0) >> (iqs8 * QI5_0); + v |= (vh << 4) & 0x00000010; // 0 -> 4 + v |= (vh << 11) & 0x00001000; // 1 -> 12 + v |= (vh << 18) & 0x00100000; // 2 -> 20 + v |= (vh << 25) & 0x10000000; // 3 -> 28 + + const int u = Q_q8[k_KQ_0/WARP_SIZE]; + + const int sumi = __dp4a(v, u, 0); + +#if FP16_AVAILABLE + if (std::is_same::value) { + const half2 * Q_ds = (const half2 *) Q_ds_v; + + const half2 sum2 = __half2half2(K_q5_0[ib].d) * Q_ds[k_KQ_0/WARP_SIZE]; + sum += (T) (((half) sumi)*__low2half(sum2) - __high2half(sum2)*__float2half(2.0f)) /* *16/QI8_1 == 2 */; + } else +#endif // FP16_AVAILABLE + { + const float2 * Q_ds = (const float2 *) Q_ds_v; + + sum += (T) (__half2float(K_q5_0[ib].d) * (sumi*Q_ds[k_KQ_0/WARP_SIZE].x - (16/QI8_1)*Q_ds[k_KQ_0/WARP_SIZE].y)); + } + } + + return sum; +#else + GGML_UNUSED(K_c); + GGML_UNUSED(Q_v); + GGML_UNUSED(Q_q8); + GGML_UNUSED(Q_ds_v); + NO_DEVICE_CODE; +#endif // __CUDA_ARCH__ >= MIN_CC_DP4A +} + +template +static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1( + const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { +#if __CUDA_ARCH__ >= MIN_CC_DP4A + + const block_q5_1 * K_q5_1 = (const block_q5_1 *) K_c; + GGML_UNUSED(Q_v); + + T sum = 0.0f; + +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) { + const int k_KQ = k_KQ_0 + threadIdx.x; + + const int ib = k_KQ / QI8_1; + const int iqs4 = k_KQ % QI5_1; + const int iqs8 = k_KQ % QI8_1; + const int shift = k_KQ & (QI8_1/2); + + int v = (get_int_from_uint8(K_q5_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F; + const int vh = get_int_from_uint8(K_q5_1[ib].qh, 0) >> (iqs8 * QI5_1); + v |= (vh << 4) & 0x00000010; // 0 -> 4 + v |= (vh << 11) & 0x00001000; // 1 -> 12 + v |= (vh << 18) & 0x00100000; // 2 -> 20 + v |= (vh << 25) & 0x10000000; // 3 -> 28 + + const int u = Q_q8[k_KQ_0/WARP_SIZE]; + + const int sumi = __dp4a(v, u, 0); + +#if FP16_AVAILABLE + if (std::is_same::value) { + const half2 * Q_ds = (const half2 *) Q_ds_v; + + const half2 d5d8_m5s8 = K_q5_1[ib].dm * Q_ds[k_KQ_0/WARP_SIZE]; + const half2 sumid5d8_m5s8scaled = d5d8_m5s8 * make_half2(sumi, 1.0f/QI8_1); + sum += (T) (__low2half(sumid5d8_m5s8scaled) + __high2half(sumid5d8_m5s8scaled)); + } else +#endif // FP16_AVAILABLE + { + const float2 * Q_ds = (const float2 *) Q_ds_v; + + const float sumid5d8 = __low2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].x * sumi; + const float m5s8scaled = __high2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].y / QI8_1; + + sum += (T) (sumid5d8 + m5s8scaled); + } + } + + return sum; +#else + GGML_UNUSED(K_c); + GGML_UNUSED(Q_v); + GGML_UNUSED(Q_q8); + GGML_UNUSED(Q_ds_v); + NO_DEVICE_CODE; +#endif // __CUDA_ARCH__ >= MIN_CC_DP4A +} + +template +static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0( + const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { +#if __CUDA_ARCH__ >= MIN_CC_DP4A + + const block_q8_0 * K_q8_0 = (const block_q8_0 *) K_c; + GGML_UNUSED(Q_v); + + T sum = 0.0f; + +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) { + const int k_KQ = k_KQ_0 + threadIdx.x; + + const int ib = k_KQ / QI8_0; + const int iqs = k_KQ % QI8_0; + + const int v = get_int_from_int8(K_q8_0[ib].qs, iqs); + + T Q_d; + if (std::is_same::value) { + const half2 * Q_ds = (const half2 *) Q_ds_v; + Q_d = __low2half(Q_ds[k_KQ_0/WARP_SIZE]); + } else { + const float2 * Q_ds = (const float2 *) Q_ds_v; + Q_d = Q_ds[k_KQ_0/WARP_SIZE].x; + } + + sum += vec_dot_q8_0_q8_1_impl(&v, &Q_q8[k_KQ_0/WARP_SIZE], K_q8_0[ib].d, Q_d); + } + + return sum; +#else + GGML_UNUSED(K_c); + GGML_UNUSED(Q_v); + GGML_UNUSED(Q_q8); + GGML_UNUSED(Q_ds_v); + NO_DEVICE_CODE; +#endif // __CUDA_ARCH__ >= MIN_CC_DP4A +} + +template +static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16( + const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) { + + const half2 * K_h2 = (const half2 *) K_c; + GGML_UNUSED(Q_q8); + GGML_UNUSED(Q_ds_v); + +#if FP16_AVAILABLE + if (std::is_same::value) { + const half2 * Q_h2 = (const half2 *) Q_v; + + half2 sum2 = make_half2(0.0f, 0.0f); + +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) { + const int k_KQ = k_KQ_0 + threadIdx.x; + + const half2 K_ik = K_h2[k_KQ]; + sum2 += K_ik * Q_h2[k_KQ_0/WARP_SIZE]; + } + + return __low2half(sum2) + __high2half(sum2); + } +#endif // FP16_AVAILABLE + + const float2 * Q_f2 = (const float2 *) Q_v; + + float sum = 0.0f; + +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) { + const int k_KQ = k_KQ_0 + threadIdx.x; + + const half2 K_ik = K_h2[k_KQ]; + sum += __low2float(K_ik) * Q_f2[k_KQ_0/WARP_SIZE].x; + sum += __high2float(K_ik) * Q_f2[k_KQ_0/WARP_SIZE].y; + } + + return sum; +} + +template +static __device__ __forceinline__ void quantize_q8_1_to_shared( + const float * __restrict__ x, const float scale, int * __restrict__ yq32, void * __restrict__ yds) { + + float vals[sizeof(int)] = {0.0f}; +#pragma unroll + for (int l = 0; l < sizeof(int); ++l) { + vals[l] = scale * x[4*threadIdx.x + l]; + } + + float amax = fabsf(vals[0]); + float sum = vals[0]; +#pragma unroll + for (int l = 1; l < sizeof(int); ++l) { + amax = fmaxf(amax, fabsf(vals[l])); + sum += vals[l]; + } +#pragma unroll + for (int mask = QI8_1/2; mask > 0; mask >>= 1) { + amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, 32)); + sum += __shfl_xor_sync(0xFFFFFFFF, sum, mask, 32); + } + + const float d = amax / 127; + int q32 = 0; + int8_t * q8 = (int8_t *) &q32; + + if (d != 0.0f) { +#pragma unroll + for (int l = 0; l < sizeof(int); ++l) { + q8[l] = roundf(vals[l] / d); + } + } + + yq32[threadIdx.x] = q32; + if (threadIdx.x % QI8_1 == 0) { + if (std::is_same::value) { + ((half2 *) yds)[threadIdx.x/QI8_1] = make_half2(d, sum); + } else { + ((float2 *) yds)[threadIdx.x/QI8_1] = make_float2(d, sum); + } + } +} + +typedef half (*dequantize_1_f16_t)(const void *, const int64_t); +typedef float (*dequantize_1_f32_t)(const void *, const int64_t); + +template +static __device__ __forceinline__ T dequantize_1_q4_0(const void * __restrict__ vx, const int64_t i) { + const block_q4_0 * x = (const block_q4_0 *) vx; + + const int64_t ib = i / QK4_0; + const int iqs = i % (QK4_0/2); + const int shift = (i % QK4_0) / (QK4_0/2); + + const T d = x[ib].d; + const int q0 = x[ib].qs[iqs]; + const int q = ((q0 >> (4*shift)) & 0x0F) - 8; + +#if FP16_AVAILABLE + if (std::is_same::value) { + return ((half) d)*((half) q); + } +#endif // FP16_AVAILABLE + + return ((float) d)*((float) q); +} + +template +static __device__ __forceinline__ T dequantize_1_q4_1(const void * __restrict__ vx, const int64_t i) { + const block_q4_1 * x = (const block_q4_1 *) vx; + + const int64_t ib = i / QK4_1; + const int iqs = i % (QK4_1/2); + const int shift = (i % QK4_1) / (QK4_1/2); + + const half2 dm = x[ib].dm; + const int q0 = x[ib].qs[iqs]; + const int q = ((q0 >> (4*shift)) & 0x0F); + +#if FP16_AVAILABLE + if (std::is_same::value) { + return __low2half(dm)*((half) q) + __high2half(dm); + } +#endif // FP16_AVAILABLE + + return __low2float(dm)*((float) q) + __high2float(dm); +} + +template +static __device__ __forceinline__ T dequantize_1_q5_0(const void * __restrict__ vx, const int64_t i) { + const block_q5_0 * x = (const block_q5_0 *) vx; + + const int64_t ib = i / QK5_0; + const int idq = i % QK5_0; + const int iqs = i % (QK5_0/2); + const int shift = (i % QK5_0) / (QK5_0/2); + + const T d = x[ib].d; + const int ql0 = x[ib].qs[iqs]; + const int qh0 = get_int_from_uint8(x[ib].qh, 0); + const int ql = ((ql0 >> (4*shift)) & 0x0F); + const int qh = ((qh0 >> idq) << 4) & 0x10; + const int q = (ql | qh) - 16; + +#if FP16_AVAILABLE + if (std::is_same::value) { + return ((half) d)*((half) q); + } +#endif // FP16_AVAILABLE + + return ((float) d)*((float) q); +} + +template +static __device__ __forceinline__ T dequantize_1_q5_1(const void * __restrict__ vx, const int64_t i) { + const block_q5_1 * x = (const block_q5_1 *) vx; + + const int64_t ib = i / QK5_1; + const int idq = i % QK5_1; + const int iqs = i % (QK5_1/2); + const int shift = (i % QK5_1) / (QK5_1/2); + + const half2 dm = x[ib].dm; + const int ql0 = x[ib].qs[iqs]; + const int qh0 = get_int_from_uint8_aligned(x[ib].qh, 0); + const int ql = ((ql0 >> (4*shift)) & 0x0F); + const int qh = ((qh0 >> idq) << 4) & 0x10; + const int q = (ql | qh); + +#if FP16_AVAILABLE + if (std::is_same::value) { + return __low2half(dm)*((half) q) + __high2half(dm); + } +#endif // FP16_AVAILABLE + + return __low2float(dm)*((float) q) + __high2float(dm); +} + +template +static __device__ __forceinline__ T dequantize_1_q8_0(const void * __restrict__ vx, const int64_t i) { + const block_q8_0 * x = (const block_q8_0 *) vx; + + const int64_t ib = i / QK8_0; + const int iqs = i % QK8_0; + + const T d = x[ib].d; + const int q = x[ib].qs[iqs]; + +#if FP16_AVAILABLE + if (std::is_same::value) { + return ((half) d)*((half) q); + } +#endif // FP16_AVAILABLE + + return ((float) d)*((float) q); +} + +template +static __device__ __forceinline__ T dequantize_1_f16(const void * __restrict__ vx, const int64_t i) { + const half * x = (const half *) vx; + + return x[i]; +} + +template +constexpr __device__ vec_dot_KQ_f16_t get_vec_dot_KQ_f16(ggml_type type_K) { + return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0 : + type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1 : + type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0 : + type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1 : + type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0 : + type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16 : + nullptr; +} + +template +constexpr __device__ vec_dot_KQ_f32_t get_vec_dot_KQ_f32(ggml_type type_K) { + return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0 : + type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1 : + type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0 : + type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1 : + type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0 : + type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16 : + nullptr; +} + +constexpr __device__ dequantize_1_f16_t get_dequantize_1_f16(ggml_type type_V) { + return type_V == GGML_TYPE_Q4_0 ? dequantize_1_q4_0 : + type_V == GGML_TYPE_Q4_1 ? dequantize_1_q4_1 : + type_V == GGML_TYPE_Q5_0 ? dequantize_1_q5_0 : + type_V == GGML_TYPE_Q5_1 ? dequantize_1_q5_1 : + type_V == GGML_TYPE_Q8_0 ? dequantize_1_q8_0 : + type_V == GGML_TYPE_F16 ? dequantize_1_f16 : + nullptr; +} + +constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) { + return type_V == GGML_TYPE_Q4_0 ? dequantize_1_q4_0 : + type_V == GGML_TYPE_Q4_1 ? dequantize_1_q4_1 : + type_V == GGML_TYPE_Q5_0 ? dequantize_1_q5_0 : + type_V == GGML_TYPE_Q5_1 ? dequantize_1_q5_1 : + type_V == GGML_TYPE_Q8_0 ? dequantize_1_q8_0 : + type_V == GGML_TYPE_F16 ? dequantize_1_f16 : + nullptr; +} + template // D == head size #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) __launch_bounds__(D, 1) @@ -83,8 +599,32 @@ static __global__ void flash_attn_combine_results( dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator; } +static void on_no_fattn_vec_case(const int D) { + if (D == 64) { + fprintf(stderr, "Unsupported KV type combination for head_size 64.\n"); + fprintf(stderr, "By default only f16 KV cache is supported.\n"); + fprintf(stderr, "Compile with LLAMA_CUDA_FA_ALL_QUANTS for V cache quantization support.\n"); + GGML_ASSERT(false); + } else if (D == 128) { + fprintf(stderr, "Unsupported KV type combination for head_size 128.\n"); + fprintf(stderr, "Supported combinations:\n"); + fprintf(stderr, " - K == q4_0, V == q4_0, 4.50 BPV\n"); + fprintf(stderr, " - K == q8_0, V == q8_0, 8.50 BPV\n"); + fprintf(stderr, " - K == f16, V == f16, 16.00 BPV\n"); + fprintf(stderr, "Compile with LLAMA_CUDA_FA_ALL_QUANTS for all combinations of q4_0, q4_1, q5_0, q5_1, q8_0, and f16.\n"); + GGML_ASSERT(false); + } else { + fprintf(stderr, "Unsupported KV type combination for head_size 256.\n"); + fprintf(stderr, "Only f16 is supported.\n"); + GGML_ASSERT(false); + } +} + template -void launch_fattn(ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, int nwarps, int cols_per_block) { +void launch_fattn( + ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, + const int nwarps, const int cols_per_block, const bool need_f16_K, const bool need_f16_V +) { const ggml_tensor * Q = dst->src[0]; const ggml_tensor * K = dst->src[1]; const ggml_tensor * V = dst->src[2]; @@ -94,8 +634,6 @@ void launch_fattn(ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kern ggml_tensor * KQV = dst; GGML_ASSERT(Q->type == GGML_TYPE_F32); - GGML_ASSERT(K->type == GGML_TYPE_F16); - GGML_ASSERT(V->type == GGML_TYPE_F16); GGML_ASSERT(KQV->type == GGML_TYPE_F32); GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16); @@ -107,9 +645,49 @@ void launch_fattn(ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kern ggml_cuda_pool & pool = ctx.pool(); cudaStream_t main_stream = ctx.stream(); + ggml_cuda_pool_alloc K_f16(pool); + ggml_cuda_pool_alloc V_f16(pool); ggml_cuda_pool_alloc dst_tmp(pool); ggml_cuda_pool_alloc dst_tmp_meta(pool); + char * K_data = (char *) K->data; + size_t nb11 = K->nb[1]; + size_t nb12 = K->nb[2]; + size_t nb13 = K->nb[3]; + + char * V_data = (char *) V->data; + size_t nb21 = V->nb[1]; + size_t nb22 = V->nb[2]; + size_t nb23 = V->nb[3]; + + if (need_f16_K && K->type != GGML_TYPE_F16) { + K_f16.alloc(ggml_nelements(K)); + to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type); + to_fp16(K_data, K_f16.ptr, ggml_nelements(K), main_stream); + K_data = (char *) K_f16.ptr; + + const size_t bs = ggml_blck_size(K->type); + const size_t ts = ggml_type_size(K->type); + + nb11 = nb11*bs*sizeof(half)/ts; + nb12 = nb12*bs*sizeof(half)/ts; + nb13 = nb13*bs*sizeof(half)/ts; + } + + if (need_f16_V && V->type != GGML_TYPE_F16) { + V_f16.alloc(ggml_nelements(V)); + to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type); + to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream); + V_data = (char *) V_f16.ptr; + + const size_t bs = ggml_blck_size(V->type); + const size_t ts = ggml_type_size(V->type); + + nb21 = nb21*bs*sizeof(half)/ts; + nb22 = nb22*bs*sizeof(half)/ts; + nb23 = nb23*bs*sizeof(half)/ts; + } + if (parallel_blocks > 1) { dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV)); @@ -133,8 +711,8 @@ void launch_fattn(ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kern fattn_kernel<<>>( (const char *) Q->data, - (const char *) K->data, - (const char *) V->data, + K_data, + V_data, mask ? ((const char *) mask->data) : nullptr, (parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, scale, max_bias, m0, m1, n_head_log2, @@ -142,7 +720,8 @@ void launch_fattn(ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kern K->ne[0], K->ne[1], K->ne[2], K->ne[3], mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], + nb11, nb12, nb13, + nb21, nb22, nb23, KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] ); CUDA_CHECK(cudaGetLastError()); diff --git a/llama/ggml-cuda/fattn-tile-f16.cu b/llama/ggml-cuda/fattn-tile-f16.cu index 4a07ac6a..cb11d721 100644 --- a/llama/ggml-cuda/fattn-tile-f16.cu +++ b/llama/ggml-cuda/fattn-tile-f16.cu @@ -36,6 +36,9 @@ static __global__ void flash_attn_tile_ext_f16( const int nb11, const int nb12, const int nb13, + const int nb21, + const int nb22, + const int nb23, const int ne0, const int ne1, const int ne2, @@ -83,7 +86,7 @@ static __global__ void flash_attn_tile_ext_f16( for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; - const float2 tmp = Q_f2[j*(nb01/sizeof(float2)) + i]; + const float2 tmp = ic0 + j < ne01 ? Q_f2[j*(nb01/sizeof(float2)) + i] : make_float2(0.0f, 0.0f); Q_h2[j][i] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y); } } @@ -238,6 +241,10 @@ static __global__ void flash_attn_tile_ext_f16( for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) { const int j_VKQ = j_VKQ_0 + threadIdx.y; + if (ic0 + j_VKQ >= ne01) { + return; + } + half kqsum_j = __low2half(kqsum[j_VKQ_0/nwarps]) + __high2half(kqsum[j_VKQ_0/nwarps]); kqsum_j = warp_reduce_sum(kqsum_j); @@ -271,13 +278,13 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * constexpr int D = 64; constexpr int nwarps = 8; fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block); + launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); } break; case 128: { constexpr int D = 128; constexpr int nwarps = 8; fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block); + launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); } break; default: { GGML_ASSERT(false && "FlashAttention without tensor cores only supports head sizes 64 and 128."); diff --git a/llama/ggml-cuda/fattn-tile-f32.cu b/llama/ggml-cuda/fattn-tile-f32.cu index 130e7cbd..15e22f49 100644 --- a/llama/ggml-cuda/fattn-tile-f32.cu +++ b/llama/ggml-cuda/fattn-tile-f32.cu @@ -36,6 +36,9 @@ static __global__ void flash_attn_tile_ext_f32( const int nb11, const int nb12, const int nb13, + const int nb21, + const int nb22, + const int nb23, const int ne0, const int ne1, const int ne2, @@ -79,7 +82,7 @@ static __global__ void flash_attn_tile_ext_f32( #pragma unroll for (int i0 = 0; i0 < D; i0 += 2*WARP_SIZE) { - float2 tmp = Q_f2[j*(nb01/sizeof(float2)) + i0/2 + threadIdx.x]; + float2 tmp = ic0 + j < ne01 ? Q_f2[j*(nb01/sizeof(float2)) + i0/2 + threadIdx.x] : make_float2(0.0f, 0.0f); Q_f[j][i0 + 0*WARP_SIZE + threadIdx.x] = tmp.x * scale; Q_f[j][i0 + 1*WARP_SIZE + threadIdx.x] = tmp.y * scale; } @@ -237,6 +240,10 @@ static __global__ void flash_attn_tile_ext_f32( for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) { const int j_VKQ = j_VKQ_0 + threadIdx.y; + if (ic0 + j_VKQ >= ne01) { + return; + } + float kqsum_j = kqsum[j_VKQ_0/nwarps]; kqsum_j = warp_reduce_sum(kqsum_j); @@ -268,13 +275,13 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * constexpr int D = 64; constexpr int nwarps = 8; fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32; - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block); + launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); } break; case 128: { constexpr int D = 128; constexpr int nwarps = 8; fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32; - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block); + launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); } break; default: { GGML_ASSERT(false && "FlashAttention without tensor cores only supports head sizes 64 and 128."); @@ -283,11 +290,7 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * } void ggml_cuda_flash_attn_ext_tile_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * KQV = dst; - const ggml_tensor * Q = dst->src[0]; - - const int32_t precision = KQV->op_params[2]; - GGML_ASSERT(precision == GGML_PREC_DEFAULT); + const ggml_tensor * Q = dst->src[0]; if (Q->ne[1] <= 16) { constexpr int cols_per_block = 16; diff --git a/llama/ggml-cuda/fattn-vec-f16.cuh b/llama/ggml-cuda/fattn-vec-f16.cuh index c7023610..9e1aa2c6 100644 --- a/llama/ggml-cuda/fattn-vec-f16.cuh +++ b/llama/ggml-cuda/fattn-vec-f16.cuh @@ -1,5 +1,397 @@ #include "common.cuh" +#include "fattn-common.cuh" -void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst); +template // D == head size +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) +__launch_bounds__(D, 1) +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) +static __global__ void flash_attn_vec_ext_f16( + const char * __restrict__ Q, + const char * __restrict__ K, + const char * __restrict__ V, + const char * __restrict__ mask, + float * __restrict__ dst, + float2 * __restrict__ dst_meta, + const float scale, + const float max_bias, + const float m0, + const float m1, + const uint32_t n_head_log2, + const int ne00, + const int ne01, + const int ne02, + const int ne03, + const int ne10, + const int ne11, + const int ne12, + const int ne13, + const int ne31, + const int nb31, + const int nb01, + const int nb02, + const int nb03, + const int nb11, + const int nb12, + const int nb13, + const int nb21, + const int nb22, + const int nb23, + const int ne0, + const int ne1, + const int ne2, + const int ne3) { +#if FP16_AVAILABLE + //In this kernel Q, K, V are matrices while i, j, k are matrix indices. -void ggml_cuda_flash_attn_ext_vec_f16_no_mma(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + constexpr vec_dot_KQ_f16_t vec_dot_KQ = get_vec_dot_KQ_f16(type_K); + constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16; + constexpr dequantize_1_f16_t dequantize_1_v = get_dequantize_1_f16(type_V); + + const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on. + const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel. + + const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. + Q += nb02* blockIdx.y + nb01*ic0; + K += nb12*(blockIdx.y / gqa_ratio); + V += nb22*(blockIdx.y / gqa_ratio); + + const half * maskh = (const half *) mask + ne11*ic0; + + const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1); + const half slopeh = __float2half(slopef); + + static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); + constexpr int nwarps = D / WARP_SIZE; + const int tid = WARP_SIZE*threadIdx.y + threadIdx.x; + __builtin_assume(tid < D); + + __shared__ half KQ[ncols*D]; + half2 * KQ2 = (half2 *) KQ; + + half kqmax[ncols]; +#pragma unroll + for (int j = 0; j < ncols; ++j) { + kqmax[j] = -HALF_MAX_HALF; + } + half kqsum[ncols] = {0.0f}; + + __shared__ half kqmax_shared[ncols][WARP_SIZE]; + __shared__ half kqsum_shared[ncols][WARP_SIZE]; +#pragma unroll + for (int j = 0; j < ncols; ++j) { + if (threadIdx.y == 0) { + kqmax_shared[j][threadIdx.x] = -HALF_MAX_HALF; + kqsum_shared[j][threadIdx.x] = 0.0f; + } + } + __syncthreads(); + + // Convert Q to half2 (f16 K) or q8_1 (quantized K) and store in registers: + half2 Q_h2[ncols][D/(2*WARP_SIZE)]; + int Q_i32[ncols][D/(sizeof(int)*QK8_1) == 0 ? 1 : D/(sizeof(int)*QK8_1)]; + half2 Q_ds[ncols][D/QK8_1 == 0 ? 1 : D/QK8_1]; + if (Q_q8_1) { +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; + + if (j0 + nwarps > ncols && j >= ncols) { + break; + } + + // Reuse KQ as temporary storage for converting Q to q8_1: + int * tmp_q_i32 = (int *) &KQ[j*D]; + half2 * tmp_q_ds = (half2 *) (tmp_q_i32 + D/sizeof(int)); + + // Set memory to zero if out of bounds: + if (ncols > 2 && ic0 + j >= ne01) { +#pragma unroll + for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + tmp_q_i32[i] = 0; + } + if (threadIdx.x < D/QK8_1) { + tmp_q_ds[threadIdx.x] = make_half2(0.0f, 0.0f); + } + continue; + } + + const float * Q_f = (const float *) (Q + j*nb01); +#pragma unroll + for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) { + quantize_q8_1_to_shared(Q_f + 4*i0, scale, tmp_q_i32, tmp_q_ds); + } + } + + __syncthreads(); + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + int * tmp_q_i32 = (int *) &KQ[j*D]; + half2 * tmp_q_ds = (half2 *) (tmp_q_i32 + D/sizeof(int)); + +#pragma unroll + for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + Q_i32[j][i0/WARP_SIZE] = tmp_q_i32[i]; + Q_ds[j][i0/WARP_SIZE] = tmp_q_ds[i/QI8_1]; + } + } + + __syncthreads(); + } else { +#pragma unroll + for (int j = 0; j < ncols; ++j) { + const float2 * Q_f2_j = (const float2 *) (Q + j*nb01); + +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + const float2 tmp = ncols <= 2 || ic0 + j < ne01 ? Q_f2_j[i] : make_float2(0.0f, 0.0f); + Q_h2[j][i0/WARP_SIZE] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y); + } + } + } + + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + KQ[j*D + tid] = -HALF_MAX_HALF; + } + + half2 VKQ[ncols] = {{0.0f, 0.0f}}; + + const int k_start = parallel_blocks == 1 ? 0 : ip*D; + for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) { + // Calculate KQ tile and keep track of new maximum KQ values: + + // For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression, + // see https://github.com/ggerganov/llama.cpp/pull/7061 . + // Therefore this variable is defined twice but only used once (so that the compiler can optimize out the unused variable). + half kqmax_new = kqmax[0]; + half kqmax_new_arr[ncols]; +#pragma unroll + for (int j = 0; j < ncols; ++j) { + kqmax_new_arr[j] = kqmax[j]; + } + +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += nwarps) { + const int i_KQ = i_KQ_0 + threadIdx.y; + + if ((i_KQ_0 + nwarps > D && i_KQ >= D) || (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + i_KQ >= ne11)) { + break; + } + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + half sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_h2[j], Q_i32[j], Q_ds[j]); + sum = warp_reduce_sum(sum); + sum += mask ? slopeh*maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f); + + if (ncols == 1) { + kqmax_new = ggml_cuda_hmax(kqmax_new, sum); + } else { + kqmax_new_arr[j] = ggml_cuda_hmax(kqmax_new_arr[j], sum); + } + + if (threadIdx.x == 0) { + KQ[j*D + i_KQ] = sum; + } + } + } + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + half kqmax_new_j = ncols == 1 ? kqmax_new : kqmax_new_arr[j]; + + kqmax_new_j = warp_reduce_max(kqmax_new_j); + if (threadIdx.x == 0) { + kqmax_shared[j][threadIdx.y] = kqmax_new_j; + } + } + + __syncthreads(); + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + half kqmax_new_j = kqmax_shared[j][threadIdx.x]; + kqmax_new_j = warp_reduce_max(kqmax_new_j); + + const half KQ_max_scale = hexp(kqmax[j] - kqmax_new_j); + kqmax[j] = kqmax_new_j; + + const half val = hexp(KQ[j*D + tid] - kqmax[j]); + kqsum[j] = kqsum[j]*KQ_max_scale + val; + KQ[j*D + tid] = val; + + VKQ[j] *= __half2half2(KQ_max_scale); + } + + __syncthreads(); + +#pragma unroll + for (int k0 = 0; k0 < D; k0 += 2) { + if (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + k0 >= ne11) { + break; + } + + half2 V_k; + reinterpret_cast(V_k.x) = dequantize_1_v(V + (k_VKQ_0 + k0 + 0)*nb21, tid); + reinterpret_cast(V_k.y) = dequantize_1_v(V + (k_VKQ_0 + k0 + 1)*nb21, tid); +#pragma unroll + for (int j = 0; j < ncols; ++j) { + VKQ[j] += V_k*KQ2[j*(D/2) + k0/2]; + } + } + + __syncthreads(); + } + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + kqsum[j] = warp_reduce_sum(kqsum[j]); + if (threadIdx.x == 0) { + kqsum_shared[j][threadIdx.y] = kqsum[j]; + } + } + + __syncthreads(); + +#pragma unroll + for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) { + if (ncols > 2 && ic0 + j_VKQ >= ne01) { + break; + } + + kqsum[j_VKQ] = kqsum_shared[j_VKQ][threadIdx.x]; + kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]); + + half dst_val = (__low2half(VKQ[j_VKQ]) + __high2half(VKQ[j_VKQ])); + if (parallel_blocks == 1) { + dst_val /= kqsum[j_VKQ]; + } + const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip; + dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val; + } + + if (parallel_blocks != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) { + dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]); + } +#else + NO_DEVICE_CODE; +#endif // FP16_AVAILABLE +} + +template +void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + constexpr int nwarps = D/WARP_SIZE; + fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16; + constexpr bool need_f16_K = D != 128; + constexpr bool need_f16_V = D != 128 && D != 64; + launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V); +} + +template +void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_tensor * KQV = dst; + ggml_tensor * Q = dst->src[0]; + ggml_tensor * K = dst->src[1]; + ggml_tensor * V = dst->src[2]; + + const int32_t precision = KQV->op_params[2]; + GGML_ASSERT(precision == GGML_PREC_DEFAULT); + + GGML_ASSERT(K->type == type_K); + GGML_ASSERT(V->type == type_V); + + if (Q->ne[1] == 1) { + constexpr int cols_per_block = 1; + constexpr int parallel_blocks = 4; + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + return; + } + + if (Q->ne[1] == 2) { + constexpr int cols_per_block = 2; + constexpr int parallel_blocks = 4; + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + return; + } + + if (Q->ne[1] <= 4) { + constexpr int cols_per_block = 4; + constexpr int parallel_blocks = 4; + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + return; + } + + if (Q->ne[1] <= 8) { + constexpr int cols_per_block = 8; + constexpr int parallel_blocks = 4; + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); + return; + } + + constexpr int cols_per_block = 8; + constexpr int parallel_blocks = 1; + ggml_cuda_flash_attn_ext_vec_f16_case_impl(ctx, dst); +} + +#define DECL_FATTN_VEC_F16_CASE(D, type_K, type_V) \ + template void ggml_cuda_flash_attn_ext_vec_f16_case \ + (ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ + +extern DECL_FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0); +extern DECL_FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1); +extern DECL_FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0); +extern DECL_FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1); +extern DECL_FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0); +extern DECL_FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16); + +extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0); +extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0); +extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0); +extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0); +extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0); +extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0); + +extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1); +extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1); +extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1); +extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1); +extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1); +extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1); + +extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0); +extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0); +extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0); +extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0); +extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0); +extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0); + +extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1); +extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1); +extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1); +extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1); +extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1); +extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1); + +extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0); +extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0); +extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0); +extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0); +extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0); +extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0); + +extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16); +extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16); +extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16); +extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16); +extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16); +extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16); + +extern DECL_FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16); diff --git a/llama/ggml-cuda/fattn-vec-f32.cuh b/llama/ggml-cuda/fattn-vec-f32.cuh index 614d54ae..ddf0c837 100644 --- a/llama/ggml-cuda/fattn-vec-f32.cuh +++ b/llama/ggml-cuda/fattn-vec-f32.cuh @@ -1,3 +1,374 @@ #include "common.cuh" +#include "fattn-common.cuh" -void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst); +template // D == head size +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) +__launch_bounds__(D, 1) +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) +static __global__ void flash_attn_vec_ext_f32( + const char * __restrict__ Q, + const char * __restrict__ K, + const char * __restrict__ V, + const char * __restrict__ mask, + float * __restrict__ dst, + float2 * __restrict__ dst_meta, + const float scale, + const float max_bias, + const float m0, + const float m1, + const uint32_t n_head_log2, + const int ne00, + const int ne01, + const int ne02, + const int ne03, + const int ne10, + const int ne11, + const int ne12, + const int ne13, + const int ne31, + const int nb31, + const int nb01, + const int nb02, + const int nb03, + const int nb11, + const int nb12, + const int nb13, + const int nb21, + const int nb22, + const int nb23, + const int ne0, + const int ne1, + const int ne2, + const int ne3) { + //In this kernel Q, K, V are matrices while i, j, k are matrix indices. + + constexpr vec_dot_KQ_f32_t vec_dot_KQ = get_vec_dot_KQ_f32(type_K); + constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16; + constexpr dequantize_1_f32_t dequantize_1_v = get_dequantize_1_f32(type_V); + + const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on. + const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel. + + const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. + Q += nb02* blockIdx.y + nb01*ic0; + K += nb12*(blockIdx.y / gqa_ratio); + V += nb22*(blockIdx.y / gqa_ratio); // K and V have same shape + const half * maskh = (const half *) mask + ne11*ic0; + + const float slope = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1); + + static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); + constexpr int nwarps = D / WARP_SIZE; + const int tid = WARP_SIZE*threadIdx.y + threadIdx.x; + __builtin_assume(tid < D); + + __shared__ float KQ[ncols*D]; +#pragma unroll + for (int j = 0; j < ncols; ++j) { + KQ[j*D + tid] = -FLT_MAX/2.0f; + } + + float kqmax[ncols]; +#pragma unroll + for (int j = 0; j < ncols; ++j) { + kqmax[j] = -FLT_MAX/2.0f; + } + float kqsum[ncols] = {0.0f}; + + __shared__ float kqmax_shared[ncols][WARP_SIZE]; + __shared__ float kqsum_shared[ncols][WARP_SIZE]; +#pragma unroll + for (int j = 0; j < ncols; ++j) { + if (threadIdx.y == 0) { + kqmax_shared[j][threadIdx.x] = -FLT_MAX/2.0f; + kqsum_shared[j][threadIdx.x] = 0.0f; + } + } + __syncthreads(); + + // Convert Q to float2 (f16 K) or q8_1 (quantized K) and store in registers: + float2 Q_f2[ncols][D/(2*WARP_SIZE)]; + int Q_i32[ncols][D/(sizeof(int)*QK8_1) == 0 ? 1 : D >= D/(sizeof(int)*QK8_1)]; + float2 Q_ds[ncols][D/QK8_1 == 0 ? 1 : D/QK8_1]; + if (Q_q8_1) { +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; + + if (j0 + nwarps > ncols && j >= ncols) { + break; + } + + // Reuse KQ as temporary storage for converting Q to q8_1: + int * tmp_q_i32 = (int *) &KQ[j*D]; + float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + D/sizeof(int)); + + // Set memory to zero if out of bounds: + if (ncols > 2 && ic0 + j >= ne01) { +#pragma unroll + for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + tmp_q_i32[i] = 0; + } + if (threadIdx.x < D/QK8_1) { + tmp_q_ds[threadIdx.x] = make_float2(0.0f, 0.0f); + } + continue; + } + + const float * Q_f = (const float *) (Q + j*nb01); +#pragma unroll + for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) { + quantize_q8_1_to_shared(Q_f + 4*i0, scale, tmp_q_i32, tmp_q_ds); + } + } + + __syncthreads(); + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + int * tmp_q_i32 = (int *) &KQ[j*D]; + float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + D/sizeof(int)); + +#pragma unroll + for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + Q_i32[j][i0/WARP_SIZE] = tmp_q_i32[i]; + Q_ds[j][i0/WARP_SIZE] = tmp_q_ds[i/QI8_1]; + } + } + + __syncthreads(); + } else { +#pragma unroll + for (int j = 0; j < ncols; ++j) { + const float2 * Q_f2_j = (const float2 *) (Q + j*nb01); +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + Q_f2[j][i0/WARP_SIZE] = ncols <= 2 || ic0 + j ? Q_f2_j[i] : make_float2(0.0f, 0.0f); + Q_f2[j][i0/WARP_SIZE].x *= scale; + Q_f2[j][i0/WARP_SIZE].y *= scale; + } + } + } + + float VKQ[ncols] = {0.0f}; + + const int k_start = parallel_blocks == 1 ? 0 : ip*D; + for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) { + // Calculate KQ tile and keep track of new maximum KQ values: + + float kqmax_new_arr[ncols]; +#pragma unroll + for (int j = 0; j < ncols; ++j) { + kqmax_new_arr[j] = kqmax[j]; + } + +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += nwarps) { + const int i_KQ = i_KQ_0 + threadIdx.y; + + if ((i_KQ_0 + nwarps > D && i_KQ >= D) || (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + i_KQ >= ne11)) { + break; + } + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + float sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_f2[j], Q_i32[j], Q_ds[j]); + sum = warp_reduce_sum(sum); + sum += mask ? slope*__half2float(maskh[j*ne11 + k_VKQ_0 + i_KQ]) : 0.0f; + + kqmax_new_arr[j] = fmaxf(kqmax_new_arr[j], sum); + + if (threadIdx.x == 0) { + KQ[j*D + i_KQ] = sum; + } + } + } + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + float kqmax_new_j = kqmax_new_arr[j]; + + kqmax_new_j = warp_reduce_max(kqmax_new_j); + if (threadIdx.x == 0) { + kqmax_shared[j][threadIdx.y] = kqmax_new_j; + } + } + + __syncthreads(); + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + float kqmax_new_j = kqmax_shared[j][threadIdx.x]; + kqmax_new_j = warp_reduce_max(kqmax_new_j); + + const float KQ_max_scale = expf(kqmax[j] - kqmax_new_j); + kqmax[j] = kqmax_new_j; + + const float val = expf(KQ[j*D + tid] - kqmax[j]); + kqsum[j] = kqsum[j]*KQ_max_scale + val; + KQ[j*D + tid] = val; + + VKQ[j] *= KQ_max_scale; + } + + __syncthreads(); + +#pragma unroll + for (int k = 0; k < D; ++k) { + if (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + k >= ne11) { + break; + } + + const float V_ki = dequantize_1_v(V + (k_VKQ_0 + k)*nb21, tid); +#pragma unroll + for (int j = 0; j < ncols; ++j) { + VKQ[j] += V_ki*KQ[j*D + k]; + } + } + + __syncthreads(); + } + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + kqsum[j] = warp_reduce_sum(kqsum[j]); + if (threadIdx.x == 0) { + kqsum_shared[j][threadIdx.y] = kqsum[j]; + } + } + + __syncthreads(); + +#pragma unroll + for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) { + if (ncols > 2 && ic0 + j_VKQ >= ne01) { + break; + } + + kqsum[j_VKQ] = kqsum_shared[j_VKQ][threadIdx.x]; + kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]); + + float dst_val = VKQ[j_VKQ]; + if (parallel_blocks == 1) { + dst_val /= kqsum[j_VKQ]; + } + const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip; + dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val; + } + + if (parallel_blocks != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) { + dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]); + } +} + +template +void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + constexpr int nwarps = D/WARP_SIZE; + fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32; + constexpr bool need_f16_K = D != 128; + constexpr bool need_f16_V = D != 128 && D != 64; + launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V); +} + +template +void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_tensor * Q = dst->src[0]; + ggml_tensor * K = dst->src[1]; + ggml_tensor * V = dst->src[2]; + + GGML_ASSERT(K->type == type_K); + GGML_ASSERT(V->type == type_V); + + if (Q->ne[1] == 1) { + constexpr int cols_per_block = 1; + constexpr int parallel_blocks = 4; + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + return; + } + + if (Q->ne[1] == 2) { + constexpr int cols_per_block = 2; + constexpr int parallel_blocks = 4; + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + return; + } + + if (Q->ne[1] <= 4) { + constexpr int cols_per_block = 4; + constexpr int parallel_blocks = 4; + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + return; + } + + if (Q->ne[1] <= 8) { + constexpr int cols_per_block = 8; + constexpr int parallel_blocks = 4; + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); + return; + } + + constexpr int cols_per_block = 8; + constexpr int parallel_blocks = 1; + ggml_cuda_flash_attn_ext_vec_f32_case_impl(ctx, dst); +} + +#define DECL_FATTN_VEC_F32_CASE(D, type_K, type_V) \ + template void ggml_cuda_flash_attn_ext_vec_f32_case \ + (ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ + +extern DECL_FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0); +extern DECL_FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1); +extern DECL_FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0); +extern DECL_FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1); +extern DECL_FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0); +extern DECL_FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16); + +extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0); +extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0); +extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0); +extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0); +extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0); +extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0); + +extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1); +extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1); +extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1); +extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1); +extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1); +extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1); + +extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0); +extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0); +extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0); +extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0); +extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0); +extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0); + +extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1); +extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1); +extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1); +extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1); +extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1); +extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1); + +extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0); +extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0); +extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0); +extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0); +extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0); +extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0); + +extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16); +extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16); +extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16); +extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16); +extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16); +extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16); + +extern DECL_FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16); diff --git a/llama/ggml-cuda/fattn-wmma-f16.cuh b/llama/ggml-cuda/fattn-wmma-f16.cuh new file mode 100644 index 00000000..59cd30d7 --- /dev/null +++ b/llama/ggml-cuda/fattn-wmma-f16.cuh @@ -0,0 +1,490 @@ +#include "common.cuh" +#include "fattn-common.cuh" + +#if FP16_MMA_AVAILABLE +#include +#endif + +// D == head size, VKQ_stride == num VKQ rows calculated in parallel: +template +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) +__launch_bounds__(nwarps*WARP_SIZE, 1) +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) +static __global__ void flash_attn_ext_f16( + const char * __restrict__ Q, + const char * __restrict__ K, + const char * __restrict__ V, + const char * __restrict__ mask, + float * __restrict__ dst, + float2 * __restrict__ dst_meta, + const float scale, + const float max_bias, + const float m0, + const float m1, + const uint32_t n_head_log2, + const int ne00, + const int ne01, + const int ne02, + const int ne03, + const int ne10, + const int ne11, + const int ne12, + const int ne13, + const int ne31, + const int nb31, + const int nb01, + const int nb02, + const int nb03, + const int nb11, + const int nb12, + const int nb13, + const int nb21, + const int nb22, + const int nb23, + const int ne0, + const int ne1, + const int ne2, + const int ne3) { +#if FP16_MMA_AVAILABLE + //In this kernel Q, K, V are matrices while i, j, k are matrix indices. + + const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on. + const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel. + + static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE."); + static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16."); + constexpr int frag_m = ncols == 8 ? 32 : 16; + constexpr int frag_n = ncols == 8 ? 8 : 16; + static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0."); + typedef nvcuda::wmma::fragment frag_a_K; + typedef nvcuda::wmma::fragment frag_a_V; + typedef nvcuda::wmma::fragment frag_b; + typedef nvcuda::wmma::fragment frag_c_KQ; + typedef nvcuda::wmma::fragment frag_c_VKQ; + + constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel. + constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy. + static_assert(VKQ_ratio <= nwarps, "VKQ_ratio must be <= nwarps."); + + // Pad internal representation of KQ, KQV to reduce shared memory bank conflicts: + constexpr int D_padded = D + 8; + constexpr int kqs_padded = FATTN_KQ_STRIDE + 8; + constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half); + + const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. + const float * Q_f = (const float *) (Q + nb02* blockIdx.y + nb01*ic0); + const half * K_h = (const half *) (K + nb12*(blockIdx.y / gqa_ratio)); + const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape + const half * maskh = (const half *) mask + (nb31/sizeof(half))* ic0; + const half2 * mask2 = (const half2 *) mask + (nb31/sizeof(half))*(ic0/2); + + const int stride_Q = nb01 / sizeof(float); + const int stride_KV = nb11 / sizeof(half); + + const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1); + const half slopeh = __float2half(slopef); + const half2 slope2 = make_half2(slopef, slopef); + + frag_b Q_b[D/16][ncols/frag_n]; + + // A single buffer for temporarily holding tiles of KQ and VKQ parts: + constexpr int mem_KQ = ncols*kqs_padded*kqar; + constexpr int mem_VKQ_parts = VKQ_ratio*ncols*D_padded; + __shared__ half KQ[mem_KQ >= mem_VKQ_parts ? mem_KQ : mem_VKQ_parts]; + float * KQ_f = (float *) KQ; + half2 * KQ2 = (half2 *) KQ; + + float KQ_rowsum_f[ncols/nwarps] = {0.0f}; + float KQ_max_f[ncols/nwarps]; + float KQ_max_scale_f[ncols/nwarps] = {0.0f}; + +#pragma unroll + for (int j = 0; j < ncols/nwarps; ++j) { + KQ_max_f[j] = -FLT_MAX/2.0f; + } + + half2 KQ_rowsum_h2[ncols/nwarps] = {{0.0f, 0.0f}}; + half2 KQ_max_h2[ncols/nwarps]; + half2 KQ_max_scale_h2[ncols/nwarps] = {{0.0f, 0.0f}}; + +#pragma unroll + for (int j = 0; j < ncols/nwarps; ++j) { + KQ_max_h2[j] = make_half2(-HALF_MAX_HALF, -HALF_MAX_HALF); + } + + __shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice. + half2 * VKQ2 = (half2 *) VKQ; +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + if (i0 + WARP_SIZE > D/2 && i >= D/2) { + break; + } + VKQ2[j*(D_padded/2) + i] = make_half2(0.0f, 0.0f); + } + } + + // Convert Q to half and apply scale, temporarily store in KQ: +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; +#pragma unroll + for (int i0 = 0; i0 < D; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + if (i0 + WARP_SIZE > D && i >= D) { + break; + } + KQ[j*D_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f; + } + } + + __syncthreads(); + + // Load Q into tensor core fragments/registers since it will be used frequently: +#pragma unroll + for (int i0 = 0; i0 < D; i0 += 16) { +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += frag_n) { + nvcuda::wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded); + } + } + + __syncthreads(); + + // Iterate over ne11 == previous tokens: + for (int k_VKQ_0 = ip*FATTN_KQ_STRIDE; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE) { + // Calculate tile of KQ: +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) { + frag_c_KQ KQ_c[ncols/frag_n]; +#pragma unroll + for (int j = 0; j < ncols/frag_n; ++j) { + nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f); + } +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) { + frag_a_K K_a; + nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV); +#pragma unroll + for (int j = 0; j < ncols/frag_n; ++j) { + nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]); + } + } +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += frag_n) { + nvcuda::wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major); + } + } + + __syncthreads(); + + // Calculate softmax for each KQ column using the current max. value. + // The divisor is stored in KQ_rowsum and will be applied at the end. +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; + + if (std::is_same::value) { + float KQ_f_tmp[FATTN_KQ_STRIDE / WARP_SIZE]; +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) { + const int k = k0 + threadIdx.x; + + KQ_f_tmp[k0/WARP_SIZE] = KQ_f[j*kqs_padded + k]; + } + + float KQ_max_new = KQ_max_f[j0/nwarps]; +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) { + const int k = k0 + threadIdx.x; + + KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f; + KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/WARP_SIZE]); + } + KQ_max_new = warp_reduce_max(KQ_max_new); + + const float diff = KQ_max_f[j0/nwarps] - KQ_max_new; + KQ_max_scale_f[j0/nwarps] = expf(diff); + if (diff <= SOFTMAX_FTZ_THRESHOLD) { + KQ_max_scale_f[j0/nwarps] = 0.0f; + } + KQ_max_f[j0/nwarps] = KQ_max_new; + + float KQ_rowsum_add = 0.0f; +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) { + const int k = k0 + threadIdx.x; + + const float diff = KQ_f_tmp[k0/WARP_SIZE] - KQ_max_f[j0/nwarps]; + KQ_f_tmp[k0/WARP_SIZE] = expf(diff); + if (diff <= SOFTMAX_FTZ_THRESHOLD) { + KQ_f_tmp[k0/WARP_SIZE] = 0.0f; + } + KQ_rowsum_add += KQ_f_tmp[k0/WARP_SIZE]; + KQ[j*(kqar*kqs_padded) + k] = KQ_f_tmp[k0/WARP_SIZE]; + } + KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add); + + // Scale previous KQ_rowsum to account for a potential increase in KQ_max: + KQ_rowsum_f[j0/nwarps] = KQ_max_scale_f[j0/nwarps]*KQ_rowsum_f[j0/nwarps] + KQ_rowsum_add; + } else { + half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*WARP_SIZE)]; +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { + const int k = k0 + threadIdx.x; + + KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k]; + } + + half2 KQ_max_new = KQ_max_h2[j0/nwarps]; +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { + const int k = k0 + threadIdx.x; + + KQ2_tmp[k0/WARP_SIZE] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); + KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]); + } + KQ_max_new = __half2half2(warp_reduce_max(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new)))); + const half2 diff = KQ_max_h2[j0/nwarps] - KQ_max_new; + KQ_max_scale_h2[j0/nwarps] = h2exp(diff); + const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD)); + *((uint32_t *) &KQ_max_scale_h2[j0/nwarps]) &= ftz_mask; + KQ_max_h2[j0/nwarps] = KQ_max_new; + + half2 KQ_rowsum_add = make_half2(0.0f, 0.0f); +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { + const int k = k0 + threadIdx.x; + + const half2 diff = KQ2_tmp[k0/WARP_SIZE] - KQ_max_h2[j0/nwarps]; + KQ2_tmp[k0/WARP_SIZE] = h2exp(diff); + const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD)); + *((uint32_t *) &KQ2_tmp[k0/WARP_SIZE]) &= ftz_mask; + KQ_rowsum_add += KQ2_tmp[k0/WARP_SIZE]; + KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/WARP_SIZE]; + } + KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add); + + // Scale previous KQ_rowsum to account for a potential increase in KQ_max: + KQ_rowsum_h2[j0/nwarps] = KQ_max_scale_h2[j0/nwarps]*KQ_rowsum_h2[j0/nwarps] + KQ_rowsum_add; + } + } + + __syncthreads(); + + frag_b KQ_b[FATTN_KQ_STRIDE/(VKQ_ratio*16)][ncols/frag_n]; +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += frag_n) { +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) { + const int k = k0 + (threadIdx.y % VKQ_ratio)*16; + nvcuda::wmma::load_matrix_sync( + KQ_b[k0/(VKQ_ratio*16)][j0/frag_n], + KQ + j0*(kqar*kqs_padded) + k, + kqar*kqs_padded); + } + } + + frag_c_VKQ VKQ_c[D/VKQ_stride][ncols/frag_n]; +#pragma unroll + for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) { +#pragma unroll + for (int j = 0; j < ncols/frag_n; ++j) { + nvcuda::wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f); + } + +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) { + const int k = k0 + (threadIdx.y % VKQ_ratio)*16; + + frag_a_V v_a; + nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV); +#pragma unroll + for (int j = 0; j < ncols/frag_n; ++j) { + nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]); + } + } + } + + __syncthreads(); + + const int offset_k = (threadIdx.y % VKQ_ratio) * (ncols*D_padded); +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) { +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += frag_n) { + nvcuda::wmma::store_matrix_sync( + KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio), + VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n], + D_padded, nvcuda::wmma::mem_col_major); + } + } + + __syncthreads(); + +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; + + half2 VKQ_scale; + if (std::is_same::value) { + VKQ_scale = make_half2(KQ_max_scale_f[j0/nwarps], KQ_max_scale_f[j0/nwarps]); + } else { + VKQ_scale = KQ_max_scale_h2[j0/nwarps]; + } + +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + if (i0 + WARP_SIZE > D/2 && i >= D/2) { + break; + } + + half2 VKQ_add = make_half2(0.0f, 0.0f); +#pragma unroll + for (int l = 0; l < VKQ_ratio; ++l) { + VKQ_add += KQ2[l*(ncols*D_padded/2) + j*(D_padded/2) + i]; + } + VKQ2[j*(D_padded/2) + i] = VKQ_scale*VKQ2[j*(D_padded/2) + i] + VKQ_add; + } + } + + __syncthreads(); + } + +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j_VKQ = j0 + threadIdx.y; + if (ic0 + j_VKQ >= ne01) { + return; + } + const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip; + + float KQ_rowsum_j; + if (std::is_same::value) { + KQ_rowsum_j = KQ_rowsum_f[j0/nwarps]; + } else { + KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]); + } + +#pragma unroll + for (int i0 = 0; i0 < D; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + if (i0 + WARP_SIZE > D && i >= D) { + break; + } + float dst_val = VKQ[j_VKQ*D_padded + i]; + if (parallel_blocks == 1) { + dst_val /= KQ_rowsum_j; + } + dst[j_dst*gridDim.y*D + blockIdx.y*D + i] = dst_val; + } + + if (parallel_blocks == 1 || threadIdx.x != 0) { + continue; + } + + float2 dst_meta_val; + if (std::is_same::value) { + dst_meta_val.x = KQ_max_f[j0/nwarps]; + } else { + dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]); + } + dst_meta_val.y = KQ_rowsum_j; + dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = dst_meta_val; + } +#else + NO_DEVICE_CODE; +#endif // FP16_MMA_AVAILABLE +} + +constexpr int get_max_power_of_2(int x) { + return x % 2 == 0 ? 2*get_max_power_of_2(x/2) : 1; +} + +static_assert(get_max_power_of_2(1) == 1, "Test failed."); +static_assert(get_max_power_of_2(2) == 2, "Test failed."); +static_assert(get_max_power_of_2(4) == 4, "Test failed."); +static_assert(get_max_power_of_2(6) == 2, "Test failed."); + +// Number of VKQ rows calculated in parallel: +constexpr int get_VKQ_stride(int D, int nwarps, int frag_m) { + return (get_max_power_of_2(D/frag_m) < nwarps ? get_max_power_of_2(D/frag_m) : nwarps)*frag_m; +} + +static_assert(get_VKQ_stride(128, 1, 32) == 32, "Test failed."); +static_assert(get_VKQ_stride(128, 2, 32) == 64, "Test failed."); +static_assert(get_VKQ_stride(128, 4, 32) == 128, "Test failed."); +static_assert(get_VKQ_stride( 64, 1, 32) == 32, "Test failed."); +static_assert(get_VKQ_stride( 64, 2, 32) == 64, "Test failed."); +static_assert(get_VKQ_stride( 64, 4, 32) == 64, "Test failed."); +static_assert(get_VKQ_stride( 80, 1, 16) == 16, "Test failed."); +static_assert(get_VKQ_stride( 80, 2, 16) == 16, "Test failed."); +static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed."); + +template +void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * Q = dst->src[0]; + + constexpr int nwarps = 4; + + constexpr int frag_m = cols_per_block == 8 && D % 32 == 0 ? 32 : 16; + const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3]; + const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm; + + if (4*blocks_num_pb1 < 2*nsm) { + constexpr int parallel_blocks = 4; + fattn_kernel_t fattn_kernel = flash_attn_ext_f16; + launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); + return; + } + if (2*blocks_num_pb1 < 2*nsm) { + constexpr int parallel_blocks = 2; + fattn_kernel_t fattn_kernel = flash_attn_ext_f16; + launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); + return; + } + constexpr int parallel_blocks = 1; + fattn_kernel_t fattn_kernel = flash_attn_ext_f16; + launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); +} + +#define DECL_FATTN_WMMA_F16_CASE(D, cols_per_block, KQ_acc_t) \ + template void ggml_cuda_flash_attn_ext_wmma_f16_case \ + (ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ + +extern DECL_FATTN_WMMA_F16_CASE( 64, 16, float); +extern DECL_FATTN_WMMA_F16_CASE( 80, 16, float); +extern DECL_FATTN_WMMA_F16_CASE( 96, 16, float); +extern DECL_FATTN_WMMA_F16_CASE(112, 16, float); +extern DECL_FATTN_WMMA_F16_CASE(128, 16, float); +extern DECL_FATTN_WMMA_F16_CASE(256, 16, float); + +extern DECL_FATTN_WMMA_F16_CASE( 64, 32, float); +extern DECL_FATTN_WMMA_F16_CASE( 80, 32, float); +extern DECL_FATTN_WMMA_F16_CASE( 96, 32, float); +extern DECL_FATTN_WMMA_F16_CASE(112, 32, float); +extern DECL_FATTN_WMMA_F16_CASE(128, 32, float); +// extern DECL_FATTN_WMMA_F16_CASE(256, 16, float); + +extern DECL_FATTN_WMMA_F16_CASE( 64, 8, half); +extern DECL_FATTN_WMMA_F16_CASE( 96, 8, half); +extern DECL_FATTN_WMMA_F16_CASE(128, 8, half); +extern DECL_FATTN_WMMA_F16_CASE(256, 8, half); + +extern DECL_FATTN_WMMA_F16_CASE( 64, 16, half); +extern DECL_FATTN_WMMA_F16_CASE( 80, 16, half); +extern DECL_FATTN_WMMA_F16_CASE( 96, 16, half); +extern DECL_FATTN_WMMA_F16_CASE(112, 16, half); +extern DECL_FATTN_WMMA_F16_CASE(128, 16, half); +extern DECL_FATTN_WMMA_F16_CASE(256, 16, half); + +extern DECL_FATTN_WMMA_F16_CASE( 64, 32, half); +extern DECL_FATTN_WMMA_F16_CASE( 80, 32, half); +extern DECL_FATTN_WMMA_F16_CASE( 96, 32, half); +extern DECL_FATTN_WMMA_F16_CASE(112, 32, half); +extern DECL_FATTN_WMMA_F16_CASE(128, 32, half); +extern DECL_FATTN_WMMA_F16_CASE(256, 16, half); diff --git a/llama/ggml-cuda/fattn.cu b/llama/ggml-cuda/fattn.cu index af7c9523..38d30b21 100644 --- a/llama/ggml-cuda/fattn.cu +++ b/llama/ggml-cuda/fattn.cu @@ -4,454 +4,295 @@ #include "fattn-tile-f32.cuh" #include "fattn-vec-f16.cuh" #include "fattn-vec-f32.cuh" +#include "fattn-wmma-f16.cuh" #include "fattn.cuh" #include -#if FP16_MMA_AVAILABLE -#include -#endif +static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * KQV = dst; + const ggml_tensor * Q = dst->src[0]; -// D == head size, VKQ_stride == num VKQ rows calculated in parallel: -template -#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) -__launch_bounds__(nwarps*WARP_SIZE, 1) -#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) -static __global__ void flash_attn_ext_f16( - const char * __restrict__ Q, - const char * __restrict__ K, - const char * __restrict__ V, - const char * __restrict__ mask, - float * __restrict__ dst, - float2 * __restrict__ dst_meta, - const float scale, - const float max_bias, - const float m0, - const float m1, - const uint32_t n_head_log2, - const int ne00, - const int ne01, - const int ne02, - const int ne03, - const int ne10, - const int ne11, - const int ne12, - const int ne13, - const int ne31, - const int nb31, - const int nb01, - const int nb02, - const int nb03, - const int nb11, - const int nb12, - const int nb13, - const int ne0, - const int ne1, - const int ne2, - const int ne3) { -#if FP16_MMA_AVAILABLE - //In this kernel Q, K, V are matrices while i, j, k are matrix indices. + const int32_t precision = KQV->op_params[2]; - const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on. - const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel. - - static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE."); - static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16."); - constexpr int frag_m = ncols == 8 ? 32 : 16; - constexpr int frag_n = ncols == 8 ? 8 : 16; - static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0."); - typedef nvcuda::wmma::fragment frag_a_K; - typedef nvcuda::wmma::fragment frag_a_V; - typedef nvcuda::wmma::fragment frag_b; - typedef nvcuda::wmma::fragment frag_c_KQ; - typedef nvcuda::wmma::fragment frag_c_VKQ; - - constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel. - constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy. - static_assert(VKQ_ratio <= nwarps, "VKQ_ratio must be <= nwarps."); - - // Pad internal representation of KQ, KQV to reduce shared memory bank conflicts: - constexpr int D_padded = D + 8; - constexpr int kqs_padded = FATTN_KQ_STRIDE + 8; - constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half); - - const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - const float * Q_f = (const float *) (Q + nb02* blockIdx.y + nb01*ic0); - const half * K_h = (const half *) (K + nb12*(blockIdx.y / gqa_ratio)); - const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape - const half * maskh = (const half *) mask + (nb31/sizeof(half))* ic0; - const half2 * mask2 = (const half2 *) mask + (nb31/sizeof(half))*(ic0/2); - - const int stride_Q = nb01 / sizeof(float); - const int stride_KV = nb11 / sizeof(half); - - const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1); - const half slopeh = __float2half(slopef); - const half2 slope2 = make_half2(slopef, slopef); - - frag_b Q_b[D/16][ncols/frag_n]; - - // A single buffer for temporarily holding tiles of KQ and VKQ parts: - constexpr int mem_KQ = ncols*kqs_padded*kqar; - constexpr int mem_VKQ_parts = VKQ_ratio*ncols*D_padded; - __shared__ half KQ[mem_KQ >= mem_VKQ_parts ? mem_KQ : mem_VKQ_parts]; - float * KQ_f = (float *) KQ; - half2 * KQ2 = (half2 *) KQ; - - float KQ_rowsum_f[ncols/nwarps] = {0.0f}; - float KQ_max_f[ncols/nwarps]; - float KQ_max_scale_f[ncols/nwarps] = {0.0f}; - -#pragma unroll - for (int j = 0; j < ncols/nwarps; ++j) { - KQ_max_f[j] = -FLT_MAX/2.0f; - } - - half2 KQ_rowsum_h2[ncols/nwarps] = {{0.0f, 0.0f}}; - half2 KQ_max_h2[ncols/nwarps]; - half2 KQ_max_scale_h2[ncols/nwarps] = {{0.0f, 0.0f}}; - -#pragma unroll - for (int j = 0; j < ncols/nwarps; ++j) { - KQ_max_h2[j] = make_half2(-HALF_MAX_HALF, -HALF_MAX_HALF); - } - - __shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice. - half2 * VKQ2 = (half2 *) VKQ; -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - const int j = j0 + threadIdx.y; -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - if (i0 + WARP_SIZE > D/2 && i >= D/2) { - break; - } - VKQ2[j*(D_padded/2) + i] = make_half2(0.0f, 0.0f); - } - } - - // Convert Q to half and apply scale, temporarily store in KQ: -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - const int j = j0 + threadIdx.y; -#pragma unroll - for (int i0 = 0; i0 < D; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - if (i0 + WARP_SIZE > D && i >= D) { - break; - } - KQ[j*D_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f; - } - } - - __syncthreads(); - - // Load Q into tensor core fragments/registers since it will be used frequently: -#pragma unroll - for (int i0 = 0; i0 < D; i0 += 16) { -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += frag_n) { - nvcuda::wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded); - } - } - - __syncthreads(); - - // Iterate over ne11 == previous tokens: - for (int k_VKQ_0 = ip*FATTN_KQ_STRIDE; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE) { - // Calculate tile of KQ: -#pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) { - frag_c_KQ KQ_c[ncols/frag_n]; -#pragma unroll - for (int j = 0; j < ncols/frag_n; ++j) { - nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f); - } -#pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) { - frag_a_K K_a; - nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV); -#pragma unroll - for (int j = 0; j < ncols/frag_n; ++j) { - nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]); - } - } -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += frag_n) { - nvcuda::wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major); - } - } - - __syncthreads(); - - // Calculate softmax for each KQ column using the current max. value. - // The divisor is stored in KQ_rowsum and will be applied at the end. -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - const int j = j0 + threadIdx.y; - - if (std::is_same::value) { - float KQ_f_tmp[FATTN_KQ_STRIDE / WARP_SIZE]; -#pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) { - const int k = k0 + threadIdx.x; - - KQ_f_tmp[k0/WARP_SIZE] = KQ_f[j*kqs_padded + k]; - } - - float KQ_max_new = KQ_max_f[j0/nwarps]; -#pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) { - const int k = k0 + threadIdx.x; - - KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f; - KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/WARP_SIZE]); - } - KQ_max_new = warp_reduce_max(KQ_max_new); - - const float diff = KQ_max_f[j0/nwarps] - KQ_max_new; - KQ_max_scale_f[j0/nwarps] = expf(diff); - if (diff <= SOFTMAX_FTZ_THRESHOLD) { - KQ_max_scale_f[j0/nwarps] = 0.0f; - } - KQ_max_f[j0/nwarps] = KQ_max_new; - - float KQ_rowsum_add = 0.0f; -#pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) { - const int k = k0 + threadIdx.x; - - const float diff = KQ_f_tmp[k0/WARP_SIZE] - KQ_max_f[j0/nwarps]; - KQ_f_tmp[k0/WARP_SIZE] = expf(diff); - if (diff <= SOFTMAX_FTZ_THRESHOLD) { - KQ_f_tmp[k0/WARP_SIZE] = 0.0f; - } - KQ_rowsum_add += KQ_f_tmp[k0/WARP_SIZE]; - KQ[j*(kqar*kqs_padded) + k] = KQ_f_tmp[k0/WARP_SIZE]; - } - KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add); - - // Scale previous KQ_rowsum to account for a potential increase in KQ_max: - KQ_rowsum_f[j0/nwarps] = KQ_max_scale_f[j0/nwarps]*KQ_rowsum_f[j0/nwarps] + KQ_rowsum_add; - } else { - half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*WARP_SIZE)]; -#pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { - const int k = k0 + threadIdx.x; - - KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k]; - } - - half2 KQ_max_new = KQ_max_h2[j0/nwarps]; -#pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { - const int k = k0 + threadIdx.x; - - KQ2_tmp[k0/WARP_SIZE] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); - KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]); - } - KQ_max_new = __half2half2(warp_reduce_max(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new)))); - const half2 diff = KQ_max_h2[j0/nwarps] - KQ_max_new; - KQ_max_scale_h2[j0/nwarps] = h2exp(diff); - const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD)); - *((uint32_t *) &KQ_max_scale_h2[j0/nwarps]) &= ftz_mask; - KQ_max_h2[j0/nwarps] = KQ_max_new; - - half2 KQ_rowsum_add = make_half2(0.0f, 0.0f); -#pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { - const int k = k0 + threadIdx.x; - - const half2 diff = KQ2_tmp[k0/WARP_SIZE] - KQ_max_h2[j0/nwarps]; - KQ2_tmp[k0/WARP_SIZE] = h2exp(diff); - const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD)); - *((uint32_t *) &KQ2_tmp[k0/WARP_SIZE]) &= ftz_mask; - KQ_rowsum_add += KQ2_tmp[k0/WARP_SIZE]; - KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/WARP_SIZE]; - } - KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add); - - // Scale previous KQ_rowsum to account for a potential increase in KQ_max: - KQ_rowsum_h2[j0/nwarps] = KQ_max_scale_h2[j0/nwarps]*KQ_rowsum_h2[j0/nwarps] + KQ_rowsum_add; - } - } - - __syncthreads(); - - frag_b KQ_b[FATTN_KQ_STRIDE/(VKQ_ratio*16)][ncols/frag_n]; -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += frag_n) { -#pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) { - const int k = k0 + (threadIdx.y % VKQ_ratio)*16; - nvcuda::wmma::load_matrix_sync( - KQ_b[k0/(VKQ_ratio*16)][j0/frag_n], - KQ + j0*(kqar*kqs_padded) + k, - kqar*kqs_padded); - } - } - - frag_c_VKQ VKQ_c[D/VKQ_stride][ncols/frag_n]; -#pragma unroll - for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) { -#pragma unroll - for (int j = 0; j < ncols/frag_n; ++j) { - nvcuda::wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f); - } - -#pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) { - const int k = k0 + (threadIdx.y % VKQ_ratio)*16; - - frag_a_V v_a; - nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV); -#pragma unroll - for (int j = 0; j < ncols/frag_n; ++j) { - nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]); - } - } - } - - __syncthreads(); - - const int offset_k = (threadIdx.y % VKQ_ratio) * (ncols*D_padded); -#pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) { -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += frag_n) { - nvcuda::wmma::store_matrix_sync( - KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio), - VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n], - D_padded, nvcuda::wmma::mem_col_major); - } - } - - __syncthreads(); - -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - const int j = j0 + threadIdx.y; - - half2 VKQ_scale; - if (std::is_same::value) { - VKQ_scale = make_half2(KQ_max_scale_f[j0/nwarps], KQ_max_scale_f[j0/nwarps]); - } else { - VKQ_scale = KQ_max_scale_h2[j0/nwarps]; - } - -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - if (i0 + WARP_SIZE > D/2 && i >= D/2) { + if (precision != GGML_PREC_DEFAULT) { + if (Q->ne[1] <= 32 || Q->ne[0] > 128) { + constexpr int cols_per_block = 16; + switch (Q->ne[0]) { + case 64: + ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst); + break; + case 80: + ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst); + break; + case 96: + ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst); + break; + case 112: + ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst); + break; + case 128: + ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst); + break; + case 256: + ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst); + break; + default: + GGML_ASSERT(false); + break; + } + } else { + constexpr int cols_per_block = 32; + switch (Q->ne[0]) { + case 64: + ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst); + break; + case 80: + ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst); + break; + case 96: + ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst); + break; + case 112: + ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst); + break; + case 128: + ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst); + break; + // case 256: + // ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst); + // break; + default: + GGML_ASSERT(false); break; - } - - half2 VKQ_add = make_half2(0.0f, 0.0f); -#pragma unroll - for (int l = 0; l < VKQ_ratio; ++l) { - VKQ_add += KQ2[l*(ncols*D_padded/2) + j*(D_padded/2) + i]; - } - VKQ2[j*(D_padded/2) + i] = VKQ_scale*VKQ2[j*(D_padded/2) + i] + VKQ_add; } } - - __syncthreads(); + return; } -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - const int j_VKQ = j0 + threadIdx.y; - if (ic0 + j_VKQ >= ne01) { - return; - } - const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip; - - float KQ_rowsum_j; - if (std::is_same::value) { - KQ_rowsum_j = KQ_rowsum_f[j0/nwarps]; - } else { - KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]); - } - -#pragma unroll - for (int i0 = 0; i0 < D; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - if (i0 + WARP_SIZE > D && i >= D) { + if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) { + constexpr int cols_per_block = 8; + switch (Q->ne[0]) { + case 64: + ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst); + break; + case 96: + ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst); + break; + case 128: + ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst); + break; + case 256: + ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst); + break; + default: + GGML_ASSERT(false); break; - } - float dst_val = VKQ[j_VKQ*D_padded + i]; - if (parallel_blocks == 1) { - dst_val /= KQ_rowsum_j; - } - dst[j_dst*gridDim.y*D + blockIdx.y*D + i] = dst_val; } - - if (parallel_blocks == 1 || threadIdx.x != 0) { - continue; - } - - float2 dst_meta_val; - if (std::is_same::value) { - dst_meta_val.x = KQ_max_f[j0/nwarps]; - } else { - dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]); - } - dst_meta_val.y = KQ_rowsum_j; - dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = dst_meta_val; + return; } + + if (Q->ne[1] <= 32) { + constexpr int cols_per_block = 16; + switch (Q->ne[0]) { + case 64: + ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst); + break; + case 80: + ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst); + break; + case 96: + ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst); + break; + case 112: + ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst); + break; + case 128: + ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst); + break; + case 256: + ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst); + break; + default: + GGML_ASSERT(false); + break; + } + return; + } + + constexpr int cols_per_block = 32; + switch (Q->ne[0]) { + case 64: + ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst); + break; + case 80: + ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst); + break; + case 96: + ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst); + break; + case 112: + ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst); + break; + case 128: + ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst); + break; + case 256: + ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst); + break; + default: + GGML_ASSERT(false); + break; + } +} +#define FATTN_VEC_F16_CASE(D, type_K, type_V) \ + if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) { \ + ggml_cuda_flash_attn_ext_vec_f16_case(ctx, dst); \ + return; \ + } \ + +static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_tensor * Q = dst->src[1]; + ggml_tensor * K = dst->src[1]; + ggml_tensor * V = dst->src[2]; + +#ifdef GGML_CUDA_FA_ALL_QUANTS + FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0) + FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1) + FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0) + FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1) + FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0) + FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16 ) + + FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0) + FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0) + FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0) + FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0) + FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0) + FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0) + + FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1) + FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1) + FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1) + FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1) + FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1) + FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1) + + FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0) + FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0) + FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0) + FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0) + FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0) + FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0) + + FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1) + FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1) + FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1) + FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1) + FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1) + FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1) + + FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0) + FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0) + FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0) + FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0) + FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) + FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0) + + FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16) + FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16) + FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16) + FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16) + FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16) + FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16) + + FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16) #else - NO_DEVICE_CODE; -#endif // FP16_MMA_AVAILABLE + FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0) + + FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) + + FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16) + FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16) + FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16) +#endif // GGML_CUDA_FA_ALL_QUANTS + + on_no_fattn_vec_case(Q->ne[0]); } -constexpr int get_max_power_of_2(int x) { - return x % 2 == 0 ? 2*get_max_power_of_2(x/2) : 1; -} +#define FATTN_VEC_F32_CASE(D, type_K, type_V) \ + if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) { \ + ggml_cuda_flash_attn_ext_vec_f32_case(ctx, dst); \ + return; \ + } \ -static_assert(get_max_power_of_2(1) == 1, "Test failed."); -static_assert(get_max_power_of_2(2) == 2, "Test failed."); -static_assert(get_max_power_of_2(4) == 4, "Test failed."); -static_assert(get_max_power_of_2(6) == 2, "Test failed."); +static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_tensor * Q = dst->src[1]; + ggml_tensor * K = dst->src[1]; + ggml_tensor * V = dst->src[2]; -// Number of VKQ rows calculated in parallel: -constexpr int get_VKQ_stride(int D, int nwarps, int frag_m) { - return (get_max_power_of_2(D/frag_m) < nwarps ? get_max_power_of_2(D/frag_m) : nwarps)*frag_m; -} +#ifdef GGML_CUDA_FA_ALL_QUANTS + FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0) + FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1) + FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0) + FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1) + FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0) + FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16) -static_assert(get_VKQ_stride(128, 1, 32) == 32, "Test failed."); -static_assert(get_VKQ_stride(128, 2, 32) == 64, "Test failed."); -static_assert(get_VKQ_stride(128, 4, 32) == 128, "Test failed."); -static_assert(get_VKQ_stride( 64, 1, 32) == 32, "Test failed."); -static_assert(get_VKQ_stride( 64, 2, 32) == 64, "Test failed."); -static_assert(get_VKQ_stride( 64, 4, 32) == 64, "Test failed."); -static_assert(get_VKQ_stride( 80, 1, 16) == 16, "Test failed."); -static_assert(get_VKQ_stride( 80, 2, 16) == 16, "Test failed."); -static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed."); + FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0) + FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0) + FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0) + FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0) + FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0) + FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0) -template -void launch_fattn_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * Q = dst->src[0]; + FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1) + FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1) + FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1) + FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1) + FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1) + FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1) - constexpr int frag_m = cols_per_block == 8 && D % 32 == 0 ? 32 : 16; - const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3]; - const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm; + FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0) + FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0) + FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0) + FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0) + FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0) + FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0) - if (4*blocks_num_pb1 < 2*nsm) { - constexpr int parallel_blocks = 4; - fattn_kernel_t fattn_kernel = flash_attn_ext_f16; - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block); - return; - } - if (2*blocks_num_pb1 < 2*nsm) { - constexpr int parallel_blocks = 2; - fattn_kernel_t fattn_kernel = flash_attn_ext_f16; - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block); - return; - } - constexpr int parallel_blocks = 1; - fattn_kernel_t fattn_kernel = flash_attn_ext_f16; - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block); + FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1) + FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1) + FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1) + FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1) + FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1) + FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1) + + FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0) + FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0) + FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0) + FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0) + FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) + FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0) + + FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16) + FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16) + FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16) + FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16) + FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16) + FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16) + + FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16) +#else + FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0) + + FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) + + FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16) + FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16) + FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16) +#endif // GGML_CUDA_FA_ALL_QUANTS + + on_no_fattn_vec_case(Q->ne[0]); } void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { @@ -464,8 +305,8 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst // On AMD the tile kernels perform poorly, use the vec kernel instead: if (cc >= CC_OFFSET_AMD) { - if (precision == GGML_PREC_DEFAULT) { - ggml_cuda_flash_attn_ext_vec_f16_no_mma(ctx, dst); + if (precision == GGML_PREC_DEFAULT && fast_fp16_available(cc)) { + ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); } else { ggml_cuda_flash_attn_ext_vec_f32(ctx, dst); } @@ -483,156 +324,22 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst if (!fp16_mma_available(cc)) { if (Q->ne[1] <= 8) { - ggml_cuda_flash_attn_ext_vec_f16_no_mma(ctx, dst); + ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); } else { ggml_cuda_flash_attn_ext_tile_f16(ctx, dst); } return; } - if (precision != GGML_PREC_DEFAULT) { - if (Q->ne[1] == 1 && (Q->ne[0] == 64 || Q->ne[0] == 128)) { + if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0) { + if (precision == GGML_PREC_DEFAULT) { + ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); + return; + } else if(Q->ne[0] <= 128) { ggml_cuda_flash_attn_ext_vec_f32(ctx, dst); return; } - - if (Q->ne[1] <= 32 || Q->ne[0] > 128) { - constexpr int cols_per_block = 16; - constexpr int nwarps = 4; - switch (Q->ne[0]) { - case 64: - launch_fattn_f16< 64, cols_per_block, nwarps, float>(ctx, dst); - break; - case 80: - launch_fattn_f16< 80, cols_per_block, nwarps, float>(ctx, dst); - break; - case 96: - launch_fattn_f16< 96, cols_per_block, nwarps, float>(ctx, dst); - break; - case 112: - launch_fattn_f16<112, cols_per_block, nwarps, float>(ctx, dst); - break; - case 128: - launch_fattn_f16<128, cols_per_block, nwarps, float>(ctx, dst); - break; - case 256: - launch_fattn_f16<256, cols_per_block, nwarps, float>(ctx, dst); - break; - default: - GGML_ASSERT(false); - break; - } - } else { - constexpr int cols_per_block = 32; - constexpr int nwarps = 4; - switch (Q->ne[0]) { - case 64: - launch_fattn_f16< 64, cols_per_block, nwarps, float>(ctx, dst); - break; - case 80: - launch_fattn_f16< 80, cols_per_block, nwarps, float>(ctx, dst); - break; - case 96: - launch_fattn_f16< 96, cols_per_block, nwarps, float>(ctx, dst); - break; - case 112: - launch_fattn_f16<112, cols_per_block, nwarps, float>(ctx, dst); - break; - case 128: - launch_fattn_f16<128, cols_per_block, nwarps, float>(ctx, dst); - break; - // case 256: - // launch_fattn_f16<256, cols_per_block, nwarps, float>(ctx, dst); - // break; - default: - GGML_ASSERT(false); - break; - } - } - return; } - if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0) { - ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); - return; - } - - if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) { - constexpr int cols_per_block = 8; - constexpr int nwarps = 4; - switch (Q->ne[0]) { - case 64: - launch_fattn_f16< 64, cols_per_block, nwarps, half>(ctx, dst); - break; - case 96: - launch_fattn_f16< 96, cols_per_block, nwarps, half>(ctx, dst); - break; - case 128: - launch_fattn_f16<128, cols_per_block, nwarps, half>(ctx, dst); - break; - case 256: - launch_fattn_f16<256, cols_per_block, nwarps, half>(ctx, dst); - break; - default: - GGML_ASSERT(false); - break; - } - return; - } - - if (Q->ne[1] <= 32) { - constexpr int cols_per_block = 16; - constexpr int nwarps = 4; - switch (Q->ne[0]) { - case 64: - launch_fattn_f16< 64, cols_per_block, nwarps, half>(ctx, dst); - break; - case 80: - launch_fattn_f16< 80, cols_per_block, nwarps, half>(ctx, dst); - break; - case 96: - launch_fattn_f16< 96, cols_per_block, nwarps, half>(ctx, dst); - break; - case 112: - launch_fattn_f16<112, cols_per_block, nwarps, half>(ctx, dst); - break; - case 128: - launch_fattn_f16<128, cols_per_block, nwarps, half>(ctx, dst); - break; - case 256: - launch_fattn_f16<256, cols_per_block, nwarps, half>(ctx, dst); - break; - default: - GGML_ASSERT(false); - break; - } - return; - } - - constexpr int cols_per_block = 32; - constexpr int nwarps = 4; - switch (Q->ne[0]) { - case 64: - launch_fattn_f16< 64, cols_per_block, nwarps, half>(ctx, dst); - break; - case 80: - launch_fattn_f16< 80, cols_per_block, nwarps, half>(ctx, dst); - break; - case 96: - launch_fattn_f16< 96, cols_per_block, nwarps, half>(ctx, dst); - break; - case 112: - launch_fattn_f16<112, cols_per_block, nwarps, half>(ctx, dst); - break; - case 128: - launch_fattn_f16<128, cols_per_block, nwarps, half>(ctx, dst); - break; - case 256: - launch_fattn_f16<256, cols_per_block, nwarps, half>(ctx, dst); - break; - default: - GGML_ASSERT(false); - break; - } - return; + ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst); } diff --git a/llama/ggml-cuda/mmq.cu b/llama/ggml-cuda/mmq.cu index 7948f1b1..58799e4c 100644 --- a/llama/ggml-cuda/mmq.cu +++ b/llama/ggml-cuda/mmq.cu @@ -1,2178 +1,4 @@ #include "mmq.cuh" -#include "vecdotq.cuh" - -typedef void (*allocate_tiles_cuda_t)(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc); -typedef void (*load_tiles_cuda_t)( - const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, - int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row); -typedef float (*vec_dot_q_mul_mat_cuda_t)( - const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, - const int * __restrict__ y_qs, const half2 * __restrict__ y_ms, const int & i, const int & j, const int & k); -typedef void (*dot_kernel_k_t)(const void * __restrict__ vx, const int ib, const int iqs, const float * __restrict__ y, float & v); - -template static __device__ __forceinline__ void allocate_tiles_q4_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { - GGML_UNUSED(x_qh); - GGML_UNUSED(x_sc); - - __shared__ int tile_x_qs[mmq_y * (WARP_SIZE) + mmq_y]; - __shared__ float tile_x_d[mmq_y * (WARP_SIZE/QI4_0) + mmq_y/QI4_0]; - - *x_ql = tile_x_qs; - *x_dm = (half2 *) tile_x_d; -} - -template static __device__ __forceinline__ void load_tiles_q4_0( - const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, - int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { - GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); - GGML_CUDA_ASSUME(i_offset >= 0); - GGML_CUDA_ASSUME(i_offset < nwarps); - GGML_CUDA_ASSUME(k >= 0); - GGML_CUDA_ASSUME(k < WARP_SIZE); - - const int kbx = k / QI4_0; - const int kqsx = k % QI4_0; - - const block_q4_0 * bx0 = (const block_q4_0 *) vx; - - float * x_dmf = (float *) x_dm; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { - int i = i0 + i_offset; - - if (need_check) { - i = min(i, i_max); - } - - const block_q4_0 * bxi = bx0 + i*blocks_per_row + kbx; - - x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bxi->qs, kqsx); - // x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbx] = bxi->d; - } - - const int blocks_per_tile_x_row = WARP_SIZE / QI4_0; - const int kbxd = k % blocks_per_tile_x_row; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_0) { - int i = i0 + i_offset * QI4_0 + k / blocks_per_tile_x_row; - - if (need_check) { - i = min(i, i_max); - } - - const block_q4_0 * bxi = bx0 + i*blocks_per_row + kbxd; - - x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbxd] = bxi->d; - } -} - -static __device__ __forceinline__ float vec_dot_q4_0_q8_1_mul_mat( - const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, - const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { - GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); - - const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2)); - const float * x_dmf = (const float *) x_dm; - - int u[2*VDR_Q4_0_Q8_1_MMQ]; - -#pragma unroll - for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) { - u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE]; - u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_0) % WARP_SIZE]; - } - - return vec_dot_q4_0_q8_1_impl - (&x_ql[i * (WARP_SIZE + 1) + k], u, x_dmf[i * (WARP_SIZE/QI4_0) + i/QI4_0 + k/QI4_0], - y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]); -} - -template static __device__ __forceinline__ void allocate_tiles_q4_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { - GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); - - __shared__ int tile_x_qs[mmq_y * (WARP_SIZE) + + mmq_y]; - __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI4_1) + mmq_y/QI4_1]; - - *x_ql = tile_x_qs; - *x_dm = tile_x_dm; -} - -template static __device__ __forceinline__ void load_tiles_q4_1( - const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, - int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { - GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); - - GGML_CUDA_ASSUME(i_offset >= 0); - GGML_CUDA_ASSUME(i_offset < nwarps); - GGML_CUDA_ASSUME(k >= 0); - GGML_CUDA_ASSUME(k < WARP_SIZE); - - const int kbx = k / QI4_1; - const int kqsx = k % QI4_1; - - const block_q4_1 * bx0 = (const block_q4_1 *) vx; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { - int i = i0 + i_offset; - - if (need_check) { - i = min(i, i_max); - } - - const block_q4_1 * bxi = bx0 + i*blocks_per_row + kbx; - - x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx); - } - - const int blocks_per_tile_x_row = WARP_SIZE / QI4_1; - const int kbxd = k % blocks_per_tile_x_row; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_1) { - int i = i0 + i_offset * QI4_1 + k / blocks_per_tile_x_row; - - if (need_check) { - i = min(i, i_max); - } - - const block_q4_1 * bxi = bx0 + i*blocks_per_row + kbxd; - - x_dm[i * (WARP_SIZE/QI4_1) + i / QI4_1 + kbxd] = bxi->dm; - } -} - -static __device__ __forceinline__ float vec_dot_q4_1_q8_1_mul_mat( - const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, - const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { - GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); - - const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2)); - - int u[2*VDR_Q4_1_Q8_1_MMQ]; - -#pragma unroll - for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) { - u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE]; - u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_1) % WARP_SIZE]; - } - - return vec_dot_q4_1_q8_1_impl - (&x_ql[i * (WARP_SIZE + 1) + k], u, x_dm[i * (WARP_SIZE/QI4_1) + i/QI4_1 + k/QI4_1], - y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]); -} - -template static __device__ __forceinline__ void allocate_tiles_q5_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { - GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); - - __shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE) + mmq_y]; - __shared__ float tile_x_d[mmq_y * (WARP_SIZE/QI5_0) + mmq_y/QI5_0]; - - *x_ql = tile_x_ql; - *x_dm = (half2 *) tile_x_d; -} - -template static __device__ __forceinline__ void load_tiles_q5_0( - const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, - int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { - GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); - - GGML_CUDA_ASSUME(i_offset >= 0); - GGML_CUDA_ASSUME(i_offset < nwarps); - GGML_CUDA_ASSUME(k >= 0); - GGML_CUDA_ASSUME(k < WARP_SIZE); - - const int kbx = k / QI5_0; - const int kqsx = k % QI5_0; - - const block_q5_0 * bx0 = (const block_q5_0 *) vx; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { - int i = i0 + i_offset; - - if (need_check) { - i = min(i, i_max); - } - - const block_q5_0 * bxi = bx0 + i*blocks_per_row + kbx; - - const int ql = get_int_from_uint8(bxi->qs, kqsx); - const int qh = get_int_from_uint8(bxi->qh, 0) >> (4 * (k % QI5_0)); - - int qs0 = (ql >> 0) & 0x0F0F0F0F; - qs0 |= (qh << 4) & 0x00000010; // 0 -> 4 - qs0 |= (qh << 11) & 0x00001000; // 1 -> 12 - qs0 |= (qh << 18) & 0x00100000; // 2 -> 20 - qs0 |= (qh << 25) & 0x10000000; // 3 -> 28 - qs0 = __vsubss4(qs0, 0x10101010); // subtract 16 - - x_ql[i * (2*WARP_SIZE + 1) + 2*k+0] = qs0; - - int qs1 = (ql >> 4) & 0x0F0F0F0F; - qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4 - qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12 - qs1 |= (qh << 2) & 0x00100000; // 18 -> 20 - qs1 |= (qh << 9) & 0x10000000; // 19 -> 28 - qs1 = __vsubss4(qs1, 0x10101010); // subtract 16 - - x_ql[i * (2*WARP_SIZE + 1) + 2*k+1] = qs1; - } - - const int blocks_per_tile_x_row = WARP_SIZE / QI5_0; - const int kbxd = k % blocks_per_tile_x_row; - float * x_dmf = (float *) x_dm; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_0) { - int i = i0 + i_offset * QI5_0 + k / blocks_per_tile_x_row; - - if (need_check) { - i = min(i, i_max); - } - - const block_q5_0 * bxi = bx0 + i*blocks_per_row + kbxd; - - x_dmf[i * (WARP_SIZE/QI5_0) + i / QI5_0 + kbxd] = bxi->d; - } -} - -static __device__ __forceinline__ float vec_dot_q5_0_q8_1_mul_mat( - const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, - const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { - GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); - - const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2)); - const int index_bx = i * (WARP_SIZE/QI5_0) + i/QI5_0 + k/QI5_0; - const float * x_dmf = (const float *) x_dm; - const float * y_df = (const float *) y_ds; - - int u[2*VDR_Q5_0_Q8_1_MMQ]; - -#pragma unroll - for (int l = 0; l < VDR_Q5_0_Q8_1_MMQ; ++l) { - u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE]; - u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_0) % WARP_SIZE]; - } - - return vec_dot_q8_0_q8_1_impl - (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dmf[index_bx], y_df[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]); -} - - -template static __device__ __forceinline__ void allocate_tiles_q5_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { - GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); - - __shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE) + mmq_y]; - __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI5_1) + mmq_y/QI5_1]; - - *x_ql = tile_x_ql; - *x_dm = tile_x_dm; -} - -template static __device__ __forceinline__ void load_tiles_q5_1( - const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, - int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { - GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); - - GGML_CUDA_ASSUME(i_offset >= 0); - GGML_CUDA_ASSUME(i_offset < nwarps); - GGML_CUDA_ASSUME(k >= 0); - GGML_CUDA_ASSUME(k < WARP_SIZE); - - const int kbx = k / QI5_1; - const int kqsx = k % QI5_1; - - const block_q5_1 * bx0 = (const block_q5_1 *) vx; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { - int i = i0 + i_offset; - - if (need_check) { - i = min(i, i_max); - } - - const block_q5_1 * bxi = bx0 + i*blocks_per_row + kbx; - - const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx); - const int qh = get_int_from_uint8_aligned(bxi->qh, 0) >> (4 * (k % QI5_1)); - - int qs0 = (ql >> 0) & 0x0F0F0F0F; - qs0 |= (qh << 4) & 0x00000010; // 0 -> 4 - qs0 |= (qh << 11) & 0x00001000; // 1 -> 12 - qs0 |= (qh << 18) & 0x00100000; // 2 -> 20 - qs0 |= (qh << 25) & 0x10000000; // 3 -> 28 - - x_ql[i * (2*WARP_SIZE + 1) + 2*k+0] = qs0; - - int qs1 = (ql >> 4) & 0x0F0F0F0F; - qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4 - qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12 - qs1 |= (qh << 2) & 0x00100000; // 18 -> 20 - qs1 |= (qh << 9) & 0x10000000; // 19 -> 28 - - x_ql[i * (2*WARP_SIZE + 1) + 2*k+1] = qs1; - } - - const int blocks_per_tile_x_row = WARP_SIZE / QI5_1; - const int kbxd = k % blocks_per_tile_x_row; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_1) { - int i = i0 + i_offset * QI5_1 + k / blocks_per_tile_x_row; - - if (need_check) { - i = min(i, i_max); - } - - const block_q5_1 * bxi = bx0 + i*blocks_per_row + kbxd; - - x_dm[i * (WARP_SIZE/QI5_1) + i / QI5_1 + kbxd] = bxi->dm; - } -} - -static __device__ __forceinline__ float vec_dot_q5_1_q8_1_mul_mat( - const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, - const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { - GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); - - const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2)); - const int index_bx = i * (WARP_SIZE/QI5_1) + + i/QI5_1 + k/QI5_1; - - int u[2*VDR_Q5_1_Q8_1_MMQ]; - -#pragma unroll - for (int l = 0; l < VDR_Q5_1_Q8_1_MMQ; ++l) { - u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE]; - u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_1) % WARP_SIZE]; - } - - return vec_dot_q8_1_q8_1_impl - (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dm[index_bx], y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]); -} - -template static __device__ __forceinline__ void allocate_tiles_q8_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { - GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); - - __shared__ int tile_x_qs[mmq_y * (WARP_SIZE) + mmq_y]; - __shared__ float tile_x_d[mmq_y * (WARP_SIZE/QI8_0) + mmq_y/QI8_0]; - - *x_ql = tile_x_qs; - *x_dm = (half2 *) tile_x_d; -} - -template static __device__ __forceinline__ void load_tiles_q8_0( - const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, - int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { - GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); - - GGML_CUDA_ASSUME(i_offset >= 0); - GGML_CUDA_ASSUME(i_offset < nwarps); - GGML_CUDA_ASSUME(k >= 0); - GGML_CUDA_ASSUME(k < WARP_SIZE); - - const int kbx = k / QI8_0; - const int kqsx = k % QI8_0; - float * x_dmf = (float *) x_dm; - - const block_q8_0 * bx0 = (const block_q8_0 *) vx; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { - int i = i0 + i_offset; - - if (need_check) { - i = min(i, i_max); - } - - const block_q8_0 * bxi = bx0 + i*blocks_per_row + kbx; - - x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_int8(bxi->qs, kqsx); - } - - const int blocks_per_tile_x_row = WARP_SIZE / QI8_0; - const int kbxd = k % blocks_per_tile_x_row; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI8_0) { - int i = i0 + i_offset * QI8_0 + k / blocks_per_tile_x_row; - - if (need_check) { - i = min(i, i_max); - } - - const block_q8_0 * bxi = bx0 + i*blocks_per_row + kbxd; - - x_dmf[i * (WARP_SIZE/QI8_0) + i / QI8_0 + kbxd] = bxi->d; - } -} - -static __device__ __forceinline__ float vec_dot_q8_0_q8_1_mul_mat( - const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, - const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { - GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); - - const float * x_dmf = (const float *) x_dm; - const float * y_df = (const float *) y_ds; - - return vec_dot_q8_0_q8_1_impl - (&x_ql[i * (WARP_SIZE + 1) + k], &y_qs[j * WARP_SIZE + k], x_dmf[i * (WARP_SIZE/QI8_0) + i/QI8_0 + k/QI8_0], - y_df[j * (WARP_SIZE/QI8_1) + k/QI8_1]); -} - -template static __device__ __forceinline__ void allocate_tiles_q2_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { - GGML_UNUSED(x_qh); - - __shared__ int tile_x_ql[mmq_y * (WARP_SIZE) + mmq_y]; - __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI2_K) + mmq_y/QI2_K]; - __shared__ int tile_x_sc[mmq_y * (WARP_SIZE/4) + mmq_y/4]; - - *x_ql = tile_x_ql; - *x_dm = tile_x_dm; - *x_sc = tile_x_sc; -} - -template static __device__ __forceinline__ void load_tiles_q2_K( - const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, - int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { - GGML_UNUSED(x_qh); - - GGML_CUDA_ASSUME(i_offset >= 0); - GGML_CUDA_ASSUME(i_offset < nwarps); - GGML_CUDA_ASSUME(k >= 0); - GGML_CUDA_ASSUME(k < WARP_SIZE); - - const int kbx = k / QI2_K; - const int kqsx = k % QI2_K; - - const block_q2_K * bx0 = (const block_q2_K *) vx; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { - int i = i0 + i_offset; - - if (need_check) { - i = min(i, i_max); - } - - const block_q2_K * bxi = bx0 + i*blocks_per_row + kbx; - - x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx); - } - - const int blocks_per_tile_x_row = WARP_SIZE / QI2_K; - const int kbxd = k % blocks_per_tile_x_row; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI2_K) { - int i = (i0 + i_offset * QI2_K + k / blocks_per_tile_x_row) % mmq_y; - - if (need_check) { - i = min(i, i_max); - } - - const block_q2_K * bxi = bx0 + i*blocks_per_row + kbxd; - - x_dm[i * (WARP_SIZE/QI2_K) + i / QI2_K + kbxd] = bxi->dm; - } - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) { - int i = i0 + i_offset * 4 + k / (WARP_SIZE/4); - - if (need_check) { - i = min(i, i_max); - } - - const block_q2_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/4)) / (QI2_K/4); - - x_sc[i * (WARP_SIZE/4) + i / 4 + k % (WARP_SIZE/4)] = get_int_from_uint8_aligned(bxi->scales, k % (QI2_K/4)); - } -} - -static __device__ __forceinline__ float vec_dot_q2_K_q8_1_mul_mat( - const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, - const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { - GGML_UNUSED(x_qh); - - const int kbx = k / QI2_K; - const int ky = (k % QI2_K) * QR2_K; - const float * y_df = (const float *) y_ds; - - int v[QR2_K*VDR_Q2_K_Q8_1_MMQ]; - - const int kqsx = i * (WARP_SIZE + 1) + kbx*QI2_K + (QI2_K/2) * (ky/(2*QI2_K)) + ky % (QI2_K/2); - const int shift = 2 * ((ky % (2*QI2_K)) / (QI2_K/2)); - -#pragma unroll - for (int l = 0; l < QR2_K*VDR_Q2_K_Q8_1_MMQ; ++l) { - v[l] = (x_ql[kqsx + l] >> shift) & 0x03030303; - } - - const uint8_t * scales = ((const uint8_t *) &x_sc[i * (WARP_SIZE/4) + i/4 + kbx*4]) + ky/4; - - const int index_y = j * WARP_SIZE + (QR2_K*k) % WARP_SIZE; - return vec_dot_q2_K_q8_1_impl_mmq(v, &y_qs[index_y], scales, x_dm[i * (WARP_SIZE/QI2_K) + i/QI2_K + kbx], y_df[index_y/QI8_1]); -} - -template static __device__ __forceinline__ void allocate_tiles_q3_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { - - __shared__ int tile_x_ql[mmq_y * (WARP_SIZE) + mmq_y]; - __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI3_K) + mmq_y/QI3_K]; - __shared__ int tile_x_qh[mmq_y * (WARP_SIZE/2) + mmq_y/2]; - __shared__ int tile_x_sc[mmq_y * (WARP_SIZE/4) + mmq_y/4]; - - *x_ql = tile_x_ql; - *x_dm = tile_x_dm; - *x_qh = tile_x_qh; - *x_sc = tile_x_sc; -} - -template static __device__ __forceinline__ void load_tiles_q3_K( - const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, - int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { - - GGML_CUDA_ASSUME(i_offset >= 0); - GGML_CUDA_ASSUME(i_offset < nwarps); - GGML_CUDA_ASSUME(k >= 0); - GGML_CUDA_ASSUME(k < WARP_SIZE); - - const int kbx = k / QI3_K; - const int kqsx = k % QI3_K; - - const block_q3_K * bx0 = (const block_q3_K *) vx; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { - int i = i0 + i_offset; - - if (need_check) { - i = min(i, i_max); - } - - const block_q3_K * bxi = bx0 + i*blocks_per_row + kbx; - - x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bxi->qs, kqsx); - } - - const int blocks_per_tile_x_row = WARP_SIZE / QI3_K; - const int kbxd = k % blocks_per_tile_x_row; - float * x_dmf = (float *) x_dm; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI3_K) { - int i = (i0 + i_offset * QI3_K + k / blocks_per_tile_x_row) % mmq_y; - - if (need_check) { - i = min(i, i_max); - } - - const block_q3_K * bxi = bx0 + i*blocks_per_row + kbxd; - - x_dmf[i * (WARP_SIZE/QI3_K) + i / QI3_K + kbxd] = bxi->d; - } - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 2) { - int i = i0 + i_offset * 2 + k / (WARP_SIZE/2); - - if (need_check) { - i = min(i, i_max); - } - - const block_q3_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/2)) / (QI3_K/2); - - // invert the mask with ~ so that a 0/1 results in 4/0 being subtracted - x_qh[i * (WARP_SIZE/2) + i / 2 + k % (WARP_SIZE/2)] = ~get_int_from_uint8(bxi->hmask, k % (QI3_K/2)); - } - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) { - int i = i0 + i_offset * 4 + k / (WARP_SIZE/4); - - if (need_check) { - i = min(i, i_max); - } - - const block_q3_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/4)) / (QI3_K/4); - - const int ksc = k % (QI3_K/4); - - const int ksc_low = ksc % (QI3_K/8); - const int shift_low = 4 * (ksc / (QI3_K/8)); - const int sc_low = (get_int_from_uint8(bxi->scales, ksc_low) >> shift_low) & 0x0F0F0F0F; - - const int ksc_high = QI3_K/8; - const int shift_high = 2 * ksc; - const int sc_high = ((get_int_from_uint8(bxi->scales, ksc_high) >> shift_high) << 4) & 0x30303030; - - const int sc = __vsubss4(sc_low | sc_high, 0x20202020); - - x_sc[i * (WARP_SIZE/4) + i / 4 + k % (WARP_SIZE/4)] = sc; - } -} - -static __device__ __forceinline__ float vec_dot_q3_K_q8_1_mul_mat( - const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, - const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { - - const int kbx = k / QI3_K; - const int ky = (k % QI3_K) * QR3_K; - const float * x_dmf = (const float *) x_dm; - const float * y_df = (const float *) y_ds; - - const int8_t * scales = ((const int8_t *) (x_sc + i * (WARP_SIZE/4) + i/4 + kbx*4)) + ky/4; - - int v[QR3_K*VDR_Q3_K_Q8_1_MMQ]; - -#pragma unroll - for (int l = 0; l < QR3_K*VDR_Q3_K_Q8_1_MMQ; ++l) { - const int kqsx = i * (WARP_SIZE + 1) + kbx*QI3_K + (QI3_K/2) * (ky/(2*QI3_K)) + ky % (QI3_K/2); - const int shift = 2 * ((ky % 32) / 8); - const int vll = (x_ql[kqsx + l] >> shift) & 0x03030303; - - const int vh = x_qh[i * (WARP_SIZE/2) + i/2 + kbx * (QI3_K/2) + (ky+l)%8] >> ((ky+l) / 8); - const int vlh = (vh << 2) & 0x04040404; - - v[l] = __vsubss4(vll, vlh); - } - - const int index_y = j * WARP_SIZE + (k*QR3_K) % WARP_SIZE; - return vec_dot_q3_K_q8_1_impl_mmq(v, &y_qs[index_y], scales, x_dmf[i * (WARP_SIZE/QI3_K) + i/QI3_K + kbx], y_df[index_y/QI8_1]); -} - -template static __device__ __forceinline__ void allocate_tiles_q4_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { - GGML_UNUSED(x_qh); - - __shared__ int tile_x_ql[mmq_y * (WARP_SIZE) + mmq_y]; - __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI4_K) + mmq_y/QI4_K]; - __shared__ int tile_x_sc[mmq_y * (WARP_SIZE/8) + mmq_y/8]; - - *x_ql = tile_x_ql; - *x_dm = tile_x_dm; - *x_sc = tile_x_sc; -} - -template static __device__ __forceinline__ void load_tiles_q4_K( - const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, - int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { - GGML_UNUSED(x_qh); - - GGML_CUDA_ASSUME(i_offset >= 0); - GGML_CUDA_ASSUME(i_offset < nwarps); - GGML_CUDA_ASSUME(k >= 0); - GGML_CUDA_ASSUME(k < WARP_SIZE); - - const int kbx = k / QI4_K; // == 0 if QK_K == 256 - const int kqsx = k % QI4_K; // == k if QK_K == 256 - - const block_q4_K * bx0 = (const block_q4_K *) vx; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { - int i = i0 + i_offset; - - if (need_check) { - i = min(i, i_max); - } - - const block_q4_K * bxi = bx0 + i*blocks_per_row + kbx; - - x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx); - } - - const int blocks_per_tile_x_row = WARP_SIZE / QI4_K; // == 1 if QK_K == 256 - const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256 - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_K) { - int i = (i0 + i_offset * QI4_K + k / blocks_per_tile_x_row) % mmq_y; - - if (need_check) { - i = min(i, i_max); - } - - const block_q4_K * bxi = bx0 + i*blocks_per_row + kbxd; - -#if QK_K == 256 - x_dm[i * (WARP_SIZE/QI4_K) + i / QI4_K + kbxd] = bxi->dm; -#else - x_dm[i * (WARP_SIZE/QI4_K) + i / QI4_K + kbxd] = {bxi->dm[0], bxi->dm[1]}; -#endif - } - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) { - int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y; - - if (need_check) { - i = min(i, i_max); - } - - const block_q4_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / (QI4_K/8); - - const int * scales = (const int *) bxi->scales; - - const int ksc = k % (WARP_SIZE/8); - - // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8 - int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits - scales8 |= (scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits - - x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8; - } -} - -static __device__ __forceinline__ float vec_dot_q4_K_q8_1_mul_mat( - const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, - const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { - GGML_UNUSED(x_qh); - - const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/16]) + 2*((k % 16) / 8); - - const int index_y = j * WARP_SIZE + (QR4_K*k) % WARP_SIZE; - return vec_dot_q4_K_q8_1_impl_mmq(&x_ql[i * (WARP_SIZE + 1) + k], &y_qs[index_y], sc, sc+8, - x_dm[i * (WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[index_y/QI8_1]); -} - -template static __device__ __forceinline__ void allocate_tiles_q5_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { - GGML_UNUSED(x_qh); - - __shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE) + mmq_y]; - __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI5_K) + mmq_y/QI5_K]; - __shared__ int tile_x_sc[mmq_y * (WARP_SIZE/8) + mmq_y/8]; - - *x_ql = tile_x_ql; - *x_dm = tile_x_dm; - *x_sc = tile_x_sc; -} - -template static __device__ __forceinline__ void load_tiles_q5_K( - const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, - int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { - GGML_UNUSED(x_qh); - - GGML_CUDA_ASSUME(i_offset >= 0); - GGML_CUDA_ASSUME(i_offset < nwarps); - GGML_CUDA_ASSUME(k >= 0); - GGML_CUDA_ASSUME(k < WARP_SIZE); - - const int kbx = k / QI5_K; // == 0 if QK_K == 256 - const int kqsx = k % QI5_K; // == k if QK_K == 256 - - const block_q5_K * bx0 = (const block_q5_K *) vx; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { - int i = i0 + i_offset; - - if (need_check) { - i = min(i, i_max); - } - - const block_q5_K * bxi = bx0 + i*blocks_per_row + kbx; - const int ky = QR5_K*kqsx; - - const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx); - const int ql0 = (ql >> 0) & 0x0F0F0F0F; - const int ql1 = (ql >> 4) & 0x0F0F0F0F; - - const int qh = get_int_from_uint8_aligned(bxi->qh, kqsx % (QI5_K/4)); - const int qh0 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 0)) << 4) & 0x10101010; - const int qh1 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 1)) << 4) & 0x10101010; - - const int kq0 = ky - ky % (QI5_K/2) + k % (QI5_K/4) + 0; - const int kq1 = ky - ky % (QI5_K/2) + k % (QI5_K/4) + (QI5_K/4); - - x_ql[i * (2*WARP_SIZE + 1) + kq0] = ql0 | qh0; - x_ql[i * (2*WARP_SIZE + 1) + kq1] = ql1 | qh1; - } - - const int blocks_per_tile_x_row = WARP_SIZE / QI5_K; // == 1 if QK_K == 256 - const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256 - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_K) { - int i = (i0 + i_offset * QI5_K + k / blocks_per_tile_x_row) % mmq_y; - - if (need_check) { - i = min(i, i_max); - } - - const block_q5_K * bxi = bx0 + i*blocks_per_row + kbxd; - -#if QK_K == 256 - x_dm[i * (WARP_SIZE/QI5_K) + i / QI5_K + kbxd] = bxi->dm; -#endif - } - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) { - int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y; - - if (need_check) { - i = min(i, i_max); - } - - const block_q5_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / (QI5_K/8); - - const int * scales = (const int *) bxi->scales; - - const int ksc = k % (WARP_SIZE/8); - - // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8 - int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits - scales8 |= (scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits - - x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8; - } -} - -static __device__ __forceinline__ float vec_dot_q5_K_q8_1_mul_mat( - const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, - const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { - GGML_UNUSED(x_qh); - - const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/16]) + 2 * ((k % 16) / 8); - - const int index_x = i * (QR5_K*WARP_SIZE + 1) + QR5_K*k; - const int index_y = j * WARP_SIZE + (QR5_K*k) % WARP_SIZE; - return vec_dot_q5_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, sc+8, - x_dm[i * (WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[index_y/QI8_1]); -} - -template static __device__ __forceinline__ void allocate_tiles_q6_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { - GGML_UNUSED(x_qh); - - __shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE) + mmq_y]; - __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI6_K) + mmq_y/QI6_K]; - __shared__ int tile_x_sc[mmq_y * (WARP_SIZE/8) + mmq_y/8]; - - *x_ql = tile_x_ql; - *x_dm = tile_x_dm; - *x_sc = tile_x_sc; -} - -template static __device__ __forceinline__ void load_tiles_q6_K( - const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, - int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { - GGML_UNUSED(x_qh); - - GGML_CUDA_ASSUME(i_offset >= 0); - GGML_CUDA_ASSUME(i_offset < nwarps); - GGML_CUDA_ASSUME(k >= 0); - GGML_CUDA_ASSUME(k < WARP_SIZE); - - const int kbx = k / QI6_K; // == 0 if QK_K == 256 - const int kqsx = k % QI6_K; // == k if QK_K == 256 - - const block_q6_K * bx0 = (const block_q6_K *) vx; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { - int i = i0 + i_offset; - - if (need_check) { - i = min(i, i_max); - } - - const block_q6_K * bxi = bx0 + i*blocks_per_row + kbx; - const int ky = QR6_K*kqsx; - - const int ql = get_int_from_uint8(bxi->ql, kqsx); - const int ql0 = (ql >> 0) & 0x0F0F0F0F; - const int ql1 = (ql >> 4) & 0x0F0F0F0F; - - const int qh = get_int_from_uint8(bxi->qh, (QI6_K/4) * (kqsx / (QI6_K/2)) + kqsx % (QI6_K/4)); - const int qh0 = ((qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) << 4) & 0x30303030; - const int qh1 = (qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) & 0x30303030; - - const int kq0 = ky - ky % QI6_K + k % (QI6_K/2) + 0; - const int kq1 = ky - ky % QI6_K + k % (QI6_K/2) + (QI6_K/2); - - x_ql[i * (2*WARP_SIZE + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020); - x_ql[i * (2*WARP_SIZE + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020); - } - - const int blocks_per_tile_x_row = WARP_SIZE / QI6_K; // == 1 if QK_K == 256 - const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256 - float * x_dmf = (float *) x_dm; - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI6_K) { - int i = (i0 + i_offset * QI6_K + k / blocks_per_tile_x_row) % mmq_y; - - if (need_check) { - i = min(i, i_max); - } - - const block_q6_K * bxi = bx0 + i*blocks_per_row + kbxd; - - x_dmf[i * (WARP_SIZE/QI6_K) + i / QI6_K + kbxd] = bxi->d; - } - -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) { - int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y; - - if (need_check) { - i = min(i, i_max); - } - - const block_q6_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / 4; - - x_sc[i * (WARP_SIZE/8) + i / 8 + k % (WARP_SIZE/8)] = get_int_from_int8(bxi->scales, k % (QI6_K/8)); - } -} - -static __device__ __forceinline__ float vec_dot_q6_K_q8_1_mul_mat( - const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, - const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { - GGML_UNUSED(x_qh); - - const float * x_dmf = (const float *) x_dm; - const float * y_df = (const float *) y_ds; - - const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/8]); - - const int index_x = i * (QR6_K*WARP_SIZE + 1) + QR6_K*k; - const int index_y = j * WARP_SIZE + (QR6_K*k) % WARP_SIZE; - return vec_dot_q6_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, x_dmf[i * (WARP_SIZE/QI6_K) + i/QI6_K], &y_df[index_y/QI8_1]); -} - -#define MMQ_X_Q4_0_RDNA2 64 -#define MMQ_Y_Q4_0_RDNA2 128 -#define NWARPS_Q4_0_RDNA2 8 -#define MMQ_X_Q4_0_RDNA1 64 -#define MMQ_Y_Q4_0_RDNA1 64 -#define NWARPS_Q4_0_RDNA1 8 -#if defined(CUDA_USE_TENSOR_CORES) -#define MMQ_X_Q4_0_AMPERE 4 -#define MMQ_Y_Q4_0_AMPERE 32 -#define NWARPS_Q4_0_AMPERE 4 -#else -#define MMQ_X_Q4_0_AMPERE 64 -#define MMQ_Y_Q4_0_AMPERE 128 -#define NWARPS_Q4_0_AMPERE 4 -#endif -#define MMQ_X_Q4_0_PASCAL 64 -#define MMQ_Y_Q4_0_PASCAL 64 -#define NWARPS_Q4_0_PASCAL 8 - -template -static __device__ __forceinline__ void mul_mat_q( - const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, - const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { - - const block_q_t * x = (const block_q_t *) vx; - const block_q8_1 * y = (const block_q8_1 *) vy; - - const int blocks_per_row_x = ncols_x / qk; - const int blocks_per_col_y = nrows_y / QK8_1; - const int blocks_per_warp = WARP_SIZE / qi; - - const int & ncols_dst = ncols_y; - - const int row_dst_0 = blockIdx.x*mmq_y; - const int & row_x_0 = row_dst_0; - - const int col_dst_0 = blockIdx.y*mmq_x; - const int & col_y_0 = col_dst_0; - - int * tile_x_ql = nullptr; - half2 * tile_x_dm = nullptr; - int * tile_x_qh = nullptr; - int * tile_x_sc = nullptr; - - allocate_tiles(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc); - - __shared__ int tile_y_qs[mmq_x * WARP_SIZE]; - __shared__ half2 tile_y_ds[mmq_x * WARP_SIZE/QI8_1]; - - float sum[mmq_y/WARP_SIZE][mmq_x/nwarps] = {{0.0f}}; - - for (int ib0 = 0; ib0 < blocks_per_row_x; ib0 += blocks_per_warp) { - - load_tiles(x + row_x_0*blocks_per_row_x + ib0, tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, - threadIdx.y, nrows_x-row_x_0-1, threadIdx.x, blocks_per_row_x); - -#pragma unroll - for (int ir = 0; ir < qr; ++ir) { - const int kqs = ir*WARP_SIZE + threadIdx.x; - const int kbxd = kqs / QI8_1; - -#pragma unroll - for (int i = 0; i < mmq_x; i += nwarps) { - const int col_y_eff = min(col_y_0 + threadIdx.y + i, ncols_y-1); // to prevent out-of-bounds memory accesses - - const block_q8_1 * by0 = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + kbxd]; - - const int index_y = (threadIdx.y + i) * WARP_SIZE + kqs % WARP_SIZE; - tile_y_qs[index_y] = get_int_from_int8_aligned(by0->qs, threadIdx.x % QI8_1); - } - -#pragma unroll - for (int ids0 = 0; ids0 < mmq_x; ids0 += nwarps * QI8_1) { - const int ids = (ids0 + threadIdx.y * QI8_1 + threadIdx.x / (WARP_SIZE/QI8_1)) % mmq_x; - const int kby = threadIdx.x % (WARP_SIZE/QI8_1); - const int col_y_eff = min(col_y_0 + ids, ncols_y-1); - - // if the sum is not needed it's faster to transform the scale to f32 ahead of time - const half2 * dsi_src = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + ir*(WARP_SIZE/QI8_1) + kby].ds; - half2 * dsi_dst = &tile_y_ds[ids * (WARP_SIZE/QI8_1) + kby]; - if (need_sum) { - *dsi_dst = *dsi_src; - } else { - float * dfi_dst = (float *) dsi_dst; - *dfi_dst = __low2float(*dsi_src); - } - } - - __syncthreads(); - -// #pragma unroll // unrolling this loop causes too much register pressure - for (int k = ir*WARP_SIZE/qr; k < (ir+1)*WARP_SIZE/qr; k += vdr) { -#pragma unroll - for (int j = 0; j < mmq_x; j += nwarps) { -#pragma unroll - for (int i = 0; i < mmq_y; i += WARP_SIZE) { - sum[i/WARP_SIZE][j/nwarps] += vec_dot( - tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, tile_y_qs, tile_y_ds, - threadIdx.x + i, threadIdx.y + j, k); - } - } - } - - __syncthreads(); - } - } - -#pragma unroll - for (int j = 0; j < mmq_x; j += nwarps) { - const int col_dst = col_dst_0 + j + threadIdx.y; - - if (col_dst >= ncols_dst) { - return; - } - -#pragma unroll - for (int i = 0; i < mmq_y; i += WARP_SIZE) { - const int row_dst = row_dst_0 + threadIdx.x + i; - - if (row_dst >= nrows_dst) { - continue; - } - - dst[col_dst*nrows_dst + row_dst] = sum[i/WARP_SIZE][j/nwarps]; - } - } -} - -template static __global__ void -#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) -#if defined(RDNA3) || defined(RDNA2) - __launch_bounds__(WARP_SIZE*NWARPS_Q4_0_RDNA2, 2) -#endif // defined(RDNA3) || defined(RDNA2) -#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) - mul_mat_q4_0( - const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, - const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { - -#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) -#if defined(RDNA3) || defined(RDNA2) - const int mmq_x = MMQ_X_Q4_0_RDNA2; - const int mmq_y = MMQ_Y_Q4_0_RDNA2; - const int nwarps = NWARPS_Q4_0_RDNA2; -#else - const int mmq_x = MMQ_X_Q4_0_RDNA1; - const int mmq_y = MMQ_Y_Q4_0_RDNA1; - const int nwarps = NWARPS_Q4_0_RDNA1; -#endif // defined(RDNA3) || defined(RDNA2) - - mul_mat_q, - load_tiles_q4_0, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - -#elif __CUDA_ARCH__ >= CC_VOLTA - const int mmq_x = MMQ_X_Q4_0_AMPERE; - const int mmq_y = MMQ_Y_Q4_0_AMPERE; - const int nwarps = NWARPS_Q4_0_AMPERE; - - mul_mat_q, - load_tiles_q4_0, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - -#elif __CUDA_ARCH__ >= MIN_CC_DP4A - const int mmq_x = MMQ_X_Q4_0_PASCAL; - const int mmq_y = MMQ_Y_Q4_0_PASCAL; - const int nwarps = NWARPS_Q4_0_PASCAL; - - mul_mat_q, - load_tiles_q4_0, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); -#else - GGML_UNUSED(vec_dot_q4_0_q8_1_mul_mat); - NO_DEVICE_CODE; -#endif // __CUDA_ARCH__ >= CC_VOLTA -} - -#define MMQ_X_Q4_1_RDNA2 64 -#define MMQ_Y_Q4_1_RDNA2 128 -#define NWARPS_Q4_1_RDNA2 8 -#define MMQ_X_Q4_1_RDNA1 64 -#define MMQ_Y_Q4_1_RDNA1 64 -#define NWARPS_Q4_1_RDNA1 8 -#if defined(CUDA_USE_TENSOR_CORES) -#define MMQ_X_Q4_1_AMPERE 4 -#define MMQ_Y_Q4_1_AMPERE 32 -#define NWARPS_Q4_1_AMPERE 4 -#else -#define MMQ_X_Q4_1_AMPERE 64 -#define MMQ_Y_Q4_1_AMPERE 128 -#define NWARPS_Q4_1_AMPERE 4 -#endif -#define MMQ_X_Q4_1_PASCAL 64 -#define MMQ_Y_Q4_1_PASCAL 64 -#define NWARPS_Q4_1_PASCAL 8 - -template static __global__ void -#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) -#if defined(RDNA3) || defined(RDNA2) - __launch_bounds__(WARP_SIZE*NWARPS_Q4_1_RDNA2, 2) -#endif // defined(RDNA3) || defined(RDNA2) -#elif __CUDA_ARCH__ < CC_VOLTA - __launch_bounds__(WARP_SIZE*NWARPS_Q4_1_PASCAL, 2) -#endif // __CUDA_ARCH__ < CC_VOLTA - mul_mat_q4_1( - const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, - const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { - -#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) -#if defined(RDNA3) || defined(RDNA2) - const int mmq_x = MMQ_X_Q4_1_RDNA2; - const int mmq_y = MMQ_Y_Q4_1_RDNA2; - const int nwarps = NWARPS_Q4_1_RDNA2; -#else - const int mmq_x = MMQ_X_Q4_1_RDNA1; - const int mmq_y = MMQ_Y_Q4_1_RDNA1; - const int nwarps = NWARPS_Q4_1_RDNA1; -#endif // defined(RDNA3) || defined(RDNA2) - - mul_mat_q, - load_tiles_q4_1, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - -#elif __CUDA_ARCH__ >= CC_VOLTA - const int mmq_x = MMQ_X_Q4_1_AMPERE; - const int mmq_y = MMQ_Y_Q4_1_AMPERE; - const int nwarps = NWARPS_Q4_1_AMPERE; - - mul_mat_q, - load_tiles_q4_1, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - -#elif __CUDA_ARCH__ >= MIN_CC_DP4A - const int mmq_x = MMQ_X_Q4_1_PASCAL; - const int mmq_y = MMQ_Y_Q4_1_PASCAL; - const int nwarps = NWARPS_Q4_1_PASCAL; - - mul_mat_q, - load_tiles_q4_1, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); -#else - GGML_UNUSED(vec_dot_q4_1_q8_1_mul_mat); - NO_DEVICE_CODE; -#endif // __CUDA_ARCH__ >= CC_VOLTA -} - -#define MMQ_X_Q5_0_RDNA2 64 -#define MMQ_Y_Q5_0_RDNA2 128 -#define NWARPS_Q5_0_RDNA2 8 -#define MMQ_X_Q5_0_RDNA1 64 -#define MMQ_Y_Q5_0_RDNA1 64 -#define NWARPS_Q5_0_RDNA1 8 -#if defined(CUDA_USE_TENSOR_CORES) -#define MMQ_X_Q5_0_AMPERE 4 -#define MMQ_Y_Q5_0_AMPERE 32 -#define NWARPS_Q5_0_AMPERE 4 -#else -#define MMQ_X_Q5_0_AMPERE 128 -#define MMQ_Y_Q5_0_AMPERE 64 -#define NWARPS_Q5_0_AMPERE 4 -#endif -#define MMQ_X_Q5_0_PASCAL 64 -#define MMQ_Y_Q5_0_PASCAL 64 -#define NWARPS_Q5_0_PASCAL 8 - -template static __global__ void -#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) -#if defined(RDNA3) || defined(RDNA2) - __launch_bounds__(WARP_SIZE*NWARPS_Q5_0_RDNA2, 2) -#endif // defined(RDNA3) || defined(RDNA2) -#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) - mul_mat_q5_0( - const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, - const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { - -#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) -#if defined(RDNA3) || defined(RDNA2) - const int mmq_x = MMQ_X_Q5_0_RDNA2; - const int mmq_y = MMQ_Y_Q5_0_RDNA2; - const int nwarps = NWARPS_Q5_0_RDNA2; -#else - const int mmq_x = MMQ_X_Q5_0_RDNA1; - const int mmq_y = MMQ_Y_Q5_0_RDNA1; - const int nwarps = NWARPS_Q5_0_RDNA1; -#endif // defined(RDNA3) || defined(RDNA2) - - mul_mat_q, - load_tiles_q5_0, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - -#elif __CUDA_ARCH__ >= CC_VOLTA - const int mmq_x = MMQ_X_Q5_0_AMPERE; - const int mmq_y = MMQ_Y_Q5_0_AMPERE; - const int nwarps = NWARPS_Q5_0_AMPERE; - - mul_mat_q, - load_tiles_q5_0, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - -#elif __CUDA_ARCH__ >= MIN_CC_DP4A - const int mmq_x = MMQ_X_Q5_0_PASCAL; - const int mmq_y = MMQ_Y_Q5_0_PASCAL; - const int nwarps = NWARPS_Q5_0_PASCAL; - - mul_mat_q, - load_tiles_q5_0, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); -#else - GGML_UNUSED(vec_dot_q5_0_q8_1_mul_mat); - NO_DEVICE_CODE; -#endif // __CUDA_ARCH__ >= CC_VOLTA -} - -#define MMQ_X_Q5_1_RDNA2 64 -#define MMQ_Y_Q5_1_RDNA2 128 -#define NWARPS_Q5_1_RDNA2 8 -#define MMQ_X_Q5_1_RDNA1 64 -#define MMQ_Y_Q5_1_RDNA1 64 -#define NWARPS_Q5_1_RDNA1 8 -#if defined(CUDA_USE_TENSOR_CORES) -#define MMQ_X_Q5_1_AMPERE 4 -#define MMQ_Y_Q5_1_AMPERE 32 -#define NWARPS_Q5_1_AMPERE 4 -#else -#define MMQ_X_Q5_1_AMPERE 128 -#define MMQ_Y_Q5_1_AMPERE 64 -#define NWARPS_Q5_1_AMPERE 4 -#endif -#define MMQ_X_Q5_1_PASCAL 64 -#define MMQ_Y_Q5_1_PASCAL 64 -#define NWARPS_Q5_1_PASCAL 8 - -template static __global__ void -#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) -#if defined(RDNA3) || defined(RDNA2) - __launch_bounds__(WARP_SIZE*NWARPS_Q5_1_RDNA2, 2) -#endif // defined(RDNA3) || defined(RDNA2) -#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) -mul_mat_q5_1( - const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, - const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { - -#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) -#if defined(RDNA3) || defined(RDNA2) - const int mmq_x = MMQ_X_Q5_1_RDNA2; - const int mmq_y = MMQ_Y_Q5_1_RDNA2; - const int nwarps = NWARPS_Q5_1_RDNA2; -#else - const int mmq_x = MMQ_X_Q5_1_RDNA1; - const int mmq_y = MMQ_Y_Q5_1_RDNA1; - const int nwarps = NWARPS_Q5_1_RDNA1; -#endif // defined(RDNA3) || defined(RDNA2) - - mul_mat_q, - load_tiles_q5_1, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - -#elif __CUDA_ARCH__ >= CC_VOLTA - const int mmq_x = MMQ_X_Q5_1_AMPERE; - const int mmq_y = MMQ_Y_Q5_1_AMPERE; - const int nwarps = NWARPS_Q5_1_AMPERE; - - mul_mat_q, - load_tiles_q5_1, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - -#elif __CUDA_ARCH__ >= MIN_CC_DP4A - const int mmq_x = MMQ_X_Q5_1_PASCAL; - const int mmq_y = MMQ_Y_Q5_1_PASCAL; - const int nwarps = NWARPS_Q5_1_PASCAL; - - mul_mat_q, - load_tiles_q5_1, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); -#else - GGML_UNUSED(vec_dot_q5_1_q8_1_mul_mat); - NO_DEVICE_CODE; -#endif // __CUDA_ARCH__ >= CC_VOLTA -} - -#define MMQ_X_Q8_0_RDNA2 64 -#define MMQ_Y_Q8_0_RDNA2 128 -#define NWARPS_Q8_0_RDNA2 8 -#define MMQ_X_Q8_0_RDNA1 64 -#define MMQ_Y_Q8_0_RDNA1 64 -#define NWARPS_Q8_0_RDNA1 8 -#if defined(CUDA_USE_TENSOR_CORES) -#define MMQ_X_Q8_0_AMPERE 4 -#define MMQ_Y_Q8_0_AMPERE 32 -#define NWARPS_Q8_0_AMPERE 4 -#else -#define MMQ_X_Q8_0_AMPERE 128 -#define MMQ_Y_Q8_0_AMPERE 64 -#define NWARPS_Q8_0_AMPERE 4 -#endif -#define MMQ_X_Q8_0_PASCAL 64 -#define MMQ_Y_Q8_0_PASCAL 64 -#define NWARPS_Q8_0_PASCAL 8 - -template static __global__ void -#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) -#if defined(RDNA3) || defined(RDNA2) - __launch_bounds__(WARP_SIZE*NWARPS_Q8_0_RDNA2, 2) -#endif // defined(RDNA3) || defined(RDNA2) -#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) - mul_mat_q8_0( - const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, - const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { - -#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) -#if defined(RDNA3) || defined(RDNA2) - const int mmq_x = MMQ_X_Q8_0_RDNA2; - const int mmq_y = MMQ_Y_Q8_0_RDNA2; - const int nwarps = NWARPS_Q8_0_RDNA2; -#else - const int mmq_x = MMQ_X_Q8_0_RDNA1; - const int mmq_y = MMQ_Y_Q8_0_RDNA1; - const int nwarps = NWARPS_Q8_0_RDNA1; -#endif // defined(RDNA3) || defined(RDNA2) - - mul_mat_q, - load_tiles_q8_0, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - -#elif __CUDA_ARCH__ >= CC_VOLTA - const int mmq_x = MMQ_X_Q8_0_AMPERE; - const int mmq_y = MMQ_Y_Q8_0_AMPERE; - const int nwarps = NWARPS_Q8_0_AMPERE; - - mul_mat_q, - load_tiles_q8_0, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - -#elif __CUDA_ARCH__ >= MIN_CC_DP4A - const int mmq_x = MMQ_X_Q8_0_PASCAL; - const int mmq_y = MMQ_Y_Q8_0_PASCAL; - const int nwarps = NWARPS_Q8_0_PASCAL; - - mul_mat_q, - load_tiles_q8_0, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); -#else - GGML_UNUSED(vec_dot_q8_0_q8_1_mul_mat); - NO_DEVICE_CODE; -#endif // __CUDA_ARCH__ >= CC_VOLTA -} - -#define MMQ_X_Q2_K_RDNA2 64 -#define MMQ_Y_Q2_K_RDNA2 128 -#define NWARPS_Q2_K_RDNA2 8 -#define MMQ_X_Q2_K_RDNA1 128 -#define MMQ_Y_Q2_K_RDNA1 32 -#define NWARPS_Q2_K_RDNA1 8 -#if defined(CUDA_USE_TENSOR_CORES) -#define MMQ_X_Q2_K_AMPERE 4 -#define MMQ_Y_Q2_K_AMPERE 32 -#define NWARPS_Q2_K_AMPERE 4 -#else -#define MMQ_X_Q2_K_AMPERE 64 -#define MMQ_Y_Q2_K_AMPERE 128 -#define NWARPS_Q2_K_AMPERE 4 -#endif -#define MMQ_X_Q2_K_PASCAL 64 -#define MMQ_Y_Q2_K_PASCAL 64 -#define NWARPS_Q2_K_PASCAL 8 - -template static __global__ void -#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) -#if defined(RDNA3) || defined(RDNA2) - __launch_bounds__(WARP_SIZE*NWARPS_Q2_K_RDNA2, 2) -#endif // defined(RDNA3) || defined(RDNA2) -#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) -mul_mat_q2_K( - const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, - const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { - -#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) -#if defined(RDNA3) || defined(RDNA2) - const int mmq_x = MMQ_X_Q2_K_RDNA2; - const int mmq_y = MMQ_Y_Q2_K_RDNA2; - const int nwarps = NWARPS_Q2_K_RDNA2; -#else - const int mmq_x = MMQ_X_Q2_K_RDNA1; - const int mmq_y = MMQ_Y_Q2_K_RDNA1; - const int nwarps = NWARPS_Q2_K_RDNA1; -#endif // defined(RDNA3) || defined(RDNA2) - - mul_mat_q, - load_tiles_q2_K, VDR_Q2_K_Q8_1_MMQ, vec_dot_q2_K_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - -#elif __CUDA_ARCH__ >= CC_VOLTA - const int mmq_x = MMQ_X_Q2_K_AMPERE; - const int mmq_y = MMQ_Y_Q2_K_AMPERE; - const int nwarps = NWARPS_Q2_K_AMPERE; - - mul_mat_q, - load_tiles_q2_K, VDR_Q2_K_Q8_1_MMQ, vec_dot_q2_K_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - -#elif __CUDA_ARCH__ >= MIN_CC_DP4A - const int mmq_x = MMQ_X_Q2_K_PASCAL; - const int mmq_y = MMQ_Y_Q2_K_PASCAL; - const int nwarps = NWARPS_Q2_K_PASCAL; - - mul_mat_q, - load_tiles_q2_K, VDR_Q2_K_Q8_1_MMQ, vec_dot_q2_K_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); -#else - GGML_UNUSED(vec_dot_q2_K_q8_1_mul_mat); - NO_DEVICE_CODE; -#endif // __CUDA_ARCH__ >= CC_VOLTA -} - -#define MMQ_X_Q3_K_RDNA2 128 -#define MMQ_Y_Q3_K_RDNA2 64 -#define NWARPS_Q3_K_RDNA2 8 -#define MMQ_X_Q3_K_RDNA1 32 -#define MMQ_Y_Q3_K_RDNA1 128 -#define NWARPS_Q3_K_RDNA1 8 -#if defined(CUDA_USE_TENSOR_CORES) -#define MMQ_X_Q3_K_AMPERE 4 -#define MMQ_Y_Q3_K_AMPERE 32 -#define NWARPS_Q3_K_AMPERE 4 -#else -#define MMQ_X_Q3_K_AMPERE 128 -#define MMQ_Y_Q3_K_AMPERE 128 -#define NWARPS_Q3_K_AMPERE 4 -#endif -#define MMQ_X_Q3_K_PASCAL 64 -#define MMQ_Y_Q3_K_PASCAL 64 -#define NWARPS_Q3_K_PASCAL 8 - -template static __global__ void -#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) -#if defined(RDNA3) || defined(RDNA2) - __launch_bounds__(WARP_SIZE*NWARPS_Q3_K_RDNA2, 2) -#endif // defined(RDNA3) || defined(RDNA2) -#elif __CUDA_ARCH__ < CC_VOLTA - __launch_bounds__(WARP_SIZE*NWARPS_Q3_K_PASCAL, 2) -#endif // __CUDA_ARCH__ < CC_VOLTA - mul_mat_q3_K( - const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, - const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { - -#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) -#if defined(RDNA3) || defined(RDNA2) - const int mmq_x = MMQ_X_Q3_K_RDNA2; - const int mmq_y = MMQ_Y_Q3_K_RDNA2; - const int nwarps = NWARPS_Q3_K_RDNA2; -#else - const int mmq_x = MMQ_X_Q3_K_RDNA1; - const int mmq_y = MMQ_Y_Q3_K_RDNA1; - const int nwarps = NWARPS_Q3_K_RDNA1; -#endif // defined(RDNA3) || defined(RDNA2) - - mul_mat_q, - load_tiles_q3_K, VDR_Q3_K_Q8_1_MMQ, vec_dot_q3_K_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - -#elif __CUDA_ARCH__ >= CC_VOLTA - const int mmq_x = MMQ_X_Q3_K_AMPERE; - const int mmq_y = MMQ_Y_Q3_K_AMPERE; - const int nwarps = NWARPS_Q3_K_AMPERE; - - mul_mat_q, - load_tiles_q3_K, VDR_Q3_K_Q8_1_MMQ, vec_dot_q3_K_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - -#elif __CUDA_ARCH__ >= MIN_CC_DP4A - const int mmq_x = MMQ_X_Q3_K_PASCAL; - const int mmq_y = MMQ_Y_Q3_K_PASCAL; - const int nwarps = NWARPS_Q3_K_PASCAL; - - mul_mat_q, - load_tiles_q3_K, VDR_Q3_K_Q8_1_MMQ, vec_dot_q3_K_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); -#else - GGML_UNUSED(vec_dot_q3_K_q8_1_mul_mat); - NO_DEVICE_CODE; -#endif // __CUDA_ARCH__ >= CC_VOLTA -} - -#define MMQ_X_Q4_K_RDNA2 64 -#define MMQ_Y_Q4_K_RDNA2 128 -#define NWARPS_Q4_K_RDNA2 8 -#define MMQ_X_Q4_K_RDNA1 32 -#define MMQ_Y_Q4_K_RDNA1 64 -#define NWARPS_Q4_K_RDNA1 8 -#if defined(CUDA_USE_TENSOR_CORES) -#define MMQ_X_Q4_K_AMPERE 4 -#define MMQ_Y_Q4_K_AMPERE 32 -#define NWARPS_Q4_K_AMPERE 4 -#else -#define MMQ_X_Q4_K_AMPERE 64 -#define MMQ_Y_Q4_K_AMPERE 128 -#define NWARPS_Q4_K_AMPERE 4 -#endif -#define MMQ_X_Q4_K_PASCAL 64 -#define MMQ_Y_Q4_K_PASCAL 64 -#define NWARPS_Q4_K_PASCAL 8 - -template static __global__ void -#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) -#if defined(RDNA3) || defined(RDNA2) - __launch_bounds__(WARP_SIZE*NWARPS_Q4_K_RDNA2, 2) -#endif // defined(RDNA3) || defined(RDNA2) -#elif __CUDA_ARCH__ < CC_VOLTA - __launch_bounds__(WARP_SIZE*NWARPS_Q4_K_PASCAL, 2) -#endif // __CUDA_ARCH__ < CC_VOLTA - mul_mat_q4_K( - const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, - const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { - -#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) -#if defined(RDNA3) || defined(RDNA2) - const int mmq_x = MMQ_X_Q4_K_RDNA2; - const int mmq_y = MMQ_Y_Q4_K_RDNA2; - const int nwarps = NWARPS_Q4_K_RDNA2; -#else - const int mmq_x = MMQ_X_Q4_K_RDNA1; - const int mmq_y = MMQ_Y_Q4_K_RDNA1; - const int nwarps = NWARPS_Q4_K_RDNA1; -#endif // defined(RDNA3) || defined(RDNA2) - - mul_mat_q, - load_tiles_q4_K, VDR_Q4_K_Q8_1_MMQ, vec_dot_q4_K_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - -#elif __CUDA_ARCH__ >= CC_VOLTA - const int mmq_x = MMQ_X_Q4_K_AMPERE; - const int mmq_y = MMQ_Y_Q4_K_AMPERE; - const int nwarps = NWARPS_Q4_K_AMPERE; - - mul_mat_q, - load_tiles_q4_K, VDR_Q4_K_Q8_1_MMQ, vec_dot_q4_K_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - -#elif __CUDA_ARCH__ >= MIN_CC_DP4A - const int mmq_x = MMQ_X_Q4_K_PASCAL; - const int mmq_y = MMQ_Y_Q4_K_PASCAL; - const int nwarps = NWARPS_Q4_K_PASCAL; - - mul_mat_q, - load_tiles_q4_K, VDR_Q4_K_Q8_1_MMQ, vec_dot_q4_K_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); -#else - GGML_UNUSED(vec_dot_q4_K_q8_1_mul_mat); - NO_DEVICE_CODE; -#endif // __CUDA_ARCH__ >= CC_VOLTA -} - -#define MMQ_X_Q5_K_RDNA2 64 -#define MMQ_Y_Q5_K_RDNA2 128 -#define NWARPS_Q5_K_RDNA2 8 -#define MMQ_X_Q5_K_RDNA1 32 -#define MMQ_Y_Q5_K_RDNA1 64 -#define NWARPS_Q5_K_RDNA1 8 -#if defined(CUDA_USE_TENSOR_CORES) -#define MMQ_X_Q5_K_AMPERE 4 -#define MMQ_Y_Q5_K_AMPERE 32 -#define NWARPS_Q5_K_AMPERE 4 -#else -#define MMQ_X_Q5_K_AMPERE 64 -#define MMQ_Y_Q5_K_AMPERE 128 -#define NWARPS_Q5_K_AMPERE 4 -#endif -#define MMQ_X_Q5_K_PASCAL 64 -#define MMQ_Y_Q5_K_PASCAL 64 -#define NWARPS_Q5_K_PASCAL 8 - -template static __global__ void -#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) -#if defined(RDNA3) || defined(RDNA2) - __launch_bounds__(WARP_SIZE*NWARPS_Q5_K_RDNA2, 2) -#endif // defined(RDNA3) || defined(RDNA2) -#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) -mul_mat_q5_K( - const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, - const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { - -#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) -#if defined(RDNA3) || defined(RDNA2) - const int mmq_x = MMQ_X_Q5_K_RDNA2; - const int mmq_y = MMQ_Y_Q5_K_RDNA2; - const int nwarps = NWARPS_Q5_K_RDNA2; -#else - const int mmq_x = MMQ_X_Q5_K_RDNA1; - const int mmq_y = MMQ_Y_Q5_K_RDNA1; - const int nwarps = NWARPS_Q5_K_RDNA1; -#endif // defined(RDNA3) || defined(RDNA2) - - mul_mat_q, - load_tiles_q5_K, VDR_Q5_K_Q8_1_MMQ, vec_dot_q5_K_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - -#elif __CUDA_ARCH__ >= CC_VOLTA - const int mmq_x = MMQ_X_Q5_K_AMPERE; - const int mmq_y = MMQ_Y_Q5_K_AMPERE; - const int nwarps = NWARPS_Q5_K_AMPERE; - - mul_mat_q, - load_tiles_q5_K, VDR_Q5_K_Q8_1_MMQ, vec_dot_q5_K_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - -#elif __CUDA_ARCH__ >= MIN_CC_DP4A - const int mmq_x = MMQ_X_Q5_K_PASCAL; - const int mmq_y = MMQ_Y_Q5_K_PASCAL; - const int nwarps = NWARPS_Q5_K_PASCAL; - - mul_mat_q, - load_tiles_q5_K, VDR_Q5_K_Q8_1_MMQ, vec_dot_q5_K_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); -#else - GGML_UNUSED(vec_dot_q5_K_q8_1_mul_mat); - NO_DEVICE_CODE; -#endif // __CUDA_ARCH__ >= CC_VOLTA -} - -#define MMQ_X_Q6_K_RDNA2 64 -#define MMQ_Y_Q6_K_RDNA2 128 -#define NWARPS_Q6_K_RDNA2 8 -#define MMQ_X_Q6_K_RDNA1 32 -#define MMQ_Y_Q6_K_RDNA1 64 -#define NWARPS_Q6_K_RDNA1 8 -#if defined(CUDA_USE_TENSOR_CORES) -#define MMQ_X_Q6_K_AMPERE 4 -#define MMQ_Y_Q6_K_AMPERE 32 -#define NWARPS_Q6_K_AMPERE 4 -#else -#define MMQ_X_Q6_K_AMPERE 64 -#define MMQ_Y_Q6_K_AMPERE 64 -#define NWARPS_Q6_K_AMPERE 4 -#endif -#define MMQ_X_Q6_K_PASCAL 64 -#define MMQ_Y_Q6_K_PASCAL 64 -#define NWARPS_Q6_K_PASCAL 8 - -template static __global__ void -#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) -#if defined(RDNA3) || defined(RDNA2) - __launch_bounds__(WARP_SIZE*NWARPS_Q6_K_RDNA2, 2) -#endif // defined(RDNA3) || defined(RDNA2) -#elif __CUDA_ARCH__ < CC_VOLTA - __launch_bounds__(WARP_SIZE*NWARPS_Q6_K_PASCAL, 2) -#endif // __CUDA_ARCH__ < CC_VOLTA - mul_mat_q6_K( - const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, - const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { - -#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) -#if defined(RDNA3) || defined(RDNA2) - const int mmq_x = MMQ_X_Q6_K_RDNA2; - const int mmq_y = MMQ_Y_Q6_K_RDNA2; - const int nwarps = NWARPS_Q6_K_RDNA2; -#else - const int mmq_x = MMQ_X_Q6_K_RDNA1; - const int mmq_y = MMQ_Y_Q6_K_RDNA1; - const int nwarps = NWARPS_Q6_K_RDNA1; -#endif // defined(RDNA3) || defined(RDNA2) - - mul_mat_q, - load_tiles_q6_K, VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - -#elif __CUDA_ARCH__ >= CC_VOLTA - const int mmq_x = MMQ_X_Q6_K_AMPERE; - const int mmq_y = MMQ_Y_Q6_K_AMPERE; - const int nwarps = NWARPS_Q6_K_AMPERE; - - mul_mat_q, - load_tiles_q6_K, VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - -#elif __CUDA_ARCH__ >= MIN_CC_DP4A - const int mmq_x = MMQ_X_Q6_K_PASCAL; - const int mmq_y = MMQ_Y_Q6_K_PASCAL; - const int nwarps = NWARPS_Q6_K_PASCAL; - - mul_mat_q, - load_tiles_q6_K, VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); -#else - GGML_UNUSED(vec_dot_q6_K_q8_1_mul_mat); - NO_DEVICE_CODE; -#endif // __CUDA_ARCH__ >= CC_VOLTA -} - -static void ggml_mul_mat_q4_0_q8_1_cuda( - const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, - const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { - - int id = ggml_cuda_get_device(); - const int compute_capability = ggml_cuda_info().devices[id].cc; - - int mmq_x, mmq_y, nwarps; - if (compute_capability >= CC_RDNA2) { - mmq_x = MMQ_X_Q4_0_RDNA2; - mmq_y = MMQ_Y_Q4_0_RDNA2; - nwarps = NWARPS_Q4_0_RDNA2; - } else if (compute_capability >= CC_OFFSET_AMD) { - mmq_x = MMQ_X_Q4_0_RDNA1; - mmq_y = MMQ_Y_Q4_0_RDNA1; - nwarps = NWARPS_Q4_0_RDNA1; - } else if (compute_capability >= CC_VOLTA) { - mmq_x = MMQ_X_Q4_0_AMPERE; - mmq_y = MMQ_Y_Q4_0_AMPERE; - nwarps = NWARPS_Q4_0_AMPERE; - } else if (compute_capability >= MIN_CC_DP4A) { - mmq_x = MMQ_X_Q4_0_PASCAL; - mmq_y = MMQ_Y_Q4_0_PASCAL; - nwarps = NWARPS_Q4_0_PASCAL; - } else { - GGML_ASSERT(false); - } - - const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; - const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; - const dim3 block_nums(block_num_x, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, nwarps, 1); - - if (nrows_x % mmq_y == 0) { - const bool need_check = false; - mul_mat_q4_0<<>> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - } else { - const bool need_check = true; - mul_mat_q4_0<<>> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - } -} - -static void ggml_mul_mat_q4_1_q8_1_cuda( - const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, - const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { - - int id = ggml_cuda_get_device(); - const int compute_capability = ggml_cuda_info().devices[id].cc; - - int mmq_x, mmq_y, nwarps; - if (compute_capability >= CC_RDNA2) { - mmq_x = MMQ_X_Q4_1_RDNA2; - mmq_y = MMQ_Y_Q4_1_RDNA2; - nwarps = NWARPS_Q4_1_RDNA2; - } else if (compute_capability >= CC_OFFSET_AMD) { - mmq_x = MMQ_X_Q4_1_RDNA1; - mmq_y = MMQ_Y_Q4_1_RDNA1; - nwarps = NWARPS_Q4_1_RDNA1; - } else if (compute_capability >= CC_VOLTA) { - mmq_x = MMQ_X_Q4_1_AMPERE; - mmq_y = MMQ_Y_Q4_1_AMPERE; - nwarps = NWARPS_Q4_1_AMPERE; - } else if (compute_capability >= MIN_CC_DP4A) { - mmq_x = MMQ_X_Q4_1_PASCAL; - mmq_y = MMQ_Y_Q4_1_PASCAL; - nwarps = NWARPS_Q4_1_PASCAL; - } else { - GGML_ASSERT(false); - } - - const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; - const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; - const dim3 block_nums(block_num_x, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, nwarps, 1); - - if (nrows_x % mmq_y == 0) { - const bool need_check = false; - mul_mat_q4_1<<>> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - } else { - const bool need_check = true; - mul_mat_q4_1<<>> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - } -} - -static void ggml_mul_mat_q5_0_q8_1_cuda( - const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, - const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { - - int id = ggml_cuda_get_device(); - const int compute_capability = ggml_cuda_info().devices[id].cc; - - int mmq_x, mmq_y, nwarps; - if (compute_capability >= CC_RDNA2) { - mmq_x = MMQ_X_Q5_0_RDNA2; - mmq_y = MMQ_Y_Q5_0_RDNA2; - nwarps = NWARPS_Q5_0_RDNA2; - } else if (compute_capability >= CC_OFFSET_AMD) { - mmq_x = MMQ_X_Q5_0_RDNA1; - mmq_y = MMQ_Y_Q5_0_RDNA1; - nwarps = NWARPS_Q5_0_RDNA1; - } else if (compute_capability >= CC_VOLTA) { - mmq_x = MMQ_X_Q5_0_AMPERE; - mmq_y = MMQ_Y_Q5_0_AMPERE; - nwarps = NWARPS_Q5_0_AMPERE; - } else if (compute_capability >= MIN_CC_DP4A) { - mmq_x = MMQ_X_Q5_0_PASCAL; - mmq_y = MMQ_Y_Q5_0_PASCAL; - nwarps = NWARPS_Q5_0_PASCAL; - } else { - GGML_ASSERT(false); - } - - const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; - const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; - const dim3 block_nums(block_num_x, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, nwarps, 1); - - if (nrows_x % mmq_y == 0) { - const bool need_check = false; - mul_mat_q5_0<<>> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - } else { - const bool need_check = true; - mul_mat_q5_0<<>> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - } -} - -static void ggml_mul_mat_q5_1_q8_1_cuda( - const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, - const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { - - int id = ggml_cuda_get_device(); - const int compute_capability = ggml_cuda_info().devices[id].cc; - - int mmq_x, mmq_y, nwarps; - if (compute_capability >= CC_RDNA2) { - mmq_x = MMQ_X_Q5_1_RDNA2; - mmq_y = MMQ_Y_Q5_1_RDNA2; - nwarps = NWARPS_Q5_1_RDNA2; - } else if (compute_capability >= CC_OFFSET_AMD) { - mmq_x = MMQ_X_Q5_1_RDNA1; - mmq_y = MMQ_Y_Q5_1_RDNA1; - nwarps = NWARPS_Q5_1_RDNA1; - } else if (compute_capability >= CC_VOLTA) { - mmq_x = MMQ_X_Q5_1_AMPERE; - mmq_y = MMQ_Y_Q5_1_AMPERE; - nwarps = NWARPS_Q5_1_AMPERE; - } else if (compute_capability >= MIN_CC_DP4A) { - mmq_x = MMQ_X_Q5_1_PASCAL; - mmq_y = MMQ_Y_Q5_1_PASCAL; - nwarps = NWARPS_Q5_1_PASCAL; - } else { - GGML_ASSERT(false); - } - - const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; - const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; - const dim3 block_nums(block_num_x, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, nwarps, 1); - - if (nrows_x % mmq_y == 0) { - const bool need_check = false; - mul_mat_q5_1<<>> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - } else { - const bool need_check = true; - mul_mat_q5_1<<>> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - } -} - -static void ggml_mul_mat_q8_0_q8_1_cuda( - const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, - const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { - - int id = ggml_cuda_get_device(); - const int compute_capability = ggml_cuda_info().devices[id].cc; - - int mmq_x, mmq_y, nwarps; - if (compute_capability >= CC_RDNA2) { - mmq_x = MMQ_X_Q8_0_RDNA2; - mmq_y = MMQ_Y_Q8_0_RDNA2; - nwarps = NWARPS_Q8_0_RDNA2; - } else if (compute_capability >= CC_OFFSET_AMD) { - mmq_x = MMQ_X_Q8_0_RDNA1; - mmq_y = MMQ_Y_Q8_0_RDNA1; - nwarps = NWARPS_Q8_0_RDNA1; - } else if (compute_capability >= CC_VOLTA) { - mmq_x = MMQ_X_Q8_0_AMPERE; - mmq_y = MMQ_Y_Q8_0_AMPERE; - nwarps = NWARPS_Q8_0_AMPERE; - } else if (compute_capability >= MIN_CC_DP4A) { - mmq_x = MMQ_X_Q8_0_PASCAL; - mmq_y = MMQ_Y_Q8_0_PASCAL; - nwarps = NWARPS_Q8_0_PASCAL; - } else { - GGML_ASSERT(false); - } - - const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; - const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; - const dim3 block_nums(block_num_x, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, nwarps, 1); - - if (nrows_x % mmq_y == 0) { - const bool need_check = false; - mul_mat_q8_0<<>> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - } else { - const bool need_check = true; - mul_mat_q8_0<<>> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - } -} - -static void ggml_mul_mat_q2_K_q8_1_cuda( - const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, - const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { - - int id = ggml_cuda_get_device(); - const int compute_capability = ggml_cuda_info().devices[id].cc; - - int mmq_x, mmq_y, nwarps; - if (compute_capability >= CC_RDNA2) { - mmq_x = MMQ_X_Q2_K_RDNA2; - mmq_y = MMQ_Y_Q2_K_RDNA2; - nwarps = NWARPS_Q2_K_RDNA2; - } else if (compute_capability >= CC_OFFSET_AMD) { - mmq_x = MMQ_X_Q2_K_RDNA1; - mmq_y = MMQ_Y_Q2_K_RDNA1; - nwarps = NWARPS_Q2_K_RDNA1; - } else if (compute_capability >= CC_VOLTA) { - mmq_x = MMQ_X_Q2_K_AMPERE; - mmq_y = MMQ_Y_Q2_K_AMPERE; - nwarps = NWARPS_Q2_K_AMPERE; - } else if (compute_capability >= MIN_CC_DP4A) { - mmq_x = MMQ_X_Q2_K_PASCAL; - mmq_y = MMQ_Y_Q2_K_PASCAL; - nwarps = NWARPS_Q2_K_PASCAL; - } else { - GGML_ASSERT(false); - } - - const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; - const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; - const dim3 block_nums(block_num_x, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, nwarps, 1); - - if (nrows_x % mmq_y == 0) { - const bool need_check = false; - mul_mat_q2_K<<>> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - } else { - const bool need_check = true; - mul_mat_q2_K<<>> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - } -} - -static void ggml_mul_mat_q3_K_q8_1_cuda( - const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, - const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { - -#if QK_K == 256 - - int id = ggml_cuda_get_device(); - const int compute_capability = ggml_cuda_info().devices[id].cc; - - int mmq_x, mmq_y, nwarps; - if (compute_capability >= CC_RDNA2) { - mmq_x = MMQ_X_Q3_K_RDNA2; - mmq_y = MMQ_Y_Q3_K_RDNA2; - nwarps = NWARPS_Q3_K_RDNA2; - } else if (compute_capability >= CC_OFFSET_AMD) { - mmq_x = MMQ_X_Q3_K_RDNA1; - mmq_y = MMQ_Y_Q3_K_RDNA1; - nwarps = NWARPS_Q3_K_RDNA1; - } else if (compute_capability >= CC_VOLTA) { - mmq_x = MMQ_X_Q3_K_AMPERE; - mmq_y = MMQ_Y_Q3_K_AMPERE; - nwarps = NWARPS_Q3_K_AMPERE; - } else if (compute_capability >= MIN_CC_DP4A) { - mmq_x = MMQ_X_Q3_K_PASCAL; - mmq_y = MMQ_Y_Q3_K_PASCAL; - nwarps = NWARPS_Q3_K_PASCAL; - } else { - GGML_ASSERT(false); - } - - const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; - const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; - const dim3 block_nums(block_num_x, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, nwarps, 1); - - if (nrows_x % mmq_y == 0) { - const bool need_check = false; - mul_mat_q3_K<<>> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - } else { - const bool need_check = true; - mul_mat_q3_K<<>> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - } -#endif -} - -static void ggml_mul_mat_q4_K_q8_1_cuda( - const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, - const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { - - int id = ggml_cuda_get_device(); - const int compute_capability = ggml_cuda_info().devices[id].cc; - - int mmq_x, mmq_y, nwarps; - if (compute_capability >= CC_RDNA2) { - mmq_x = MMQ_X_Q4_K_RDNA2; - mmq_y = MMQ_Y_Q4_K_RDNA2; - nwarps = NWARPS_Q4_K_RDNA2; - } else if (compute_capability >= CC_OFFSET_AMD) { - mmq_x = MMQ_X_Q4_K_RDNA1; - mmq_y = MMQ_Y_Q4_K_RDNA1; - nwarps = NWARPS_Q4_K_RDNA1; - } else if (compute_capability >= CC_VOLTA) { - mmq_x = MMQ_X_Q4_K_AMPERE; - mmq_y = MMQ_Y_Q4_K_AMPERE; - nwarps = NWARPS_Q4_K_AMPERE; - } else if (compute_capability >= MIN_CC_DP4A) { - mmq_x = MMQ_X_Q4_K_PASCAL; - mmq_y = MMQ_Y_Q4_K_PASCAL; - nwarps = NWARPS_Q4_K_PASCAL; - } else { - GGML_ASSERT(false); - } - - const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; - const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; - const dim3 block_nums(block_num_x, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, nwarps, 1); - - if (nrows_x % mmq_y == 0) { - const bool need_check = false; - mul_mat_q4_K<<>> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - } else { - const bool need_check = true; - mul_mat_q4_K<<>> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - } -} - -static void ggml_mul_mat_q5_K_q8_1_cuda( - const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, - const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { - - int id = ggml_cuda_get_device(); - const int compute_capability = ggml_cuda_info().devices[id].cc; - - int mmq_x, mmq_y, nwarps; - if (compute_capability >= CC_RDNA2) { - mmq_x = MMQ_X_Q5_K_RDNA2; - mmq_y = MMQ_Y_Q5_K_RDNA2; - nwarps = NWARPS_Q5_K_RDNA2; - } else if (compute_capability >= CC_OFFSET_AMD) { - mmq_x = MMQ_X_Q5_K_RDNA1; - mmq_y = MMQ_Y_Q5_K_RDNA1; - nwarps = NWARPS_Q5_K_RDNA1; - } else if (compute_capability >= CC_VOLTA) { - mmq_x = MMQ_X_Q5_K_AMPERE; - mmq_y = MMQ_Y_Q5_K_AMPERE; - nwarps = NWARPS_Q5_K_AMPERE; - } else if (compute_capability >= MIN_CC_DP4A) { - mmq_x = MMQ_X_Q5_K_PASCAL; - mmq_y = MMQ_Y_Q5_K_PASCAL; - nwarps = NWARPS_Q5_K_PASCAL; - } else { - GGML_ASSERT(false); - } - - const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; - const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; - const dim3 block_nums(block_num_x, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, nwarps, 1); - - if (nrows_x % mmq_y == 0) { - const bool need_check = false; - mul_mat_q5_K<<>> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - } else { - const bool need_check = true; - mul_mat_q5_K<<>> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - } -} - -static void ggml_mul_mat_q6_K_q8_1_cuda( - const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, - const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { - - int id = ggml_cuda_get_device(); - const int compute_capability = ggml_cuda_info().devices[id].cc; - - int mmq_x, mmq_y, nwarps; - if (compute_capability >= CC_RDNA2) { - mmq_x = MMQ_X_Q6_K_RDNA2; - mmq_y = MMQ_Y_Q6_K_RDNA2; - nwarps = NWARPS_Q6_K_RDNA2; - } else if (compute_capability >= CC_OFFSET_AMD) { - mmq_x = MMQ_X_Q6_K_RDNA1; - mmq_y = MMQ_Y_Q6_K_RDNA1; - nwarps = NWARPS_Q6_K_RDNA1; - } else if (compute_capability >= CC_VOLTA) { - mmq_x = MMQ_X_Q6_K_AMPERE; - mmq_y = MMQ_Y_Q6_K_AMPERE; - nwarps = NWARPS_Q6_K_AMPERE; - } else if (compute_capability >= MIN_CC_DP4A) { - mmq_x = MMQ_X_Q6_K_PASCAL; - mmq_y = MMQ_Y_Q6_K_PASCAL; - nwarps = NWARPS_Q6_K_PASCAL; - } else { - GGML_ASSERT(false); - } - - const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; - const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; - const dim3 block_nums(block_num_x, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, nwarps, 1); - - if (nrows_x % mmq_y == 0) { - const bool need_check = false; - mul_mat_q6_K<<>> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - } else { - const bool need_check = true; - mul_mat_q6_K<<>> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - } -} void ggml_cuda_op_mul_mat_q( ggml_backend_cuda_context & ctx, @@ -2182,49 +8,55 @@ void ggml_cuda_op_mul_mat_q( const int64_t ne00 = src0->ne[0]; + const int64_t nb01 = src0->nb[1]; + const int64_t ne10 = src1->ne[0]; GGML_ASSERT(ne10 % QK8_1 == 0); const int64_t ne0 = dst->ne[0]; const int64_t row_diff = row_high - row_low; + const int64_t stride00 = nb01 / ggml_type_size(src0->type); int id = ggml_cuda_get_device(); + const int compute_capability = ggml_cuda_info().devices[id].cc; // the main device has a larger memory buffer to hold the results from all GPUs // nrows_dst == nrows of the matrix that the kernel writes into const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff; + const mmq_args args = {src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stride00, src1_padded_row_size, src1_ncols, nrows_dst}; + switch (src0->type) { case GGML_TYPE_Q4_0: - ggml_mul_mat_q4_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream); + mul_mat_q_case(args, stream); break; case GGML_TYPE_Q4_1: - ggml_mul_mat_q4_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream); + mul_mat_q_case(args, stream); break; case GGML_TYPE_Q5_0: - ggml_mul_mat_q5_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream); + mul_mat_q_case(args, stream); break; case GGML_TYPE_Q5_1: - ggml_mul_mat_q5_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream); + mul_mat_q_case(args, stream); break; case GGML_TYPE_Q8_0: - ggml_mul_mat_q8_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream); + mul_mat_q_case(args, stream); break; case GGML_TYPE_Q2_K: - ggml_mul_mat_q2_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream); + mul_mat_q_case(args, stream); break; case GGML_TYPE_Q3_K: - ggml_mul_mat_q3_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream); + mul_mat_q_case(args, stream); break; case GGML_TYPE_Q4_K: - ggml_mul_mat_q4_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream); + mul_mat_q_case(args, stream); break; case GGML_TYPE_Q5_K: - ggml_mul_mat_q5_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream); + mul_mat_q_case(args, stream); break; case GGML_TYPE_Q6_K: - ggml_mul_mat_q6_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream); + mul_mat_q_case(args, stream); break; default: GGML_ASSERT(false); diff --git a/llama/ggml-cuda/mmq.cuh b/llama/ggml-cuda/mmq.cuh index 807817c4..6744cce6 100644 --- a/llama/ggml-cuda/mmq.cuh +++ b/llama/ggml-cuda/mmq.cuh @@ -1,4 +1,1304 @@ #include "common.cuh" +#include "vecdotq.cuh" + +#include +#include + +typedef void (*load_tiles_mmq_t)( + const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride); +typedef void (*vec_dot_mmq_t)( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ms, float * __restrict__ sum, const int & k0); + +struct tile_x_sizes { + int ql; + int dm; + int qh; + int sc; +}; + +// get_mmq_x_max_host is in common.cuh so that it can be used to determine the correct way to round for --split-mode row + +static constexpr __device__ int get_mmq_x_max_device() { +#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) + return 64; +#else +#if __CUDA_ARCH__ >= CC_VOLTA +#ifdef CUDA_USE_TENSOR_CORES + return MMQ_MAX_BATCH_SIZE; +#else + return 128; +#endif // CUDA_USE_TENSOR_CORES +#else + return 64; +#endif // __CUDA_ARCH__ >= CC_VOLTA +#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) +} + +// get_mmq_y_host is in common.cuh so that it can be used to determine the correct way to round for --split-mode row + +#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) +static constexpr __device__ int get_mmq_y_device(int mmq_x) { + return mmq_x >= 32 ? 128 : 64; +} +#else +#if __CUDA_ARCH__ >= CC_VOLTA +static constexpr __device__ int get_mmq_y_device(int mmq_x) { + return mmq_x >= 32 ? 128 : 64; +} +#else +static constexpr __device__ int get_mmq_y_device(int /*mmq_x*/) { + return 64; +} +#endif // __CUDA_ARCH__ >= CC_VOLTA +#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) + +#define TILE_X_SIZES_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0, 0} +#define TILE_X_SIZES_Q4_1 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_1 + mmq_y/QI4_1, 0, 0} +#define TILE_X_SIZES_Q5_0 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_0 + mmq_y/QI5_0, 0, 0} +#define TILE_X_SIZES_Q5_1 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_1 + mmq_y/QI5_1, 0, 0} +#define TILE_X_SIZES_Q8_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI8_0 + mmq_y/QI8_0, 0, 0} +#define TILE_X_SIZES_Q2_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI2_K + mmq_y/QI2_K, 0, mmq_y*WARP_SIZE/4 + mmq_y/4} +#define TILE_X_SIZES_Q3_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI3_K + mmq_y/QI3_K, mmq_y*WARP_SIZE/2 + mmq_y/2, mmq_y*WARP_SIZE/4 + mmq_y/4} +#define TILE_X_SIZES_Q4_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_K + mmq_y/QI4_K, 0, mmq_y*WARP_SIZE/8 + mmq_y/8} +#define TILE_X_SIZES_Q5_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_K + mmq_y/QI5_K, 0, mmq_y*WARP_SIZE/8 + mmq_y/8} +#define TILE_X_SIZES_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, 0, mmq_y*WARP_SIZE/8 + mmq_y/8} + +#define GET_TILE_X_SIZES_BODY \ + return type == GGML_TYPE_Q4_0 ? TILE_X_SIZES_Q4_0 : \ + type == GGML_TYPE_Q4_1 ? TILE_X_SIZES_Q4_1 : \ + type == GGML_TYPE_Q5_0 ? TILE_X_SIZES_Q5_0 : \ + type == GGML_TYPE_Q5_1 ? TILE_X_SIZES_Q5_1 : \ + type == GGML_TYPE_Q8_0 ? TILE_X_SIZES_Q8_0 : \ + type == GGML_TYPE_Q2_K ? TILE_X_SIZES_Q2_K : \ + type == GGML_TYPE_Q3_K ? TILE_X_SIZES_Q3_K : \ + type == GGML_TYPE_Q4_K ? TILE_X_SIZES_Q4_K : \ + type == GGML_TYPE_Q5_K ? TILE_X_SIZES_Q5_K : \ + type == GGML_TYPE_Q6_K ? TILE_X_SIZES_Q6_K : \ + tile_x_sizes{0, 0, 0, 0} + +static tile_x_sizes get_tile_x_sizes_host(const ggml_type type, const int mmq_y) { + GET_TILE_X_SIZES_BODY; +} + +template +static constexpr __device__ tile_x_sizes get_tile_x_sizes_device(ggml_type type) { + GET_TILE_X_SIZES_BODY; +} + +// ------------------------------------------------------------ + +template static __device__ __forceinline__ void load_tiles_q4_0( + const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) { + GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); + + const int kbx = threadIdx.x / QI4_0; + const int kqsx = threadIdx.x % QI4_0; + + float * x_dmf = (float *) x_dm; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + threadIdx.y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx; + + x_ql[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8(bxi->qs, kqsx); + } + + const int blocks_per_tile_x_row = WARP_SIZE / QI4_0; + const int kbxd = threadIdx.x % blocks_per_tile_x_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_0) { + int i = i0 + threadIdx.y * QI4_0 + threadIdx.x / blocks_per_tile_x_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbxd; + + x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbxd] = bxi->d; + } +} + +template +static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mul_mat( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) { + + GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2)); + const float * x_dmf = (const float *) x_dm; + + int u[2*VDR_Q4_0_Q8_1_MMQ]; + +#pragma unroll + for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) { + u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE]; + u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_0) % WARP_SIZE]; + } + + sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_0_q8_1_impl + (&x_ql[i * (WARP_SIZE + 1) + k0], u, x_dmf[i * (WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0], + y_ds[j * (WARP_SIZE/QI8_1) + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]); + } + } +} + +template static __device__ __forceinline__ void load_tiles_q4_1( + const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) { + GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); + + const int kbx = threadIdx.x / QI4_1; + const int kqsx = threadIdx.x % QI4_1; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + threadIdx.y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx; + + x_ql[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx); + } + + const int blocks_per_tile_x_row = WARP_SIZE / QI4_1; + const int kbxd = threadIdx.x % blocks_per_tile_x_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_1) { + int i = i0 + threadIdx.y * QI4_1 + threadIdx.x / blocks_per_tile_x_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbxd; + + x_dm[i * (WARP_SIZE/QI4_1) + i / QI4_1 + kbxd] = bxi->dm; + } +} + +template +static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mul_mat( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) { + + GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2)); + + int u[2*VDR_Q4_1_Q8_1_MMQ]; + +#pragma unroll + for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) { + u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE]; + u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_1) % WARP_SIZE]; + } + + sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_1_q8_1_impl + (&x_ql[i * (WARP_SIZE + 1) + k0], u, x_dm[i * (WARP_SIZE/QI4_1) + i/QI4_1 + k0/QI4_1], + y_ds[j * (WARP_SIZE/QI8_1) + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]); + } + } +} + +template static __device__ __forceinline__ void load_tiles_q5_0( + const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) { + GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); + + const int kbx = threadIdx.x / QI5_0; + const int kqsx = threadIdx.x % QI5_0; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + threadIdx.y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbx; + + const int ql = get_int_from_uint8(bxi->qs, kqsx); + const int qh = get_int_from_uint8(bxi->qh, 0) >> (4 * (threadIdx.x % QI5_0)); + + int qs0 = (ql >> 0) & 0x0F0F0F0F; + qs0 |= (qh << 4) & 0x00000010; // 0 -> 4 + qs0 |= (qh << 11) & 0x00001000; // 1 -> 12 + qs0 |= (qh << 18) & 0x00100000; // 2 -> 20 + qs0 |= (qh << 25) & 0x10000000; // 3 -> 28 + qs0 = __vsubss4(qs0, 0x10101010); // subtract 16 + + x_ql[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+0] = qs0; + + int qs1 = (ql >> 4) & 0x0F0F0F0F; + qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4 + qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12 + qs1 |= (qh << 2) & 0x00100000; // 18 -> 20 + qs1 |= (qh << 9) & 0x10000000; // 19 -> 28 + qs1 = __vsubss4(qs1, 0x10101010); // subtract 16 + + x_ql[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+1] = qs1; + } + + const int blocks_per_tile_x_row = WARP_SIZE / QI5_0; + const int kbxd = threadIdx.x % blocks_per_tile_x_row; + float * x_dmf = (float *) x_dm; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_0) { + int i = i0 + threadIdx.y * QI5_0 + threadIdx.x / blocks_per_tile_x_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd; + + x_dmf[i * (WARP_SIZE/QI5_0) + i / QI5_0 + kbxd] = bxi->d; + } +} + +template +static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mul_mat( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) { + + GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2)); + const int index_bx = i * (WARP_SIZE/QI5_0) + i/QI5_0 + k0/QI5_0; + const float * x_dmf = (const float *) x_dm; + const float * y_df = (const float *) y_ds; + + int u[2*VDR_Q5_0_Q8_1_MMQ]; + +#pragma unroll + for (int l = 0; l < VDR_Q5_0_Q8_1_MMQ; ++l) { + u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE]; + u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_0) % WARP_SIZE]; + } + + sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl + (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k0], u, x_dmf[index_bx], y_df[j * (WARP_SIZE/QI8_1) + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]); + } + } +} + + +template static __device__ __forceinline__ void load_tiles_q5_1( + const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) { + GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); + + const int kbx = threadIdx.x / QI5_1; + const int kqsx = threadIdx.x % QI5_1; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + threadIdx.y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbx; + + const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx); + const int qh = get_int_from_uint8_aligned(bxi->qh, 0) >> (4 * (threadIdx.x % QI5_1)); + + int qs0 = (ql >> 0) & 0x0F0F0F0F; + qs0 |= (qh << 4) & 0x00000010; // 0 -> 4 + qs0 |= (qh << 11) & 0x00001000; // 1 -> 12 + qs0 |= (qh << 18) & 0x00100000; // 2 -> 20 + qs0 |= (qh << 25) & 0x10000000; // 3 -> 28 + + x_ql[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+0] = qs0; + + int qs1 = (ql >> 4) & 0x0F0F0F0F; + qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4 + qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12 + qs1 |= (qh << 2) & 0x00100000; // 18 -> 20 + qs1 |= (qh << 9) & 0x10000000; // 19 -> 28 + + x_ql[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+1] = qs1; + } + + const int blocks_per_tile_x_row = WARP_SIZE / QI5_1; + const int kbxd = threadIdx.x % blocks_per_tile_x_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_1) { + int i = i0 + threadIdx.y * QI5_1 + threadIdx.x / blocks_per_tile_x_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd; + + x_dm[i * (WARP_SIZE/QI5_1) + i / QI5_1 + kbxd] = bxi->dm; + } +} + +template +static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mul_mat( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) { + + GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2)); + const int index_bx = i * (WARP_SIZE/QI5_1) + + i/QI5_1 + k0/QI5_1; + + int u[2*VDR_Q5_1_Q8_1_MMQ]; + +#pragma unroll + for (int l = 0; l < VDR_Q5_1_Q8_1_MMQ; ++l) { + u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE]; + u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_1) % WARP_SIZE]; + } + + sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_1_q8_1_impl + (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k0], u, x_dm[index_bx], y_ds[j * (WARP_SIZE/QI8_1) + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]); + } + } +} + +template static __device__ __forceinline__ void load_tiles_q8_0( + const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) { + + GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); + + const int kbx = threadIdx.x / QI8_0; + const int kqsx = threadIdx.x % QI8_0; + float * x_dmf = (float *) x_dm; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + threadIdx.y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx; + + x_ql[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_int8(bxi->qs, kqsx); + } + + const int blocks_per_tile_x_row = WARP_SIZE / QI8_0; + const int kbxd = threadIdx.x % blocks_per_tile_x_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI8_0) { + int i = i0 + threadIdx.y * QI8_0 + threadIdx.x / blocks_per_tile_x_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd; + + x_dmf[i * (WARP_SIZE/QI8_0) + i / QI8_0 + kbxd] = bxi->d; + } +} + +template +static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mul_mat( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) { + + GGML_UNUSED(x_qh); GGML_UNUSED(x_sc); + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + const float * x_dmf = (const float *) x_dm; + const float * y_df = (const float *) y_ds; + + sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl + (&x_ql[i * (WARP_SIZE + 1) + k0], &y_qs[j * WARP_SIZE + k0], x_dmf[i * (WARP_SIZE/QI8_0) + i/QI8_0 + k0/QI8_0], + y_df[j * (WARP_SIZE/QI8_1) + k0/QI8_1]); + } + } +} + +template static __device__ __forceinline__ void load_tiles_q2_K( + const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) { + GGML_UNUSED(x_qh); + + const int kbx = threadIdx.x / QI2_K; + const int kqsx = threadIdx.x % QI2_K; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + threadIdx.y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride + kbx; + + x_ql[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx); + } + + const int blocks_per_tile_x_row = WARP_SIZE / QI2_K; + const int kbxd = threadIdx.x % blocks_per_tile_x_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI2_K) { + int i = (i0 + threadIdx.y * QI2_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride + kbxd; + + x_dm[i * (WARP_SIZE/QI2_K) + i / QI2_K + kbxd] = bxi->dm; + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) { + int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4); + + if (need_check) { + i = min(i, i_max); + } + + const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/4)) / (QI2_K/4); + + x_sc[i * (WARP_SIZE/4) + i / 4 + threadIdx.x % (WARP_SIZE/4)] = get_int_from_uint8_aligned(bxi->scales, threadIdx.x % (QI2_K/4)); + } +} + +template +static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mul_mat( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) { + + GGML_UNUSED(x_qh); + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + const int kbx = k0 / QI2_K; + const int ky = (k0 % QI2_K) * QR2_K; + const float * y_df = (const float *) y_ds; + + int v[QR2_K*VDR_Q2_K_Q8_1_MMQ]; + + const int kqsx = i * (WARP_SIZE + 1) + kbx*QI2_K + (QI2_K/2) * (ky/(2*QI2_K)) + ky % (QI2_K/2); + const int shift = 2 * ((ky % (2*QI2_K)) / (QI2_K/2)); + +#pragma unroll + for (int l = 0; l < QR2_K*VDR_Q2_K_Q8_1_MMQ; ++l) { + v[l] = (x_ql[kqsx + l] >> shift) & 0x03030303; + } + + const uint8_t * scales = ((const uint8_t *) &x_sc[i * (WARP_SIZE/4) + i/4 + kbx*4]) + ky/4; + + const int index_y = j * WARP_SIZE + (QR2_K*k0) % WARP_SIZE; + sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq( + v, &y_qs[index_y], scales, x_dm[i * (WARP_SIZE/QI2_K) + i/QI2_K + kbx], y_df[index_y/QI8_1]); + } + } +} + +template static __device__ __forceinline__ void load_tiles_q3_K( + const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) { + + const int kbx = threadIdx.x / QI3_K; + const int kqsx = threadIdx.x % QI3_K; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + threadIdx.y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride + kbx; + + x_ql[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8(bxi->qs, kqsx); + } + + const int blocks_per_tile_x_row = WARP_SIZE / QI3_K; + const int kbxd = threadIdx.x % blocks_per_tile_x_row; + float * x_dmf = (float *) x_dm; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI3_K) { + int i = (i0 + threadIdx.y * QI3_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride + kbxd; + + x_dmf[i * (WARP_SIZE/QI3_K) + i / QI3_K + kbxd] = bxi->d; + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 2) { + int i = i0 + threadIdx.y * 2 + threadIdx.x / (WARP_SIZE/2); + + if (need_check) { + i = min(i, i_max); + } + + const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/2)) / (QI3_K/2); + + // invert the mask with ~ so that a 0/1 results in 4/0 being subtracted + x_qh[i * (WARP_SIZE/2) + i / 2 + threadIdx.x % (WARP_SIZE/2)] = ~get_int_from_uint8(bxi->hmask, threadIdx.x % (QI3_K/2)); + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) { + int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4); + + if (need_check) { + i = min(i, i_max); + } + + const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/4)) / (QI3_K/4); + + const int ksc = threadIdx.x % (QI3_K/4); + + const int ksc_low = ksc % (QI3_K/8); + const int shift_low = 4 * (ksc / (QI3_K/8)); + const int sc_low = (get_int_from_uint8(bxi->scales, ksc_low) >> shift_low) & 0x0F0F0F0F; + + const int ksc_high = QI3_K/8; + const int shift_high = 2 * ksc; + const int sc_high = ((get_int_from_uint8(bxi->scales, ksc_high) >> shift_high) << 4) & 0x30303030; + + const int sc = __vsubss4(sc_low | sc_high, 0x20202020); + + x_sc[i * (WARP_SIZE/4) + i / 4 + threadIdx.x % (WARP_SIZE/4)] = sc; + } +} + +template +static __device__ __forceinline__ void vec_dot_q3_K_q8_1_mul_mat( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) { + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + const int kbx = k0 / QI3_K; + const int ky = (k0 % QI3_K) * QR3_K; + const float * x_dmf = (const float *) x_dm; + const float * y_df = (const float *) y_ds; + + const int8_t * scales = ((const int8_t *) (x_sc + i * (WARP_SIZE/4) + i/4 + kbx*4)) + ky/4; + + int v[QR3_K*VDR_Q3_K_Q8_1_MMQ]; + +#pragma unroll + for (int l = 0; l < QR3_K*VDR_Q3_K_Q8_1_MMQ; ++l) { + const int kqsx = i * (WARP_SIZE + 1) + kbx*QI3_K + (QI3_K/2) * (ky/(2*QI3_K)) + ky % (QI3_K/2); + const int shift = 2 * ((ky % 32) / 8); + const int vll = (x_ql[kqsx + l] >> shift) & 0x03030303; + + const int vh = x_qh[i * (WARP_SIZE/2) + i/2 + kbx * (QI3_K/2) + (ky+l)%8] >> ((ky+l) / 8); + const int vlh = (vh << 2) & 0x04040404; + + v[l] = __vsubss4(vll, vlh); + } + + const int index_y = j * WARP_SIZE + (k0*QR3_K) % WARP_SIZE; + sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q3_K_q8_1_impl_mmq( + v, &y_qs[index_y], scales, x_dmf[i * (WARP_SIZE/QI3_K) + i/QI3_K + kbx], y_df[index_y/QI8_1]); + } + } +} + +template static __device__ __forceinline__ void load_tiles_q4_K( + const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) { + GGML_UNUSED(x_qh); + + const int kbx = 0; // threadIdx.x / QI4_K + const int kqsx = threadIdx.x; // threadIdx.x % QI4_K + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + threadIdx.y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + kbx; + + x_ql[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx); + } + + const int blocks_per_tile_x_row = WARP_SIZE / QI4_K; // == 1 if QK_K == 256 + const int kbxd = threadIdx.x % blocks_per_tile_x_row; // == 0 if QK_K == 256 + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_K) { + int i = (i0 + threadIdx.y * QI4_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + kbxd; + + x_dm[i * (WARP_SIZE/QI4_K) + i / QI4_K + kbxd] = bxi->dm; + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) { + int i = (i0 + threadIdx.y * 8 + threadIdx.x / (WARP_SIZE/8)) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / (QI4_K/8); + + const int * scales = (const int *) bxi->scales; + + const int ksc = threadIdx.x % (WARP_SIZE/8); + + // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8 + int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits + scales8 |= (scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits + + x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8; + } +} + +template +static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mul_mat( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) { + + GGML_UNUSED(x_qh); + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2*((k0 % 16) / 8); + + const int index_y = j * WARP_SIZE + (QR4_K*k0) % WARP_SIZE; + sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_K_q8_1_impl_mmq( + &x_ql[i * (WARP_SIZE + 1) + k0], &y_qs[index_y], sc, sc+8, x_dm[i * (WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[index_y/QI8_1]); + } + } +} + +template static __device__ __forceinline__ void load_tiles_q5_K( + const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) { + GGML_UNUSED(x_qh); + + const int kbx = 0; // threadIdx.x / QI5_K + const int kqsx = threadIdx.x; // threadIdx.x % QI5_K + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + threadIdx.y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride + kbx; + const int ky = QR5_K*kqsx; + + const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx); + const int ql0 = (ql >> 0) & 0x0F0F0F0F; + const int ql1 = (ql >> 4) & 0x0F0F0F0F; + + const int qh = get_int_from_uint8_aligned(bxi->qh, kqsx % (QI5_K/4)); + const int qh0 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 0)) << 4) & 0x10101010; + const int qh1 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 1)) << 4) & 0x10101010; + + const int kq0 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + 0; + const int kq1 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + (QI5_K/4); + + x_ql[i * (2*WARP_SIZE + 1) + kq0] = ql0 | qh0; + x_ql[i * (2*WARP_SIZE + 1) + kq1] = ql1 | qh1; + } + + const int blocks_per_tile_x_row = WARP_SIZE / QI5_K; // == 1 if QK_K == 256 + const int kbxd = threadIdx.x % blocks_per_tile_x_row; // == 0 if QK_K == 256 + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_K) { + int i = (i0 + threadIdx.y * QI5_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride + kbxd; + + x_dm[i * (WARP_SIZE/QI5_K) + i / QI5_K + kbxd] = bxi->dm; + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) { + int i = (i0 + threadIdx.y * 8 + threadIdx.x / (WARP_SIZE/8)) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / (QI5_K/8); + + const int * scales = (const int *) bxi->scales; + + const int ksc = threadIdx.x % (WARP_SIZE/8); + + // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8 + int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits + scales8 |= (scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits + + x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8; + } +} + +template +static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mul_mat( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) { + + GGML_UNUSED(x_qh); + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2 * ((k0 % 16) / 8); + + const int index_x = i * (QR5_K*WARP_SIZE + 1) + QR5_K*k0; + const int index_y = j * WARP_SIZE + (QR5_K*k0) % WARP_SIZE; + sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q5_K_q8_1_impl_mmq( + &x_ql[index_x], &y_qs[index_y], sc, sc+8, x_dm[i * (WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[index_y/QI8_1]); + } + } +} + +template static __device__ __forceinline__ void load_tiles_q6_K( + const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, + int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) { + GGML_UNUSED(x_qh); + + const int kbx = 0; // threadIdx.x / QI6_K + const int kqsx = threadIdx.x; // threadIdx.x % QI6_K + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + threadIdx.y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + kbx; + const int ky = QR6_K*kqsx; + + const int ql = get_int_from_uint8(bxi->ql, kqsx); + const int ql0 = (ql >> 0) & 0x0F0F0F0F; + const int ql1 = (ql >> 4) & 0x0F0F0F0F; + + const int qh = get_int_from_uint8(bxi->qh, (QI6_K/4) * (kqsx / (QI6_K/2)) + kqsx % (QI6_K/4)); + const int qh0 = ((qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) << 4) & 0x30303030; + const int qh1 = (qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) & 0x30303030; + + const int kq0 = ky - ky % QI6_K + threadIdx.x % (QI6_K/2) + 0; + const int kq1 = ky - ky % QI6_K + threadIdx.x % (QI6_K/2) + (QI6_K/2); + + x_ql[i * (2*WARP_SIZE + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020); + x_ql[i * (2*WARP_SIZE + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020); + } + + const int blocks_per_tile_x_row = WARP_SIZE / QI6_K; // == 1 if QK_K == 256 + const int kbxd = threadIdx.x % blocks_per_tile_x_row; // == 0 if QK_K == 256 + float * x_dmf = (float *) x_dm; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI6_K) { + int i = (i0 + threadIdx.y * QI6_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + kbxd; + + x_dmf[i * (WARP_SIZE/QI6_K) + i / QI6_K + kbxd] = bxi->d; + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) { + int i = (i0 + threadIdx.y * 8 + threadIdx.x / (WARP_SIZE/8)) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / 4; + + x_sc[i * (WARP_SIZE/8) + i / 8 + threadIdx.x % (WARP_SIZE/8)] = get_int_from_int8(bxi->scales, threadIdx.x % (QI6_K/8)); + } +} + +template +static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mul_mat( + const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, + const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) { + + GGML_UNUSED(x_qh); + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + const float * x_dmf = (const float *) x_dm; + const float * y_df = (const float *) y_ds; + + const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/8]); + + const int index_x = i * (QR6_K*WARP_SIZE + 1) + QR6_K*k0; + const int index_y = j * WARP_SIZE + (QR6_K*k0) % WARP_SIZE; + sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q6_K_q8_1_impl_mmq( + &x_ql[index_x], &y_qs[index_y], sc, x_dmf[i * (WARP_SIZE/QI6_K) + i/QI6_K], &y_df[index_y/QI8_1]); + } + } +} + +// ------------------------------------------------------------------------------------------------------------------------------------- + +template +struct mmq_type_traits; + +template +struct mmq_type_traits { + static constexpr bool need_sum = true; + static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0; + static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_0_q8_1_mul_mat; +}; + +template +struct mmq_type_traits { + static constexpr bool need_sum = true; + static constexpr int vdr = VDR_Q4_1_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1; + static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_1_q8_1_mul_mat; +}; + +template +struct mmq_type_traits { + static constexpr bool need_sum = false; + static constexpr int vdr = VDR_Q5_0_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0; + static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_0_q8_1_mul_mat; +}; + +template +struct mmq_type_traits { + static constexpr bool need_sum = true; + static constexpr int vdr = VDR_Q5_1_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1; + static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_1_q8_1_mul_mat; +}; + +template +struct mmq_type_traits { + static constexpr bool need_sum = false; + static constexpr int vdr = VDR_Q8_0_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0; + static constexpr vec_dot_mmq_t vec_dot = vec_dot_q8_0_q8_1_mul_mat; +}; + +template +struct mmq_type_traits { + static constexpr bool need_sum = false; + static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_K; + static constexpr vec_dot_mmq_t vec_dot = vec_dot_q2_K_q8_1_mul_mat; +}; + +template +struct mmq_type_traits { + static constexpr bool need_sum = false; + static constexpr int vdr = VDR_Q3_K_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K; + static constexpr vec_dot_mmq_t vec_dot = vec_dot_q3_K_q8_1_mul_mat; +}; + +template +struct mmq_type_traits { + static constexpr bool need_sum = true; + static constexpr int vdr = VDR_Q4_K_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K; + static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_K_q8_1_mul_mat; +}; + +template +struct mmq_type_traits { + static constexpr bool need_sum = true; + static constexpr int vdr = VDR_Q5_K_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K; + static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_K_q8_1_mul_mat; +}; + +template +struct mmq_type_traits { + static constexpr bool need_sum = false; + static constexpr int vdr = VDR_Q6_K_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K; + static constexpr vec_dot_mmq_t vec_dot = vec_dot_q6_K_q8_1_mul_mat; +}; + +template +#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) +#if defined(RDNA3) || defined(RDNA2) + __launch_bounds__(WARP_SIZE*nwarps, 2) +#endif // defined(RDNA3) || defined(RDNA2) +#else +#if __CUDA_ARCH__ >= CC_VOLTA + __launch_bounds__(WARP_SIZE*nwarps, 1) +#else + __launch_bounds__(WARP_SIZE*nwarps, type == GGML_TYPE_Q2_K ? 1 : 2) +#endif // __CUDA_ARCH__ >= CC_VOLTA +#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) +static __global__ void mul_mat_q( + const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst, + const int ne00, const int ne01, const int stride00, const int ne10, const int ne11, const int ne0) { + + // Skip unused template specializations for faster compilation: + if (mmq_x > get_mmq_x_max_device()) { + NO_DEVICE_CODE; + return; + } + + constexpr int qk = ggml_cuda_type_traits::qk; + constexpr int qr = ggml_cuda_type_traits::qr; + constexpr int qi = ggml_cuda_type_traits::qi; + constexpr int mmq_y = get_mmq_y_device(mmq_x); + constexpr bool need_sum = mmq_type_traits::need_sum; + constexpr int vdr = mmq_type_traits::vdr; + constexpr load_tiles_mmq_t load_tiles = mmq_type_traits::load_tiles; + constexpr vec_dot_mmq_t vec_dot = mmq_type_traits::vec_dot; + + constexpr tile_x_sizes txs = get_tile_x_sizes_device(type); + + extern __shared__ char data_mul_mat_q[]; + int * tile_x_ql = (int *) data_mul_mat_q; + half2 * tile_x_dm = (half2 *) (tile_x_ql + txs.ql); + int * tile_x_qh = (int *) (tile_x_dm + txs.dm); + int * tile_x_sc = (int *) (tile_x_qh + txs.qh); + int * tile_y_qs = (int *) (tile_x_sc + txs.sc); // [mmq_x * WARP_SIZE] + half2 * tile_y_ds = (half2 *) (tile_y_qs + mmq_x*WARP_SIZE); // [mmq_x * WARP_SIZE/QI8_1]; + + const block_q8_1 * y = (const block_q8_1 *) yc; + + const int blocks_per_row_x = ne00 / qk; + const int blocks_per_col_y = ne10 / QK8_1; + const int blocks_per_warp = WARP_SIZE / qi; + + const int & ne1 = ne11; + + const int tile_x_max_i = ne01 - blockIdx.x*mmq_y - 1; + + float sum[(mmq_x/nwarps) * (mmq_y/WARP_SIZE)] = {0.0f}; + + for (int kb0 = 0; kb0 < blocks_per_row_x; kb0 += blocks_per_warp) { + + load_tiles(x, tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, stride00*blockIdx.x*mmq_y + kb0, tile_x_max_i, stride00); + +#pragma unroll + for (int kr = 0; kr < qr; ++kr) { + const int kqs = kr*WARP_SIZE + threadIdx.x; + const int kbxd = kqs / QI8_1; + +#pragma unroll + for (int i0 = 0; i0 < mmq_x; i0 += nwarps) { + const int i = min(blockIdx.y*mmq_x + threadIdx.y + i0, ne11-1); // to prevent out-of-bounds memory accesses + + const block_q8_1 * by0 = &y[i*blocks_per_col_y + kb0 * (qk/QK8_1) + kbxd]; + + const int index_y = (i0 + threadIdx.y) * WARP_SIZE + kqs % WARP_SIZE; + tile_y_qs[index_y] = get_int_from_int8_aligned(by0->qs, threadIdx.x % QI8_1); + } + +#pragma unroll + for (int ids0 = 0; ids0 < mmq_x; ids0 += nwarps * QI8_1) { + const int ids = (ids0 + threadIdx.y * QI8_1 + threadIdx.x / (WARP_SIZE/QI8_1)) % mmq_x; + const int kby = threadIdx.x % (WARP_SIZE/QI8_1); + const int i_y_eff = min(blockIdx.y*mmq_x + ids, ne11-1); + + // if the sum is not needed it's faster to transform the scale to f32 ahead of time + const half2 * dsi_src = &y[i_y_eff*blocks_per_col_y + kb0 * (qk/QK8_1) + kr*(WARP_SIZE/QI8_1) + kby].ds; + half2 * dsi_dst = &tile_y_ds[ids * (WARP_SIZE/QI8_1) + kby]; + if (need_sum) { + *dsi_dst = *dsi_src; + } else { + float * dfi_dst = (float *) dsi_dst; + *dfi_dst = __low2float(*dsi_src); + } + } + + __syncthreads(); + +// #pragma unroll // unrolling this loop causes too much register pressure + for (int k0 = kr*WARP_SIZE/qr; k0 < (kr+1)*WARP_SIZE/qr; k0 += vdr) { + vec_dot(tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, tile_y_qs, tile_y_ds, sum, k0); + } + + __syncthreads(); + } + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = blockIdx.y*mmq_x + j0 + threadIdx.y; + + if (j >= ne1) { + return; + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + const int i = blockIdx.x*mmq_y + i0 + threadIdx.x; + + if (need_check && i >= ne0) { + continue; + } + + dst[j*ne0 + i] = sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE]; + } + } +} + +struct mmq_args { + const char * x; const char * y; float * dst; + int64_t ne00; int64_t ne01; int64_t stride00; + int64_t ne10; int64_t ne11; + int64_t ne0; +}; + +template +static void launch_mul_mat_q(const mmq_args & args, cudaStream_t stream) { + const int id = ggml_cuda_get_device(); + const int cc = ggml_cuda_info().devices[id].cc; + const int mmq_y = get_mmq_y_host(cc, mmq_x); + + const int block_num_x = (args.ne01 + mmq_y - 1) / mmq_y; + const int block_num_y = (args.ne11 + mmq_x - 1) / mmq_x; + const dim3 block_nums(block_num_x, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, nwarps, 1); + + const tile_x_sizes txs = get_tile_x_sizes_host(type, mmq_y); + const int shmem_x = txs.ql*sizeof(int) + txs.dm*sizeof(half2) + txs.qh*sizeof(int) + txs.sc*sizeof(int); + const int shmem_y = mmq_x*WARP_SIZE*sizeof(int) + mmq_x*(WARP_SIZE/QI8_1)*sizeof(half2); + const int shmem = shmem_x + shmem_y; + +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) + static bool shmem_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; + if (!shmem_limit_raised[id]) { + CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem)); + CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem)); + shmem_limit_raised[id] = true; + } +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) + + if (args.ne01 % mmq_y == 0) { + const bool need_check = false; + mul_mat_q<<>> + (args.x, args.y, args.dst, args.ne00, args.ne01, args.stride00, args.ne10, args.ne11, args.ne0); + } else { + const bool need_check = true; + mul_mat_q<<>> + (args.x, args.y, args.dst, args.ne00, args.ne01, args.stride00, args.ne10, args.ne11, args.ne0); + } +} + +template +void mul_mat_q_case(const mmq_args & args, cudaStream_t stream) { + const int id = ggml_cuda_get_device(); + const int nsm = ggml_cuda_info().devices[id].nsm; + const int cc = ggml_cuda_info().devices[id].cc; + + const int mmq_x_max = get_mmq_x_max_host(cc); + const int mmq_y = get_mmq_y_host(cc, mmq_x_max); + const int block_num_y = (args.ne01 + mmq_y - 1) / mmq_y; + + int mmq_x_best = 0; + int nwaves_best = INT_MAX; + + for (int mmq_x = 8; mmq_x <= mmq_x_max && nwaves_best > 1; mmq_x += 8) { + const int block_num_x = (args.ne11 + mmq_x - 1) / mmq_x; + const int nwaves = (block_num_x*block_num_y + nsm - 1) / nsm; + + if (nwaves < nwaves_best) { + mmq_x_best = mmq_x; + nwaves_best = nwaves; + } + } + + switch (mmq_x_best) { + case 8: + launch_mul_mat_q(args, stream); + break; + case 16: + launch_mul_mat_q(args, stream); + break; + case 24: + launch_mul_mat_q(args, stream); + break; + case 32: + launch_mul_mat_q(args, stream); + break; + case 40: + launch_mul_mat_q(args, stream); + break; + case 48: + launch_mul_mat_q(args, stream); + break; + case 56: + launch_mul_mat_q(args, stream); + break; + case 64: + launch_mul_mat_q(args, stream); + break; + case 72: + launch_mul_mat_q(args, stream); + break; + case 80: + launch_mul_mat_q(args, stream); + break; + case 88: + launch_mul_mat_q(args, stream); + break; + case 96: + launch_mul_mat_q(args, stream); + break; + case 104: + launch_mul_mat_q(args, stream); + break; + case 112: + launch_mul_mat_q(args, stream); + break; + case 120: + launch_mul_mat_q(args, stream); + break; + case 128: + launch_mul_mat_q(args, stream); + break; + default: + GGML_ASSERT(false); + break; + } +} + +#define DECL_MMQ_CASE(type) \ + template void mul_mat_q_case(const mmq_args & args, cudaStream_t stream) \ + +extern DECL_MMQ_CASE(GGML_TYPE_Q4_0); +extern DECL_MMQ_CASE(GGML_TYPE_Q4_1); +extern DECL_MMQ_CASE(GGML_TYPE_Q5_0); +extern DECL_MMQ_CASE(GGML_TYPE_Q5_1); +extern DECL_MMQ_CASE(GGML_TYPE_Q8_0); +extern DECL_MMQ_CASE(GGML_TYPE_Q2_K); +extern DECL_MMQ_CASE(GGML_TYPE_Q3_K); +extern DECL_MMQ_CASE(GGML_TYPE_Q4_K); +extern DECL_MMQ_CASE(GGML_TYPE_Q5_K); +extern DECL_MMQ_CASE(GGML_TYPE_Q6_K); + +// ------------------------------------------------------------------------------------------------------------------------- void ggml_cuda_op_mul_mat_q( ggml_backend_cuda_context & ctx, diff --git a/llama/ggml-cuda/mmvq.cu b/llama/ggml-cuda/mmvq.cu index 65cc1bca..5f056e91 100644 --- a/llama/ggml-cuda/mmvq.cu +++ b/llama/ggml-cuda/mmvq.cu @@ -1,9 +1,47 @@ #include "mmvq.cuh" #include "vecdotq.cuh" -typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs); +typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs); -template +static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) { + return type == GGML_TYPE_Q4_0 ? vec_dot_q4_0_q8_1 : + type == GGML_TYPE_Q4_1 ? vec_dot_q4_1_q8_1 : + type == GGML_TYPE_Q5_0 ? vec_dot_q5_0_q8_1 : + type == GGML_TYPE_Q5_1 ? vec_dot_q5_1_q8_1 : + type == GGML_TYPE_Q8_0 ? vec_dot_q8_0_q8_1 : + type == GGML_TYPE_Q2_K ? vec_dot_q2_K_q8_1 : + type == GGML_TYPE_Q3_K ? vec_dot_q3_K_q8_1 : + type == GGML_TYPE_Q4_K ? vec_dot_q4_K_q8_1 : + type == GGML_TYPE_Q5_K ? vec_dot_q5_K_q8_1 : + type == GGML_TYPE_Q6_K ? vec_dot_q6_K_q8_1 : + type == GGML_TYPE_IQ2_XXS ? vec_dot_iq2_xxs_q8_1 : + type == GGML_TYPE_IQ2_XS ? vec_dot_iq2_xs_q8_1 : + type == GGML_TYPE_IQ2_S ? vec_dot_iq2_s_q8_1 : + type == GGML_TYPE_IQ3_XXS ? vec_dot_iq3_xxs_q8_1 : + type == GGML_TYPE_IQ1_S ? vec_dot_iq1_s_q8_1 : + type == GGML_TYPE_IQ1_M ? vec_dot_iq1_m_q8_1 : + type == GGML_TYPE_IQ4_NL ? vec_dot_iq4_nl_q8_1 : + type == GGML_TYPE_IQ4_XS ? vec_dot_iq4_xs_q8_1 : + type == GGML_TYPE_IQ3_S ? vec_dot_iq3_s_q8_1 : + nullptr; +} + +static constexpr __device__ int get_vdr_mmvq(ggml_type type) { + return type == GGML_TYPE_Q4_0 ? VDR_Q4_0_Q8_1_MMVQ : + type == GGML_TYPE_Q4_1 ? VDR_Q4_1_Q8_1_MMVQ : + type == GGML_TYPE_Q5_0 ? VDR_Q5_0_Q8_1_MMVQ : + type == GGML_TYPE_Q5_1 ? VDR_Q5_1_Q8_1_MMVQ : + type == GGML_TYPE_Q8_0 ? VDR_Q8_0_Q8_1_MMVQ : + type == GGML_TYPE_Q2_K ? VDR_Q2_K_Q8_1_MMVQ : + type == GGML_TYPE_Q3_K ? VDR_Q3_K_Q8_1_MMVQ : + type == GGML_TYPE_Q4_K ? VDR_Q4_K_Q8_1_MMVQ : + type == GGML_TYPE_Q5_K ? VDR_Q5_K_Q8_1_MMVQ : + type == GGML_TYPE_Q6_K ? VDR_Q6_K_Q8_1_MMVQ : + type == GGML_TYPE_IQ4_NL ? VDR_Q4_K_Q8_1_MMVQ : + 1; +} + +template #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) // tell the compiler to use as many registers as it wants, see nwarps definition below __launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1) @@ -12,6 +50,12 @@ static __global__ void mul_mat_vec_q( const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + constexpr int qk = ggml_cuda_type_traits::qk; + constexpr int qi = ggml_cuda_type_traits::qi; + constexpr int vdr = get_vdr_mmvq(type); + + constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type); + #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3)) constexpr int nwarps = 1; constexpr int rows_per_cuda_block = 1; @@ -29,7 +73,6 @@ static __global__ void mul_mat_vec_q( // partial sum for each thread float tmp[ncols_y][rows_per_cuda_block] = {0.0f}; - const block_q_t * x = (const block_q_t *) vx; const block_q8_1 * y = (const block_q8_1 *) vy; for (int kbx = tid / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) { @@ -42,8 +85,7 @@ static __global__ void mul_mat_vec_q( for (int j = 0; j < ncols_y; ++j) { #pragma unroll for (int i = 0; i < rows_per_cuda_block; ++i) { - tmp[j][i] += vec_dot_q_cuda( - &x[kbx + (row0 + i)*blocks_per_row_x], &y[j*blocks_per_col_y + kby], kqs); + tmp[j][i] += vec_dot_q_cuda(vx, &y[j*blocks_per_col_y + kby], (row0 + i)*blocks_per_row_x + kbx, kqs); } } } @@ -81,12 +123,12 @@ static __global__ void mul_mat_vec_q( } } -template +template static void mul_mat_vec_q_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - GGML_ASSERT(ncols_x % qk == 0); + GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0); GGML_ASSERT(ncols_y <= MMVQ_MAX_BATCH_SIZE); int id = ggml_cuda_get_device(); @@ -124,36 +166,28 @@ static void mul_mat_vec_q_cuda( switch (ncols_y) { case 1: - mul_mat_vec_q<1, qk, qi, block_q_t, vdr, vec_dot> - <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); + mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); break; case 2: - mul_mat_vec_q<2, qk, qi, block_q_t, vdr, vec_dot> - <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); + mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); break; case 3: - mul_mat_vec_q<3, qk, qi, block_q_t, vdr, vec_dot> - <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); + mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); break; case 4: - mul_mat_vec_q<4, qk, qi, block_q_t, vdr, vec_dot> - <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); + mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); break; case 5: - mul_mat_vec_q<5, qk, qi, block_q_t, vdr, vec_dot> - <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); + mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); break; case 6: - mul_mat_vec_q<6, qk, qi, block_q_t, vdr, vec_dot> - <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); + mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); break; case 7: - mul_mat_vec_q<7, qk, qi, block_q_t, vdr, vec_dot> - <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); + mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); break; case 8: - mul_mat_vec_q<8, qk, qi, block_q_t, vdr, vec_dot> - <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); + mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); break; default: GGML_ASSERT(false); @@ -165,152 +199,133 @@ static void mul_mat_vec_q4_0_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - mul_mat_vec_q_cuda - (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); } static void mul_mat_vec_q4_1_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - mul_mat_vec_q_cuda - (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); } static void mul_mat_vec_q5_0_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - mul_mat_vec_q_cuda - (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); } static void mul_mat_vec_q5_1_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - mul_mat_vec_q_cuda - (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); } static void mul_mat_vec_q8_0_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - mul_mat_vec_q_cuda - (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); } static void mul_mat_vec_q2_K_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - mul_mat_vec_q_cuda - (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); } static void mul_mat_vec_q3_K_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - mul_mat_vec_q_cuda - (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); } static void mul_mat_vec_q4_K_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - mul_mat_vec_q_cuda - (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); } static void mul_mat_vec_q5_K_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - mul_mat_vec_q_cuda - (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); } static void mul_mat_vec_q6_K_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - mul_mat_vec_q_cuda - (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); } static void mul_mat_vec_iq2_xxs_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - mul_mat_vec_q_cuda - (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); } static void mul_mat_vec_iq2_xs_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - mul_mat_vec_q_cuda - (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); } static void mul_mat_vec_iq2_s_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - mul_mat_vec_q_cuda - (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); } static void mul_mat_vec_iq3_xxs_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - mul_mat_vec_q_cuda - (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); } static void mul_mat_vec_iq1_s_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - mul_mat_vec_q_cuda - (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); } static void mul_mat_vec_iq1_m_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - mul_mat_vec_q_cuda - (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); } static void mul_mat_vec_iq4_nl_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - mul_mat_vec_q_cuda - (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); } static void mul_mat_vec_iq4_xs_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - mul_mat_vec_q_cuda - (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); } static void mul_mat_vec_iq3_s_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - mul_mat_vec_q_cuda - (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); } void ggml_cuda_op_mul_mat_vec_q( diff --git a/llama/ggml-cuda/norm.cu b/llama/ggml-cuda/norm.cu index 86f77453..30866d51 100644 --- a/llama/ggml-cuda/norm.cu +++ b/llama/ggml-cuda/norm.cu @@ -170,6 +170,8 @@ void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { float * dst_d = (float *)dst->data; cudaStream_t stream = ctx.stream(); + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); @@ -188,6 +190,8 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) float * dst_d = (float *)dst->data; cudaStream_t stream = ctx.stream(); + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); @@ -202,6 +206,8 @@ void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { float * dst_d = (float *)dst->data; cudaStream_t stream = ctx.stream(); + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); diff --git a/llama/ggml-cuda/rope.cu b/llama/ggml-cuda/rope.cu index 4b0d2e5a..596fb7c1 100644 --- a/llama/ggml-cuda/rope.cu +++ b/llama/ggml-cuda/rope.cu @@ -1,7 +1,7 @@ #include "rope.cuh" struct rope_corr_dims { - float v[4]; + float v[2]; }; static __device__ float rope_yarn_ramp(const float low, const float high, const int i0) { @@ -13,8 +13,7 @@ static __device__ float rope_yarn_ramp(const float low, const float high, const // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. static __device__ void rope_yarn( float theta_extrap, float freq_scale, rope_corr_dims corr_dims, int64_t i0, float ext_factor, float mscale, - float * cos_theta, float * sin_theta -) { + float * cos_theta, float * sin_theta) { // Get n-d rotational scaling corrected for extrapolation float theta_interp = freq_scale * theta_extrap; float theta = theta_interp; @@ -29,27 +28,38 @@ static __device__ void rope_yarn( *sin_theta = sinf(theta) * mscale; } -// rope == RoPE == rotary positional embedding -template -static __global__ void rope( - const T * x, T * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base, - float ext_factor, float attn_factor, rope_corr_dims corr_dims -) { - const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y); +template +static __global__ void rope_norm( + const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows, + float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors) { + const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); - if (col >= ncols) { + if (i0 >= ne0) { return; } const int row = blockDim.x*blockIdx.x + threadIdx.x; - const int i = row*ncols + col; + + if (i0 >= n_dims) { + const int i = row*ne0 + i0; + + dst[i + 0] = x[i + 0]; + dst[i + 1] = x[i + 1]; + + return; + } + + const int i = row*ne0 + i0; const int i2 = row/p_delta_rows; - const int p = has_pos ? pos[i2] : 0; - const float theta_base = p*powf(freq_base, -float(col)/ncols); + const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f); - float cos_theta, sin_theta; - rope_yarn(theta_base, freq_scale, corr_dims, col, ext_factor, attn_factor, &cos_theta, &sin_theta); + const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; + + float cos_theta; + float sin_theta; + + rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); const float x0 = x[i + 0]; const float x1 = x[i + 1]; @@ -58,23 +68,20 @@ static __global__ void rope( dst[i + 1] = x0*sin_theta + x1*cos_theta; } -template +template static __global__ void rope_neox( - const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows, - float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims -) { - const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y); + const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows, + float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors) { + const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); - if (col >= ncols) { + if (i0 >= ne0) { return; } const int row = blockDim.x*blockIdx.x + threadIdx.x; - const int ib = col / n_dims; - const int ic = col % n_dims; - if (ib > 0) { - const int i = row*ncols + ib*n_dims + ic; + if (i0 >= n_dims) { + const int i = row*ne0 + i0; dst[i + 0] = x[i + 0]; dst[i + 1] = x[i + 1]; @@ -82,16 +89,17 @@ static __global__ void rope_neox( return; } - const int i = row*ncols + ib*n_dims + ic/2; + const int i = row*ne0 + i0/2; const int i2 = row/p_delta_rows; - float cur_rot = inv_ndims * ic - ib; + const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f); - const int p = has_pos ? pos[i2] : 0; - const float theta_base = p*freq_scale*powf(theta_scale, col/2.0f); + const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; - float cos_theta, sin_theta; - rope_yarn(theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta); + float cos_theta; + float sin_theta; + + rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); const float x0 = x[i + 0]; const float x1 = x[i + n_dims/2]; @@ -100,158 +108,117 @@ static __global__ void rope_neox( dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta; } -static __global__ void rope_glm_f32( - const float * x, float * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base, - int n_ctx -) { - const int col = blockDim.x*blockIdx.x + threadIdx.x; - const int half_n_dims = ncols/4; - - if (col >= half_n_dims) { - return; - } - - const int row = blockDim.y*blockIdx.y + threadIdx.y; - const int i = row*ncols + col; - const int i2 = row/p_delta_rows; - - const float col_theta_scale = powf(freq_base, -2.0f*col/ncols); - // FIXME: this is likely wrong - const int p = pos != nullptr ? pos[i2] : 0; - - const float theta = min(p, n_ctx - 2)*freq_scale*col_theta_scale; - const float sin_theta = sinf(theta); - const float cos_theta = cosf(theta); - - const float x0 = x[i + 0]; - const float x1 = x[i + half_n_dims]; - - dst[i + 0] = x0*cos_theta - x1*sin_theta; - dst[i + half_n_dims] = x0*sin_theta + x1*cos_theta; - - const float block_theta = ((float)max(p - n_ctx - 2, 0))*col_theta_scale; - const float sin_block_theta = sinf(block_theta); - const float cos_block_theta = cosf(block_theta); - - const float x2 = x[i + half_n_dims * 2]; - const float x3 = x[i + half_n_dims * 3]; - - dst[i + half_n_dims * 2] = x2*cos_block_theta - x3*sin_block_theta; - dst[i + half_n_dims * 3] = x2*sin_block_theta + x3*cos_block_theta; -} - - template -static void rope_cuda( - const T * x, T * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows, - float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream -) { - GGML_ASSERT(ncols % 2 == 0); +static void rope_norm_cuda( + const T * x, T * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows, + float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) { + GGML_ASSERT(ne0 % 2 == 0); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); - const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); - const dim3 block_nums(nrows, num_blocks_x, 1); - if (pos == nullptr) { - rope<<>>( - x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims - ); + const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); + const dim3 block_nums(nr, n_blocks_x, 1); + + const float theta_scale = powf(freq_base, -2.0f/n_dims); + + if (freq_factors == nullptr) { + rope_norm<<>>( + x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, + theta_scale, freq_factors + ); } else { - rope<<>>( - x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims - ); + rope_norm<<>>( + x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, + theta_scale, freq_factors + ); } } template static void rope_neox_cuda( - const T * x, T * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows, - float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream -) { - GGML_ASSERT(ncols % 2 == 0); + const T * x, T * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows, + float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) { + GGML_ASSERT(ne0 % 2 == 0); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); - const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); - const dim3 block_nums(nrows, num_blocks_x, 1); + const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); + const dim3 block_nums(nr, n_blocks_x, 1); const float theta_scale = powf(freq_base, -2.0f/n_dims); - const float inv_ndims = -1.0f / n_dims; - if (pos == nullptr) { + if (freq_factors == nullptr) { rope_neox<<>>( - x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, - theta_scale, inv_ndims - ); + x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, + theta_scale, freq_factors + ); } else { rope_neox<<>>( - x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, - theta_scale, inv_ndims - ); + x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims, + theta_scale, freq_factors + ); } } -static void rope_glm_f32_cuda( - const float * x, float * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows, - float freq_base, int n_ctx, cudaStream_t stream -) { - GGML_ASSERT(ncols % 4 == 0); - const dim3 block_dims(CUDA_ROPE_BLOCK_SIZE/4, 1, 1); - const int num_blocks_x = (ncols + CUDA_ROPE_BLOCK_SIZE - 1) / CUDA_ROPE_BLOCK_SIZE; - const dim3 block_nums(num_blocks_x, nrows, 1); - rope_glm_f32<<>>(x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, n_ctx); +static void rope_norm_cuda_f16( + const half * x, half * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows, + float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) { + + rope_norm_cuda(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream); } -static void rope_cuda_f16( - const half * x, half * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows, - float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream) { +static void rope_norm_cuda_f32( + const float * x, float * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows, + float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) { - rope_cuda(x, dst, ncols, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, stream); -} - -static void rope_cuda_f32( - const float * x, float * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows, - float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream) { - - rope_cuda(x, dst, ncols, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, stream); + rope_norm_cuda(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream); } static void rope_neox_cuda_f16( - const half * x, half * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows, - float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream) { + const half * x, half * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows, + float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) { - rope_neox_cuda(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, stream); + rope_neox_cuda(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream); } static void rope_neox_cuda_f32( - const float * x, float * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows, - float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream + const float * x, float * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows, + float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream ) { - rope_neox_cuda(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, stream); + rope_neox_cuda(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream); } void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; + const ggml_tensor * src2 = dst->src[2]; + const float * src0_d = (const float *)src0->data; const float * src1_d = (const float *)src1->data; + float * dst_d = (float *)dst->data; cudaStream_t stream = ctx.stream(); + GGML_ASSERT(ggml_is_contiguous(src0)); GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); GGML_ASSERT(src0->type == dst->type); const int64_t ne00 = src0->ne[0]; const int64_t ne01 = src0->ne[1]; - const int64_t ne2 = dst->ne[2]; - const int64_t nrows = ggml_nrows(src0); + const int64_t nr = ggml_nrows(src0); - //const int n_past = ((int32_t *) dst->op_params)[0]; - const int n_dims = ((int32_t *) dst->op_params)[1]; - const int mode = ((int32_t *) dst->op_params)[2]; - const int n_ctx = ((int32_t *) dst->op_params)[3]; - const int n_orig_ctx = ((int32_t *) dst->op_params)[4]; + //const int n_past = ((int32_t *) dst->op_params)[0]; + const int n_dims = ((int32_t *) dst->op_params)[1]; + const int mode = ((int32_t *) dst->op_params)[2]; + //const int n_ctx = ((int32_t *) dst->op_params)[3]; + const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; // RoPE alteration for extended context - float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; + float freq_base; + float freq_scale; + float ext_factor; + float attn_factor; + float beta_fast; + float beta_slow; + memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); @@ -259,47 +226,43 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); - const int32_t * pos = nullptr; - if ((mode & 1) == 0) { - GGML_ASSERT(src1->type == GGML_TYPE_I32); - GGML_ASSERT(src1->ne[0] == ne2); - pos = (const int32_t *) src1_d; + const bool is_neox = mode & 2; + + const int32_t * pos = (const int32_t *) src1_d; + + const float * freq_factors = nullptr; + if (src2 != nullptr) { + freq_factors = (const float *) src2->data; } - const bool is_neox = mode & 2; - const bool is_glm = mode & 4; - rope_corr_dims corr_dims; - ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims.v); + ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v); // compute - if (is_glm) { - GGML_ASSERT(false); - rope_glm_f32_cuda(src0_d, dst_d, ne00, nrows, pos, freq_scale, ne01, freq_base, n_ctx, stream); - } else if (is_neox) { + if (is_neox) { if (src0->type == GGML_TYPE_F32) { rope_neox_cuda_f32( - (const float *)src0_d, (float *)dst_d, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor, - attn_factor, corr_dims, stream + (const float *)src0_d, (float *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, + attn_factor, corr_dims, freq_factors, stream ); } else if (src0->type == GGML_TYPE_F16) { rope_neox_cuda_f16( - (const half *)src0_d, (half *)dst_d, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor, - attn_factor, corr_dims, stream + (const half *)src0_d, (half *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, + attn_factor, corr_dims, freq_factors, stream ); } else { GGML_ASSERT(false); } } else { if (src0->type == GGML_TYPE_F32) { - rope_cuda_f32( - (const float *)src0_d, (float *)dst_d, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor, - attn_factor, corr_dims, stream + rope_norm_cuda_f32( + (const float *)src0_d, (float *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, + attn_factor, corr_dims, freq_factors, stream ); } else if (src0->type == GGML_TYPE_F16) { - rope_cuda_f16( - (const half *)src0_d, (half *)dst_d, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor, - attn_factor, corr_dims, stream + rope_norm_cuda_f16( + (const half *)src0_d, (half *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, + attn_factor, corr_dims, freq_factors, stream ); } else { GGML_ASSERT(false); diff --git a/llama/ggml-cuda/vecdotq.cuh b/llama/ggml-cuda/vecdotq.cuh index 86b87fa9..b9573a7c 100644 --- a/llama/ggml-cuda/vecdotq.cuh +++ b/llama/ggml-cuda/vecdotq.cuh @@ -180,8 +180,8 @@ template static __device__ __forceinline__ float vec_dot_q5_1_q8_1_imp #define VDR_Q8_0_Q8_1_MMVQ 2 #define VDR_Q8_0_Q8_1_MMQ 8 -template static __device__ __forceinline__ float vec_dot_q8_0_q8_1_impl( - const int * v, const int * u, const float & d8_0, const float & d8_1) { +template static __device__ __forceinline__ T vec_dot_q8_0_q8_1_impl( + const int * v, const int * u, const T & d8_0, const T & d8_1) { #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics int sumi = 0; @@ -192,7 +192,7 @@ template static __device__ __forceinline__ float vec_dot_q8_0_q8_1_imp sumi = __dp4a(v[i], u[i], sumi); } - return d8_0*d8_1 * sumi; + return d8_0*d8_1 * ((T) sumi); #else NO_DEVICE_CODE; #endif // __CUDA_ARCH__ >= MIN_CC_DP4A @@ -566,9 +566,9 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq( } static __device__ __forceinline__ float vec_dot_q4_0_q8_1( - const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { - const block_q4_0 * bq4_0 = (const block_q4_0 *) vbq; + const block_q4_0 * bq4_0 = (const block_q4_0 *) vbq + kbx; int v[VDR_Q4_0_Q8_1_MMVQ]; int u[2*VDR_Q4_0_Q8_1_MMVQ]; @@ -585,9 +585,9 @@ static __device__ __forceinline__ float vec_dot_q4_0_q8_1( static __device__ __forceinline__ float vec_dot_q4_1_q8_1( - const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { - const block_q4_1 * bq4_1 = (const block_q4_1 *) vbq; + const block_q4_1 * bq4_1 = (const block_q4_1 *) vbq + kbx; int v[VDR_Q4_1_Q8_1_MMVQ]; int u[2*VDR_Q4_1_Q8_1_MMVQ]; @@ -603,9 +603,9 @@ static __device__ __forceinline__ float vec_dot_q4_1_q8_1( } static __device__ __forceinline__ float vec_dot_q5_0_q8_1( - const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { - const block_q5_0 * bq5_0 = (const block_q5_0 *) vbq; + const block_q5_0 * bq5_0 = (const block_q5_0 *) vbq + kbx; int vl[VDR_Q5_0_Q8_1_MMVQ]; int vh[VDR_Q5_0_Q8_1_MMVQ]; @@ -623,9 +623,9 @@ static __device__ __forceinline__ float vec_dot_q5_0_q8_1( } static __device__ __forceinline__ float vec_dot_q5_1_q8_1( - const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { - const block_q5_1 * bq5_1 = (const block_q5_1 *) vbq; + const block_q5_1 * bq5_1 = (const block_q5_1 *) vbq + kbx; int vl[VDR_Q5_1_Q8_1_MMVQ]; int vh[VDR_Q5_1_Q8_1_MMVQ]; @@ -643,9 +643,9 @@ static __device__ __forceinline__ float vec_dot_q5_1_q8_1( } static __device__ __forceinline__ float vec_dot_q8_0_q8_1( - const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { - const block_q8_0 * bq8_0 = (const block_q8_0 *) vbq; + const block_q8_0 * bq8_0 = (const block_q8_0 *) vbq + kbx; int v[VDR_Q8_0_Q8_1_MMVQ]; int u[VDR_Q8_0_Q8_1_MMVQ]; @@ -656,13 +656,13 @@ static __device__ __forceinline__ float vec_dot_q8_0_q8_1( u[i] = get_int_from_int8_aligned(bq8_1->qs, iqs + i); } - return vec_dot_q8_0_q8_1_impl(v, u, bq8_0->d, __low2half(bq8_1->ds)); + return vec_dot_q8_0_q8_1_impl(v, u, bq8_0->d, __low2half(bq8_1->ds)); } static __device__ __forceinline__ float vec_dot_q2_K_q8_1( - const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { - const block_q2_K * bq2_K = (const block_q2_K *) vbq; + const block_q2_K * bq2_K = (const block_q2_K *) vbq + kbx; const int bq8_offset = QR2_K * (iqs / QI8_1); const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2); @@ -683,9 +683,9 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1( } static __device__ __forceinline__ float vec_dot_q3_K_q8_1( - const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { - const block_q3_K * bq3_K = (const block_q3_K *) vbq; + const block_q3_K * bq3_K = (const block_q3_K *) vbq + kbx; const int bq8_offset = QR3_K * (iqs / (QI3_K/2)); const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2); @@ -710,10 +710,9 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1( } static __device__ __forceinline__ float vec_dot_q4_K_q8_1( - const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { -#ifndef GGML_QKK_64 - const block_q4_K * bq4_K = (const block_q4_K *) vbq; + const block_q4_K * bq4_K = (const block_q4_K *) vbq + kbx; int v[2]; int u[2*QR4_K]; @@ -754,59 +753,12 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1( } return vec_dot_q4_K_q8_1_impl_vmmq(v, u, sc, m, bq4_K->dm, d8); - -#else - -#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics - const block_q4_K * bq4_K = (const block_q4_K *) vbq; - - float sumf_d = 0.0f; - float sumf_m = 0.0f; - - uint16_t aux16[2]; - const uint8_t * s = (const uint8_t *)aux16; - - const uint16_t * a = (const uint16_t *)bq4_K->scales; - aux16[0] = a[0] & 0x0f0f; - aux16[1] = (a[0] >> 4) & 0x0f0f; - - const float dall = bq4_K->dm[0]; - const float dmin = bq4_K->dm[1]; - - const float d8_1 = __low2float(bq8_1[0].ds); - const float d8_2 = __low2float(bq8_1[1].ds); - - const int ui1 = *((const int *)bq8_1[0].qs + (iqs/2)); - const int ui2 = *((const int *)bq8_1[0].qs + (iqs/2) + 4); - const int ui3 = *((const int *)bq8_1[1].qs + (iqs/2)); - const int ui4 = *((const int *)bq8_1[1].qs + (iqs/2) + 4); - - const int * q4 = (const int *)bq4_K->qs + (iqs/2); - const int v1 = q4[0]; - const int v2 = q4[4]; - - const int dot1 = __dp4a(ui2, v2 & 0x0f0f0f0f, __dp4a(ui1, v1 & 0x0f0f0f0f, 0)); - const int dot2 = __dp4a(ui4, (v2 >> 4) & 0x0f0f0f0f, __dp4a(ui3, (v1 >> 4) & 0x0f0f0f0f, 0)); - const int dot3 = __dp4a(0x01010101, ui2, __dp4a(0x01010101, ui1, 0)); - const int dot4 = __dp4a(0x01010101, ui4, __dp4a(0x01010101, ui3, 0)); - - sumf_d += d8_1 * (dot1 * s[0]) + d8_2 * (dot2 * s[1]); - sumf_m += d8_1 * (dot3 * s[2]) + d8_2 * (dot4 * s[3]); - - return dall * sumf_d - dmin * sumf_m; - -#else - NO_DEVICE_CODE; -#endif // __CUDA_ARCH__ >= MIN_CC_DP4A - -#endif } static __device__ __forceinline__ float vec_dot_q5_K_q8_1( - const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { -#ifndef GGML_QKK_64 - const block_q5_K * bq5_K = (const block_q5_K *) vbq; + const block_q5_K * bq5_K = (const block_q5_K *) vbq + kbx; int vl[2]; int vh[2]; @@ -847,54 +799,12 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1( } return vec_dot_q5_K_q8_1_impl_vmmq(vl, vh, u, sc, m, bq5_K->dm, d8); - -#else - -#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics - const block_q5_K * bq5_K = (const block_q5_K *) vbq; - - const int8_t * s = bq5_K->scales; - - const float d = bq5_K->d; - - const float d8_1 = __low2half(bq8_1[0].ds); - const float d8_2 = __low2half(bq8_1[1].ds); - - const int ui1 = *((const int *)bq8_1[0].qs + (iqs/2)); - const int ui2 = *((const int *)bq8_1[0].qs + (iqs/2) + 4); - const int ui3 = *((const int *)bq8_1[1].qs + (iqs/2)); - const int ui4 = *((const int *)bq8_1[1].qs + (iqs/2) + 4); - - const int * ql = (const int *)bq5_K->qs + (iqs/2); - const int vl1 = ql[0]; - const int vl2 = ql[4]; - - const int step = 4 * (iqs/2); // 0, 4, 8, 12 - const int im = step/8; // = 0 for iqs = 0, 2, = 1 for iqs = 4, 6 - const int in = step%8; // 0, 4, 0, 4 - const int vh = (*((const int *)(bq5_K->qh + in))) >> im; - - const int v1 = (((vh << 4) & 0x10101010) ^ 0x10101010) | ((vl1 >> 0) & 0x0f0f0f0f); - const int v2 = (((vh << 2) & 0x10101010) ^ 0x10101010) | ((vl2 >> 0) & 0x0f0f0f0f); - const int v3 = (((vh >> 0) & 0x10101010) ^ 0x10101010) | ((vl1 >> 4) & 0x0f0f0f0f); - const int v4 = (((vh >> 2) & 0x10101010) ^ 0x10101010) | ((vl2 >> 4) & 0x0f0f0f0f); - - const float sumf_d = d8_1 * (__dp4a(ui1, v1, 0) * s[0] + __dp4a(ui2, v2, 0) * s[1]) - + d8_2 * (__dp4a(ui3, v3, 0) * s[2] + __dp4a(ui4, v4, 0) * s[3]); - - return d * sumf_d; - -#else - NO_DEVICE_CODE; -#endif // __CUDA_ARCH__ >= MIN_CC_DP4A - -#endif } static __device__ __forceinline__ float vec_dot_q6_K_q8_1( - const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { - const block_q6_K * bq6_K = (const block_q6_K *) vbq; + const block_q6_K * bq6_K = (const block_q6_K *) vbq + kbx; const int bq8_offset = 2 * QR6_K * (iqs / (QI6_K/2)) + (iqs % (QI6_K/2)) / (QI6_K/4); const int scale_offset = (QI6_K/4) * (iqs / (QI6_K/2)) + (iqs % (QI6_K/2)) / (QI6_K/8); @@ -918,9 +828,8 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1( } static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1( - const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { -#if QK_K == 256 - const block_iq2_xxs * bq2 = (const block_iq2_xxs *) vbq; + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { + const block_iq2_xxs * bq2 = (const block_iq2_xxs *) vbq + kbx; #if QR2_XXS == 8 const int ib32 = iqs; @@ -960,16 +869,12 @@ static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1( } return d * (sumi1 + sumi2); #endif -#else - NO_DEVICE_CODE; -#endif } static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1( - const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics -#if QK_K == 256 - const block_iq2_xs * bq2 = (const block_iq2_xs *) vbq; + const block_iq2_xs * bq2 = (const block_iq2_xs *) vbq + kbx; const int ib32 = iqs; const uint16_t * q2 = bq2->qs + 4*ib32; @@ -1002,18 +907,13 @@ static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1( GGML_UNUSED(ksigns64); NO_DEVICE_CODE; #endif -#else - GGML_UNUSED(ksigns64); - NO_DEVICE_CODE; -#endif } // TODO static __device__ __forceinline__ float vec_dot_iq2_s_q8_1( - const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics -#if QK_K == 256 - const block_iq2_s * bq2 = (const block_iq2_s *) vbq; + const block_iq2_s * bq2 = (const block_iq2_s *) vbq + kbx; const int ib32 = iqs; const int8_t * q8 = bq8_1[ib32].qs; @@ -1048,17 +948,12 @@ static __device__ __forceinline__ float vec_dot_iq2_s_q8_1( GGML_UNUSED(ksigns64); NO_DEVICE_CODE; #endif -#else - GGML_UNUSED(ksigns64); - NO_DEVICE_CODE; -#endif } static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1( - const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics -#if QK_K == 256 - const block_iq3_xxs * bq2 = (const block_iq3_xxs *) vbq; + const block_iq3_xxs * bq2 = (const block_iq3_xxs *) vbq + kbx; const int ib32 = iqs; const uint8_t * q3 = bq2->qs + 8*ib32; @@ -1082,17 +977,13 @@ static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1( #else NO_DEVICE_CODE; #endif -#else - NO_DEVICE_CODE; -#endif } // TODO: don't use lookup table for signs static __device__ __forceinline__ float vec_dot_iq3_s_q8_1( - const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics -#if QK_K == 256 - const block_iq3_s * bq2 = (const block_iq3_s *) vbq; + const block_iq3_s * bq2 = (const block_iq3_s *) vbq + kbx; const int ib32 = iqs; const uint8_t * qs = bq2->qs + 8*ib32; @@ -1114,15 +1005,11 @@ static __device__ __forceinline__ float vec_dot_iq3_s_q8_1( #else NO_DEVICE_CODE; #endif -#else - NO_DEVICE_CODE; -#endif } static __device__ __forceinline__ float vec_dot_iq1_s_q8_1( - const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { -#if QK_K == 256 - const block_iq1_s * bq1 = (const block_iq1_s *) vbq; + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { + const block_iq1_s * bq1 = (const block_iq1_s *) vbq + kbx; const int ib32 = iqs; int sumi = 0; @@ -1149,15 +1036,11 @@ static __device__ __forceinline__ float vec_dot_iq1_s_q8_1( const float d = d1q * __low2float (bq8_1[ib32].ds); const float m = d1q * __high2float(bq8_1[ib32].ds); return d * sumi + m * delta; -#else - NO_DEVICE_CODE; -#endif } static __device__ __forceinline__ float vec_dot_iq1_m_q8_1( - const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { -#if QK_K == 256 - const block_iq1_m * bq1 = (const block_iq1_m *) vbq; + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { + const block_iq1_m * bq1 = (const block_iq1_m *) vbq + kbx; const int ib32 = iqs; int sumi[2] = {0, 0}; @@ -1192,9 +1075,6 @@ static __device__ __forceinline__ float vec_dot_iq1_m_q8_1( scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); const float d = (float)scale.f16 * __low2float (bq8_1[ib32].ds); return d * ((sumi[0] + sumf[0]) * (2*((sc[ib32/2] >> 6*(ib32%2)) & 0x7) + 1) + (sumi[1] + sumf[1]) * (2*((sc[ib32/2] >> (6*(ib32%2)+3)) & 0x7) + 1)); -#else - NO_DEVICE_CODE; -#endif } #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics @@ -1214,9 +1094,9 @@ static __device__ __forceinline__ void get_int_from_table_16(const uint32_t & q4 #endif static __device__ __forceinline__ float vec_dot_iq4_nl_q8_1( - const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { - const block_iq4_nl * bq = (const block_iq4_nl *) vbq; + const block_iq4_nl * bq = (const block_iq4_nl *) vbq + kbx; #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics const uint16_t * q4 = (const uint16_t *)bq->qs + 2*iqs; @@ -1248,12 +1128,10 @@ static __device__ __forceinline__ float vec_dot_iq4_nl_q8_1( } static __device__ __forceinline__ float vec_dot_iq4_xs_q8_1( - const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { -#if QK_K == 256 #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics - - const block_iq4_xs * bq4 = (const block_iq4_xs *) vbq; + const block_iq4_xs * bq4 = (const block_iq4_xs *) vbq + kbx; const uint8_t * values = (const uint8_t *)kvalues_iq4nl; // iqs is 0...7 @@ -1270,11 +1148,7 @@ static __device__ __forceinline__ float vec_dot_iq4_xs_q8_1( sumi2 = __dp4a(v2, q8[j+4], sumi2); } return d * (sumi1 + sumi2); - #else - NO_DEVICE_CODE; -#endif -#else - return vec_dot_iq4_xs_q8_1(vbq, bq8_1, iqs); + return vec_dot_iq4_xs_q8_1(vbq, bq8_1, kbx, iqs); #endif } diff --git a/llama/ggml-impl.h b/llama/ggml-impl.h index 72210a26..557c828d 100644 --- a/llama/ggml-impl.h +++ b/llama/ggml-impl.h @@ -1,5 +1,5 @@ /** - * llama.cpp - git 059031b8c40e1f4ba60586842c5b1ed3ddf61842 + * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28 * * MIT License * @@ -43,6 +43,18 @@ #define MIN(a, b) ((a) < (b) ? (a) : (b)) #define MAX(a, b) ((a) > (b) ? (a) : (b)) +#if defined(_WIN32) + +#define m512bh(p) p +#define m512i(p) p + +#else + +#define m512bh(p) (__m512bh)(p) +#define m512i(p) (__m512i)(p) + +#endif + /** * Converts brain16 to float32. * @@ -158,6 +170,10 @@ extern "C" { #endif #endif +#if defined(__ARM_FEATURE_SVE) +#include +#endif + // 16-bit float // on Arm, we use __fp16 // on x86, we use uint16_t @@ -469,6 +485,34 @@ static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) { #include #endif +#if defined(__loongarch64) +#if defined(__loongarch_asx) +#include +#endif +#if defined(__loongarch_sx) +#include +#endif +#endif + +#if defined(__loongarch_asx) + +typedef union { + int32_t i; + float f; +} ft_union; + +/* float type data load instructions */ +static __m128 __lsx_vreplfr2vr_s(float val) { + ft_union fi_tmpval = {.f = val}; + return (__m128)__lsx_vreplgr2vr_w(fi_tmpval.i); +} + +static __m256 __lasx_xvreplfr2vr_s(float val) { + ft_union fi_tmpval = {.f = val}; + return (__m256)__lasx_xvreplgr2vr_w(fi_tmpval.i); +} +#endif + #ifdef __F16C__ #ifdef _MSC_VER diff --git a/llama/ggml-metal-darwin_arm64.m b/llama/ggml-metal-darwin_arm64.m index f4c2c412..cd151a17 100644 --- a/llama/ggml-metal-darwin_arm64.m +++ b/llama/ggml-metal-darwin_arm64.m @@ -1,5 +1,5 @@ /** - * llama.cpp - git 059031b8c40e1f4ba60586842c5b1ed3ddf61842 + * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28 * * MIT License * @@ -61,6 +61,10 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_MUL_ROW, GGML_METAL_KERNEL_TYPE_DIV, GGML_METAL_KERNEL_TYPE_DIV_ROW, + GGML_METAL_KERNEL_TYPE_REPEAT_F32, + GGML_METAL_KERNEL_TYPE_REPEAT_F16, + GGML_METAL_KERNEL_TYPE_REPEAT_I32, + GGML_METAL_KERNEL_TYPE_REPEAT_I16, GGML_METAL_KERNEL_TYPE_SCALE, GGML_METAL_KERNEL_TYPE_SCALE_4, GGML_METAL_KERNEL_TYPE_CLAMP, @@ -194,8 +198,10 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, - GGML_METAL_KERNEL_TYPE_ROPE_F32, - GGML_METAL_KERNEL_TYPE_ROPE_F16, + GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, + GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, + GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, + GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, GGML_METAL_KERNEL_TYPE_IM2COL_F16, GGML_METAL_KERNEL_TYPE_IM2COL_F32, GGML_METAL_KERNEL_TYPE_UPSCALE_F32, @@ -210,9 +216,9 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, + //GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261 GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, - GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, + //GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261 GGML_METAL_KERNEL_TYPE_CPY_F32_F16, GGML_METAL_KERNEL_TYPE_CPY_F32_F32, GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, @@ -407,10 +413,6 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { // dictionary of preprocessor macros NSMutableDictionary * prep = [NSMutableDictionary dictionary]; -#ifdef GGML_QKK_64 - prep[@"GGML_QKK_64"] = @(1); -#endif - MTLCompileOptions* options = [MTLCompileOptions new]; options.preprocessorMacros = prep; @@ -515,6 +517,10 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I32, repeat_i32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I16, repeat_i16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true); @@ -648,8 +654,10 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true); @@ -664,9 +672,9 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, ctx->support_simdgroup_mm); + //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, ctx->support_simdgroup_reduction); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction); + //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true); @@ -776,6 +784,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const case GGML_OP_ACC: case GGML_OP_MUL: case GGML_OP_DIV: + case GGML_OP_REPEAT: case GGML_OP_SCALE: case GGML_OP_CLAMP: case GGML_OP_SQR: @@ -800,6 +809,15 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const case GGML_OP_LEAKY_RELU: return true; case GGML_OP_FLASH_ATTN_EXT: + if (op->src[1]->type != GGML_TYPE_F16) { + return false; + } + if (op->src[2]->type != GGML_TYPE_F16) { + return false; + } + if (op->src[0]->ne[0] == 256) { + return false; + } return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: @@ -953,22 +971,32 @@ static enum ggml_status ggml_metal_graph_compute( const int64_t ne10 = src1 ? src1->ne[0] : 0; const int64_t ne11 = src1 ? src1->ne[1] : 0; const int64_t ne12 = src1 ? src1->ne[2] : 0; - const int64_t ne13 = src1 ? src1->ne[3] : 0; UNUSED(ne13); + const int64_t ne13 = src1 ? src1->ne[3] : 0; const uint64_t nb10 = src1 ? src1->nb[0] : 0; const uint64_t nb11 = src1 ? src1->nb[1] : 0; const uint64_t nb12 = src1 ? src1->nb[2] : 0; - const uint64_t nb13 = src1 ? src1->nb[3] : 0; UNUSED(nb13); + const uint64_t nb13 = src1 ? src1->nb[3] : 0; - const int64_t ne0 = dst ? dst->ne[0] : 0; - const int64_t ne1 = dst ? dst->ne[1] : 0; - const int64_t ne2 = dst ? dst->ne[2] : 0; - const int64_t ne3 = dst ? dst->ne[3] : 0; + const int64_t ne20 = src2 ? src2->ne[0] : 0; + const int64_t ne21 = src2 ? src2->ne[1] : 0; + const int64_t ne22 = src2 ? src2->ne[2] : 0; GGML_UNUSED(ne22); + const int64_t ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23); - const uint64_t nb0 = dst ? dst->nb[0] : 0; - const uint64_t nb1 = dst ? dst->nb[1] : 0; - const uint64_t nb2 = dst ? dst->nb[2] : 0; - const uint64_t nb3 = dst ? dst->nb[3] : 0; + const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20); + const uint64_t nb21 = src2 ? src2->nb[1] : 0; + const uint64_t nb22 = src2 ? src2->nb[2] : 0; + const uint64_t nb23 = src2 ? src2->nb[3] : 0; + + const int64_t ne0 = dst ? dst->ne[0] : 0; + const int64_t ne1 = dst ? dst->ne[1] : 0; + const int64_t ne2 = dst ? dst->ne[2] : 0; + const int64_t ne3 = dst ? dst->ne[3] : 0; + + const uint64_t nb0 = dst ? dst->nb[0] : 0; + const uint64_t nb1 = dst ? dst->nb[1] : 0; + const uint64_t nb2 = dst ? dst->nb[2] : 0; + const uint64_t nb3 = dst ? dst->nb[3] : 0; const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT; const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT; @@ -996,10 +1024,10 @@ static enum ggml_status ggml_metal_graph_compute( switch (dst->op) { case GGML_OP_CONCAT: { - const int64_t nb = ne00; - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline; + const int32_t dim = ((int32_t *) dst->op_params)[0]; + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; @@ -1028,7 +1056,7 @@ static enum ggml_status ggml_metal_graph_compute( [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24]; [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25]; [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26]; - [encoder setBytes:&nb length:sizeof(nb) atIndex:27]; + [encoder setBytes:&dim length:sizeof(dim) atIndex:27]; const int nth = MIN(1024, ne0); @@ -1038,11 +1066,14 @@ static enum ggml_status ggml_metal_graph_compute( case GGML_OP_MUL: case GGML_OP_DIV: { + GGML_ASSERT(src0t == GGML_TYPE_F32); + GGML_ASSERT(src1t == GGML_TYPE_F32); + const size_t offs = 0; bool bcast_row = false; - int64_t nb = ne00; + int64_t nb = ne00; // used by the "row" kernels id pipeline = nil; @@ -1111,6 +1142,42 @@ static enum ggml_status ggml_metal_graph_compute( [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } } break; + case GGML_OP_REPEAT: + { + id pipeline; + + switch (src0t) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F16].pipeline; break; + case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I32].pipeline; break; + case GGML_TYPE_I16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I16].pipeline; break; + default: GGML_ASSERT(false); + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; + [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11]; + [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12]; + [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13]; + [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15]; + [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16]; + [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17]; + + const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0); + + [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; case GGML_OP_ACC: { GGML_ASSERT(src0t == GGML_TYPE_F32); @@ -1488,7 +1555,6 @@ static enum ggml_status ggml_metal_graph_compute( { GGML_ASSERT(ne00 == ne10); - // TODO: assert that dim2 and dim3 are contiguous GGML_ASSERT(ne12 % ne02 == 0); GGML_ASSERT(ne13 % ne03 == 0); @@ -1499,11 +1565,9 @@ static enum ggml_status ggml_metal_graph_compute( // to the matrix-vector kernel int ne11_mm_min = 1; - // the numbers below are measured on M2 Ultra for 7B and 13B models // these numbers do not translate to other devices or model sizes // TODO: need to find a better approach - // if ([ctx->device.name isEqualToString:@"Apple M2 Ultra"]) { switch (src0t) { case GGML_TYPE_F16: ne11_mm_min = 2; break; case GGML_TYPE_Q8_0: ne11_mm_min = 7; break; @@ -1518,8 +1582,6 @@ static enum ggml_status ggml_metal_graph_compute( case GGML_TYPE_Q6_K: ne11_mm_min = 7; break; default: ne11_mm_min = 1; break; } - // } - // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel @@ -1789,11 +1851,7 @@ static enum ggml_status ggml_metal_graph_compute( [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else if (src0t == GGML_TYPE_Q3_K) { -#ifdef GGML_QKK_64 - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; -#else [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; -#endif } else if (src0t == GGML_TYPE_Q5_K) { [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; @@ -1811,16 +1869,6 @@ static enum ggml_status ggml_metal_graph_compute( const int n_as = src0->ne[2]; // src2 = ids - const int64_t ne20 = src2->ne[0]; - const int64_t ne21 = src2->ne[1]; - const int64_t ne22 = src2->ne[2]; GGML_UNUSED(ne22); - const int64_t ne23 = src2->ne[3]; GGML_UNUSED(ne23); - - const uint64_t nb20 = src2->nb[0]; GGML_UNUSED(nb20); - const uint64_t nb21 = src2->nb[1]; - const uint64_t nb22 = src2->nb[2]; GGML_UNUSED(nb22); - const uint64_t nb23 = src2->nb[3]; GGML_UNUSED(nb23); - const enum ggml_type src2t = src2->type; GGML_UNUSED(src2t); GGML_ASSERT(src2t == GGML_TYPE_I32); @@ -2044,12 +2092,7 @@ static enum ggml_status ggml_metal_graph_compute( { nth0 = 4; nth1 = 16; - #if QK_K == 64 - pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline; - #else pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline; - #endif - } break; default: { @@ -2114,11 +2157,7 @@ static enum ggml_status ggml_metal_graph_compute( [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else if (src0t == GGML_TYPE_Q3_K) { -#ifdef GGML_QKK_64 - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; -#else [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; -#endif } else if (src0t == GGML_TYPE_Q5_K) { [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; @@ -2179,6 +2218,7 @@ static enum ggml_status ggml_metal_graph_compute( case GGML_OP_RMS_NORM: { GGML_ASSERT(ne00 % 4 == 0); + GGML_ASSERT(ggml_is_contiguous_1(src0)); float eps; memcpy(&eps, dst->op_params, sizeof(float)); @@ -2206,6 +2246,7 @@ static enum ggml_status ggml_metal_graph_compute( case GGML_OP_GROUP_NORM: { GGML_ASSERT(ne00 % 4 == 0); + GGML_ASSERT(ggml_is_contiguous(src0)); //float eps; //memcpy(&eps, dst->op_params, sizeof(float)); @@ -2239,6 +2280,8 @@ static enum ggml_status ggml_metal_graph_compute( } break; case GGML_OP_NORM: { + GGML_ASSERT(ggml_is_contiguous_1(src0)); + float eps; memcpy(&eps, dst->op_params, sizeof(float)); @@ -2268,9 +2311,15 @@ static enum ggml_status ggml_metal_graph_compute( const int n_dims = ((int32_t *) dst->op_params)[1]; const int mode = ((int32_t *) dst->op_params)[2]; // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal - const int n_orig_ctx = ((int32_t *) dst->op_params)[4]; + const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; + + float freq_base; + float freq_scale; + float ext_factor; + float attn_factor; + float beta_fast; + float beta_slow; - float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); @@ -2278,38 +2327,52 @@ static enum ggml_status ggml_metal_graph_compute( memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); + const bool is_neox = mode & 2; + id pipeline = nil; - switch (src0->type) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_F32].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_F16].pipeline; break; - default: GGML_ASSERT(false); - }; + if (!is_neox) { + switch (src0->type) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break; + default: GGML_ASSERT(false); + }; + } else { + switch (src0->type) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break; + default: GGML_ASSERT(false); + }; + } [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:4]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:5]; - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:6]; - [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:7]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10]; - [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:11]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:12]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:14]; - [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:15]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:16]; - [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:17]; - [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:18]; - [encoder setBytes:&n_past length:sizeof( int) atIndex:19]; - [encoder setBytes:&n_dims length:sizeof( int) atIndex:20]; - [encoder setBytes:&mode length:sizeof( int) atIndex:21]; - [encoder setBytes:&n_orig_ctx length:sizeof( int) atIndex:22]; + if (id_src2 != nil) { + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; + } else { + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:2]; + } + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:4]; + [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5]; + [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6]; + [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7]; + [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:8]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:9]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:10]; + [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:11]; + [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:12]; + [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:13]; + [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:14]; + [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:15]; + [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:16]; + [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:17]; + [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:18]; + [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:19]; + [encoder setBytes:&n_past length:sizeof( int) atIndex:20]; + [encoder setBytes:&n_dims length:sizeof( int) atIndex:21]; + [encoder setBytes:&n_ctx_orig length:sizeof( int) atIndex:22]; [encoder setBytes:&freq_base length:sizeof( float) atIndex:23]; [encoder setBytes:&freq_scale length:sizeof( float) atIndex:24]; [encoder setBytes:&ext_factor length:sizeof( float) atIndex:25]; @@ -2561,11 +2624,6 @@ static enum ggml_status ggml_metal_graph_compute( GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) && "the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big"); - const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20); - const uint64_t nb21 = src2 ? src2->nb[1] : 0; - const uint64_t nb22 = src2 ? src2->nb[2] : 0; - const uint64_t nb23 = src2 ? src2->nb[3] : 0; - const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30); //const int64_t ne31 = src3 ? src3->ne[1] : 0; const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32); @@ -2601,7 +2659,7 @@ static enum ggml_status ggml_metal_graph_compute( case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break; case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break; case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break; - case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break; + //case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break; default: { GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00); @@ -2614,7 +2672,7 @@ static enum ggml_status ggml_metal_graph_compute( switch (ne00) { case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break; - case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break; + //case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break; default: { GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00); diff --git a/llama/ggml-metal.h b/llama/ggml-metal.h index 465dc35c..c117f051 100644 --- a/llama/ggml-metal.h +++ b/llama/ggml-metal.h @@ -1,5 +1,5 @@ /** - * llama.cpp - git 059031b8c40e1f4ba60586842c5b1ed3ddf61842 + * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28 * * MIT License * @@ -27,7 +27,7 @@ // An interface allowing to compute ggml_cgraph with Metal // // This is a fully functional interface that extends ggml with GPU support for Apple devices. -// A similar interface can be created for other GPU backends (e.g. Vulkan, CUDA, OpenCL, etc.) +// A similar interface can be created for other GPU backends (e.g. Vulkan, CUDA, etc.) // // How it works? // diff --git a/llama/ggml-metal.metal b/llama/ggml-metal.metal index 923fd72c..88b8aece 100644 --- a/llama/ggml-metal.metal +++ b/llama/ggml-metal.metal @@ -1,5 +1,5 @@ /** - * llama.cpp - git 059031b8c40e1f4ba60586842c5b1ed3ddf61842 + * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28 * * MIT License * @@ -194,6 +194,53 @@ kernel void kernel_div( } } +template +kernel void kernel_repeat( + device const char * src0, + device char * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i3 = tgpig.z; + const int64_t i2 = tgpig.y; + const int64_t i1 = tgpig.x; + + const int64_t i03 = i3 % ne03; + const int64_t i02 = i2 % ne02; + const int64_t i01 = i1 % ne01; + + device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01; + device char * dst_ptr = dst + i3*nb3 + i2*nb2 + i1*nb1 ; + + for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { + const int i00 = i0 % ne00; + *((device T *)(dst_ptr + i0*nb0)) = *((device T *)(src0_ptr + i00*nb00)); + } +} + +typedef decltype(kernel_repeat) kernel_repeat_t; + +template [[host_name("kernel_repeat_f32")]] kernel kernel_repeat_t kernel_repeat; +template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat; +template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat; +template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat; + // assumption: src1 is a row // broadcast src1 into src0 kernel void kernel_add_row( @@ -1633,8 +1680,7 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) { // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. static void rope_yarn( float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale, - thread float * cos_theta, thread float * sin_theta -) { + thread float * cos_theta, thread float * sin_theta) { // Get n-d rotational scaling corrected for extrapolation float theta_interp = freq_scale * theta_extrap; float theta = theta_interp; @@ -1651,56 +1697,23 @@ static void rope_yarn( // Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get // `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))` -static float rope_yarn_corr_factor(int n_dims, int n_orig_ctx, float n_rot, float base) { - return n_dims * log(n_orig_ctx / (n_rot * 2 * M_PI_F)) / (2 * log(base)); +static float rope_yarn_corr_factor(int n_dims, int n_ctx_orig, float n_rot, float base) { + return n_dims * log(n_ctx_orig / (n_rot * 2 * M_PI_F)) / (2 * log(base)); } static void rope_yarn_corr_dims( - int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2] + int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2] ) { // start and end correction dims - dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_fast, freq_base))); - dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_slow, freq_base))); + dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_fast, freq_base))); + dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base))); } -typedef void (rope_t)( - device const void * src0, - device const int32_t * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - constant int & n_past, - constant int & n_dims, - constant int & mode, - constant int & n_orig_ctx, - constant float & freq_base, - constant float & freq_scale, - constant float & ext_factor, - constant float & attn_factor, - constant float & beta_fast, - constant float & beta_slow, - uint tiitg[[thread_index_in_threadgroup]], - uint3 tptg[[threads_per_threadgroup]], - uint3 tgpig[[threadgroup_position_in_grid]]); - template -kernel void kernel_rope( +kernel void kernel_rope_norm( device const void * src0, device const int32_t * src1, + device const float * src2, device float * dst, constant int64_t & ne00, constant int64_t & ne01, @@ -1720,8 +1733,7 @@ kernel void kernel_rope( constant uint64_t & nb3, constant int & n_past, constant int & n_dims, - constant int & mode, - constant int & n_orig_ctx, + constant int & n_ctx_orig, constant float & freq_base, constant float & freq_scale, constant float & ext_factor, @@ -1735,71 +1747,130 @@ kernel void kernel_rope( const int64_t i2 = tgpig[1]; const int64_t i1 = tgpig[0]; - const bool is_neox = mode & 2; - float corr_dims[2]; - rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims); + rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); device const int32_t * pos = src1; - const int64_t p = pos[i2]; - - const float theta_0 = (float)p; + const float theta_base = (float) pos[i2]; const float inv_ndims = -1.f/n_dims; - if (!is_neox) { - for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) { + float cos_theta; + float sin_theta; - const float theta = theta_0 * pow(freq_base, inv_ndims*i0); - float cos_theta, sin_theta; - rope_yarn(theta, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); + for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) { + if (i0 < n_dims) { + const int64_t ic = i0/2; + + const float theta = theta_base * pow(freq_base, inv_ndims*i0); + + const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; + + rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - const T x0 = src[0]; - const T x1 = src[1]; + const float x0 = src[0]; + const float x1 = src[1]; dst_data[0] = x0*cos_theta - x1*sin_theta; dst_data[1] = x0*sin_theta + x1*cos_theta; - } - } else { - for (int64_t ic = 2*tiitg; ic < ne0; ic += 2*tptg.x) { - if (ic < n_dims) { - const int64_t ib = 0; + } else { + device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - // simplified from `(ib * n_dims + ic) * inv_ndims` - const float cur_rot = inv_ndims*ic - ib; - - const float theta = theta_0 * pow(freq_base, cur_rot); - float cos_theta, sin_theta; - rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta); - - const int64_t i0 = ib*n_dims + ic/2; - - device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - const float x0 = src[0]; - const float x1 = src[n_dims/2]; - - dst_data[0] = x0*cos_theta - x1*sin_theta; - dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta; - } else { - const int64_t i0 = ic; - - device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - dst_data[0] = src[0]; - dst_data[1] = src[1]; - } + dst_data[0] = src[0]; + dst_data[1] = src[1]; } } } -template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope; -template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope; +template +kernel void kernel_rope_neox( + device const void * src0, + device const int32_t * src1, + device const float * src2, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + constant int & n_past, + constant int & n_dims, + constant int & n_ctx_orig, + constant float & freq_base, + constant float & freq_scale, + constant float & ext_factor, + constant float & attn_factor, + constant float & beta_fast, + constant float & beta_slow, + uint tiitg[[thread_index_in_threadgroup]], + uint3 tptg[[threads_per_threadgroup]], + uint3 tgpig[[threadgroup_position_in_grid]]) { + const int64_t i3 = tgpig[2]; + const int64_t i2 = tgpig[1]; + const int64_t i1 = tgpig[0]; + + float corr_dims[2]; + rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); + + device const int32_t * pos = src1; + + const float theta_base = (float) pos[i2]; + const float inv_ndims = -1.f/n_dims; + + float cos_theta; + float sin_theta; + + for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) { + if (i0 < n_dims) { + const int64_t ic = i0/2; + + const float theta = theta_base * pow(freq_base, inv_ndims*i0); + + const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; + + rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); + + device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); + device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); + + const float x0 = src[0]; + const float x1 = src[n_dims/2]; + + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta; + } else { + device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } + } +} + +typedef decltype(kernel_rope_norm) kernel_rope_norm_t; +typedef decltype(kernel_rope_neox) kernel_rope_neox_t; + +template [[host_name("kernel_rope_norm_f32")]] kernel kernel_rope_norm_t kernel_rope_norm; +template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_rope_norm; + +template [[host_name("kernel_rope_neox_f32")]] kernel kernel_rope_neox_t kernel_rope_neox; +template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_rope_neox; typedef void (im2col_t)( device const float * x, @@ -2230,11 +2301,7 @@ kernel void kernel_flash_attn_ext_f16( // pointer to the mask device const half * mp = (device const half *) (mask + iq1*nb31); - // prepare diagonal scale matrix - simdgroup_float8x8 mscale(scale); - - // prepare diagonal slope matrix - simdgroup_float8x8 mslope(1.0f); + float slope = 1.0f; // ALiBi if (max_bias > 0.0f) { @@ -2243,7 +2310,7 @@ kernel void kernel_flash_attn_ext_f16( const float base = h < n_head_log2 ? m0 : m1; const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; - mslope = simdgroup_float8x8(pow(base, exph)); + slope = pow(base, exph); } // loop over the KV cache @@ -2268,18 +2335,20 @@ kernel void kernel_flash_attn_ext_f16( simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk); } + simdgroup_store(mqk, ss + 8*cc, TF, 0, false); + + const short tx = tiisg%4; + const short ty = tiisg/4; + if (mask != q) { // mqk = mqk*scale + mask*slope - simdgroup_half8x8 mm; - simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(half), 0, false); - simdgroup_multiply(mm, mslope, mm); - simdgroup_multiply_accumulate(mqk, mqk, mscale, mm); + ss[8*cc + ty*TF + 2*tx + 0] = scale*ss[8*cc + ty*TF + 2*tx + 0] + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 0]; + ss[8*cc + ty*TF + 2*tx + 1] = scale*ss[8*cc + ty*TF + 2*tx + 1] + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 1]; } else { // mqk = mqk*scale - simdgroup_multiply(mqk, mscale, mqk); + ss[8*cc + ty*TF + 2*tx + 0] *= scale; + ss[8*cc + ty*TF + 2*tx + 1] *= scale; } - - simdgroup_store(mqk, ss + 8*cc, TF, 0, false); } } @@ -2442,7 +2511,7 @@ template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<96>; template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<112>; template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128>; -template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256>; +//template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256>; template // head size, queries per threadgroup, cache items per threadgroup kernel void kernel_flash_attn_ext_vec_f16( @@ -2720,7 +2789,7 @@ kernel void kernel_flash_attn_ext_vec_f16( } template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>; -template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>; +//template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>; kernel void kernel_cpy_f16_f16( device const half * src0, @@ -2842,8 +2911,7 @@ kernel void kernel_cpy_f32_f16( for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); - // TODO: is there a better way to handle -INFINITY? - dst_data[i00] = src[0] == -INFINITY ? -MAXHALF : src[0]; + dst_data[i00] = src[0]; } } @@ -3344,31 +3412,30 @@ kernel void kernel_concat( constant uint64_t & nb1, constant uint64_t & nb2, constant uint64_t & nb3, + constant int32_t & dim, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig.z; - const int64_t i02 = tgpig.y; - const int64_t i01 = tgpig.x; + const int64_t i3 = tgpig.z; + const int64_t i2 = tgpig.y; + const int64_t i1 = tgpig.x; - const int64_t i13 = i03 % ne13; - const int64_t i12 = i02 % ne12; - const int64_t i11 = i01 % ne11; + int64_t o[4] = {0, 0, 0, 0}; + o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03)); - device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + tpitg.x*nb00; - device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10; - device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0; + device const float * x; for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { - if (i02 < ne02) { - ((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0]; - src0_ptr += ntg.x*nb00; + if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) { + x = (device const float *)(src0 + (i3 )*nb03 + (i2 )*nb02 + (i1 )*nb01 + (i0 )*nb00); } else { - ((device float *)dst_ptr)[0] = ((device float *)src1_ptr)[0]; - src1_ptr += ntg.x*nb10; + x = (device const float *)(src1 + (i3 - o[3])*nb13 + (i2 - o[2])*nb12 + (i1 - o[1])*nb11 + (i0 - o[0])*nb10); } - dst_ptr += ntg.x*nb0; + + device float * y = (device float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + *y = *x; } } @@ -3411,7 +3478,6 @@ void kernel_mul_mv_q2_K_f32_impl( const int step = sizeof(block_q2_K) * nb; -#if QK_K == 256 const int ix = tiisg/8; // 0...3 const int it = tiisg%8; // 0...7 const int iq = it/4; // 0 or 1 @@ -3463,57 +3529,6 @@ void kernel_mul_mv_q2_K_f32_impl( y4 += 4 * QK_K; } -#else - const int ix = tiisg/2; // 0...15 - const int it = tiisg%2; // 0...1 - - device const float * y4 = y + ix * QK_K + 8 * it; - - for (int ib = ix; ib < nb; ib += 16) { - - float4 sumy = {0.f, 0.f, 0.f, 0.f}; - for (int i = 0; i < 8; ++i) { - yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0]; - yl[i+ 8] = y4[i+16]; sumy[1] += yl[i+ 8]; - yl[i+16] = y4[i+32]; sumy[2] += yl[i+16]; - yl[i+24] = y4[i+48]; sumy[3] += yl[i+24]; - } - - device const uint8_t * sc = (device const uint8_t *)x[ib].scales; - device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it; - device const half * dh = &x[ib].d; - - for (int row = 0; row < N_DST; row++) { - - float4 acc1 = {0.f, 0.f, 0.f, 0.f}; - float4 acc2 = {0.f, 0.f, 0.f, 0.f}; - for (int i = 0; i < 8; i += 2) { - acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003); - acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300); - acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c); - acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00); - acc1[2] += yl[i+16] * (qs[i/2] & 0x0030); - acc2[2] += yl[i+17] * (qs[i/2] & 0x3000); - acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0); - acc2[3] += yl[i+25] * (qs[i/2] & 0xc000); - } - - float dall = dh[0]; - float dmin = dh[1]; - sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f + - (acc1[1] + 1.f/256.f * acc2[1]) * (sc[1] & 0xF) * 1.f/ 4.f + - (acc1[2] + 1.f/256.f * acc2[2]) * (sc[2] & 0xF) * 1.f/16.f + - (acc1[3] + 1.f/256.f * acc2[3]) * (sc[3] & 0xF) * 1.f/64.f) - - dmin * (sumy[0] * (sc[0] >> 4) + sumy[1] * (sc[1] >> 4) + sumy[2] * (sc[2] >> 4) + sumy[3] * (sc[3] >> 4)); - - qs += step/2; - sc += step; - dh += step/2; - } - - y4 += 16 * QK_K; - } -#endif for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); @@ -3551,7 +3566,6 @@ kernel void kernel_mul_mv_q2_K_f32( kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); } -#if QK_K == 256 void kernel_mul_mv_q3_K_f32_impl( device const void * src0, device const float * src1, @@ -3710,84 +3724,6 @@ void kernel_mul_mv_q3_K_f32_impl( } } } -#else -void kernel_mul_mv_q3_K_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne10, - constant int64_t & ne12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - const int nb = ne00/QK_K; - - const int64_t r0 = tgpig.x; - const int64_t r1 = tgpig.y; - const int64_t im = tgpig.z; - - const int row = 2 * r0 + sgitg; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - - device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb + offset0; - device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - - const int ix = tiisg/4; - const int il = 4 * (tiisg%4);// 0, 4, 8, 12 - const int iq = il/8; // 0, 0, 1, 1 - const int in = il%8; // 0, 4, 0, 4 - - float2 sum = {0.f, 0.f}; - - for (int i = ix; i < nb; i += 8) { - - const float d_all = (float)(x[i].d); - - device const uint16_t * q = (device const uint16_t *)(x[i].qs + il); - device const uint16_t * h = (device const uint16_t *)(x[i].hmask + in); - device const uint16_t * s = (device const uint16_t *)(x[i].scales); - device const float * y = yy + i * QK_K + il; - - const float d1 = d_all * ((int32_t)(s[0] & 0x000F) - 8); - const float d2 = d_all * ((int32_t)(s[0] & 0x00F0) - 128) * 1.f/64.f; - const float d3 = d_all * ((int32_t)(s[0] & 0x0F00) - 2048) * 1.f/4096.f; - const float d4 = d_all * ((int32_t)(s[0] & 0xF000) - 32768) * 1.f/262144.f; - - for (int l = 0; l < 4; l += 2) { - const uint16_t hm = h[l/2] >> iq; - sum[0] += y[l+ 0] * d1 * ((int32_t)(q[l/2] & 0x0003) - ((hm & 0x0001) ? 0 : 4)) - + y[l+16] * d2 * ((int32_t)(q[l/2] & 0x000c) - ((hm & 0x0004) ? 0 : 16)) - + y[l+32] * d3 * ((int32_t)(q[l/2] & 0x0030) - ((hm & 0x0010) ? 0 : 64)) - + y[l+48] * d4 * ((int32_t)(q[l/2] & 0x00c0) - ((hm & 0x0040) ? 0 : 256)); - sum[1] += y[l+ 1] * d1 * ((int32_t)(q[l/2] & 0x0300) - ((hm & 0x0100) ? 0 : 1024)) - + y[l+17] * d2 * ((int32_t)(q[l/2] & 0x0c00) - ((hm & 0x0400) ? 0 : 4096)) - + y[l+33] * d3 * ((int32_t)(q[l/2] & 0x3000) - ((hm & 0x1000) ? 0 : 16384)) - + y[l+49] * d4 * ((int32_t)(q[l/2] & 0xc000) - ((hm & 0x4000) ? 0 : 65536)); - } - - } - const float sumf = sum[0] + sum[1] * 1.f/256.f; - - const float tot = simd_sum(sumf); - if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + row] = tot; - } - -} -#endif [[host_name("kernel_mul_mv_q3_K_f32")]] kernel void kernel_mul_mv_q3_K_f32( @@ -3817,7 +3753,6 @@ kernel void kernel_mul_mv_q3_K_f32( kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); } -#if QK_K == 256 void kernel_mul_mv_q4_K_f32_impl( device const void * src0, device const float * src1, @@ -3931,103 +3866,6 @@ void kernel_mul_mv_q4_K_f32_impl( } } } -#else -void kernel_mul_mv_q4_K_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne10, - constant int64_t & ne12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - const int ix = tiisg/4; // 0...7 - const int it = tiisg%4; // 0...3 - - const int nb = ne00/QK_K; - const int r0 = tgpig.x; - const int r1 = tgpig.y; - const int im = tgpig.z; - const int first_row = r0 * N_DST; - const int ib_row = first_row * nb; - - const uint i12 = im%ne12; - const uint i13 = im/ne12; - - const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - - device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - - float yl[8]; - float yh[8]; - float sumf[N_DST]={0.f}, all_sum; - - const int step = sizeof(block_q4_K) * nb / 2; - - device const float * y4 = y + ix * QK_K + 8 * it; - - uint16_t sc16[4]; - - for (int ib = ix; ib < nb; ib += 8) { - - float2 sumy = {0.f, 0.f}; - for (int i = 0; i < 8; ++i) { - yl[i] = y4[i+ 0]; sumy[0] += yl[i]; - yh[i] = y4[i+32]; sumy[1] += yh[i]; - } - - device const uint16_t * sc = (device const uint16_t *)x[ib].scales; - device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it; - device const half * dh = x[ib].d; - - for (int row = 0; row < N_DST; row++) { - - sc16[0] = sc[0] & 0x000f; - sc16[1] = sc[0] & 0x0f00; - sc16[2] = sc[0] & 0x00f0; - sc16[3] = sc[0] & 0xf000; - - float2 acc1 = {0.f, 0.f}; - float2 acc2 = {0.f, 0.f}; - for (int i = 0; i < 8; i += 2) { - acc1[0] += yl[i+0] * (qs[i/2] & 0x000F); - acc1[1] += yl[i+1] * (qs[i/2] & 0x0F00); - acc2[0] += yh[i+0] * (qs[i/2] & 0x00F0); - acc2[1] += yh[i+1] * (qs[i/2] & 0xF000); - } - - float dall = dh[0]; - float dmin = dh[1]; - sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc16[0] + - (acc2[0] + 1.f/256.f * acc2[1]) * sc16[1] * 1.f/4096.f) - - dmin * 1.f/16.f * (sumy[0] * sc16[2] + sumy[1] * sc16[3] * 1.f/256.f); - - qs += step; - sc += step; - dh += step; - } - - y4 += 8 * QK_K; - } - - for (int row = 0; row < N_DST; ++row) { - all_sum = simd_sum(sumf[row]); - if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; - } - } -} -#endif [[host_name("kernel_mul_mv_q4_K_f32")]] kernel void kernel_mul_mv_q4_K_f32( @@ -4095,8 +3933,6 @@ void kernel_mul_mv_q5_K_f32_impl( const int step = sizeof(block_q5_K) * nb; -#if QK_K == 256 -# float yl[16], yh[16]; const uint16_t kmask1 = 0x3f3f; @@ -4179,54 +4015,6 @@ void kernel_mul_mv_q5_K_f32_impl( y1 += 4 * QK_K; } -#else - float yl[8], yh[8]; - - const int il = 4 * (tiisg/8); // 0, 4, 8, 12 - const int ix = tiisg%8; - const int iq = il/8; // 0, 0, 1, 1 - const int in = il%8; // 0, 4, 0, 4 - - device const float * y = yy + ix*QK_K + il; - - for (int i = ix; i < nb; i += 8) { - - for (int l = 0; l < 4; ++l) { - yl[l+0] = y[l+ 0]; - yl[l+4] = y[l+16]; - yh[l+0] = y[l+32]; - yh[l+4] = y[l+48]; - } - - device const half * dh = &x[i].d; - device const uint8_t * q = x[i].qs + il; - device const uint8_t * h = x[i].qh + in; - device const int8_t * s = x[i].scales; - - for (int row = 0; row < 2; ++row) { - - const float d = dh[0]; - - float2 acc = {0.f, 0.f}; - for (int l = 0; l < 4; ++l) { - const uint8_t hl = h[l] >> iq; - acc[0] += yl[l+0] * s[0] * ((int16_t)(q[l+ 0] & 0x0F) - (hl & 0x01 ? 0 : 16)) - + yl[l+4] * s[1] * ((int16_t)(q[l+16] & 0x0F) - (hl & 0x04 ? 0 : 16)); - acc[1] += yh[l+0] * s[2] * ((int16_t)(q[l+ 0] & 0xF0) - (hl & 0x10 ? 0 : 256)) - + yh[l+4] * s[3] * ((int16_t)(q[l+16] & 0xF0) - (hl & 0x40 ? 0 : 256)); - } - sumf[row] += d * (acc[0] + 1.f/16.f * acc[1]); - - q += step; - h += step; - s += step; - dh += step/2; - - } - - y += 8 * QK_K; - } -#endif for (int row = 0; row < 2; ++row) { const float tot = simd_sum(sumf[row]); @@ -4305,7 +4093,6 @@ void kernel_mul_mv_q6_K_f32_impl( float sumf = 0; -#if QK_K == 256 const int tid = tiisg/2; const int ix = tiisg%2; const int ip = tid/8; // 0 or 1 @@ -4341,30 +4128,6 @@ void kernel_mul_mv_q6_K_f32_impl( } -#else - const int ix = tiisg/4; - const int il = 4*(tiisg%4); - - for (int i = ix; i < nb; i += 8) { - device const float * y = yy + i * QK_K + il; - device const uint8_t * ql = x[i].ql + il; - device const uint8_t * qh = x[i].qh + il; - device const int8_t * s = x[i].scales; - - const float d = x[i].d; - - float4 sums = {0.f, 0.f, 0.f, 0.f}; - for (int l = 0; l < 4; ++l) { - sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & kmask1) << 4)) - 32); - sums[1] += y[l+16] * ((int8_t)((ql[l+16] & 0xF) | ((qh[l] & kmask2) << 2)) - 32); - sums[2] += y[l+32] * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & kmask3) >> 0)) - 32); - sums[3] += y[l+48] * ((int8_t)((ql[l+16] >> 4) | ((qh[l] & kmask4) >> 2)) - 32); - } - sumf += d * (sums[0] * s[0] + sums[1] * s[1] + sums[2] * s[2] + sums[3] * s[3]); - } - -#endif - const float tot = simd_sum(sumf); if (tiisg == 0) { dst[r1*ne0 + im*ne0*ne1 + row] = tot; @@ -5198,9 +4961,7 @@ void kernel_mul_mv_iq1_m_f32_impl( device const float * y4 = y + 32 * ix; -#if QK_K != 64 iq1m_scale_t scale; -#endif for (int ib32 = ix; ib32 < nb32; ib32 += 32) { @@ -5221,10 +4982,7 @@ void kernel_mul_mv_iq1_m_f32_impl( device const uint16_t * sc = (device const uint16_t *)xr->scales; for (int row = 0; row < N_DST; row++) { - -#if QK_K != 64 scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); -#endif constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700))); constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700))); @@ -5240,14 +4998,9 @@ void kernel_mul_mv_iq1_m_f32_impl( } const float delta1 = sumy[0] * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[1] * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA); const float delta2 = sumy[2] * (qh[1] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[3] * (qh[1] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA); -#if QK_K == 64 - const float d = (float) *((device const half *)(sc - 1)); - sumf[row] += d * ((sum[0] + delta1) * (2*((sc[0] >> (8*(ib%2)+0)) & 0xf) + 1) + - (sum[1] + delta2) * (2*((sc[0] >> (8*(ib%2)+4)) & 0xf) + 1)); -#else + sumf[row] += (float)scale.f16 * ((sum[0] + delta1) * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 7) + 1) + (sum[1] + delta2) * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 7) + 1)); -#endif sc += nb*sizeof(block_iq1_m)/2; qs += nb*sizeof(block_iq1_m); @@ -5359,7 +5112,6 @@ void kernel_mul_mv_iq4_nl_f32_impl( } } -#if QK_K != 64 void kernel_mul_mv_iq4_xs_f32_impl( device const void * src0, device const float * src1, @@ -5454,7 +5206,6 @@ void kernel_mul_mv_iq4_xs_f32_impl( } } } -#endif [[host_name("kernel_mul_mv_iq1_s_f32")]] kernel void kernel_mul_mv_iq1_s_f32( @@ -5567,11 +5318,7 @@ kernel void kernel_mul_mv_iq4_xs_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { -#if QK_K == 64 - kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); -#else kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); -#endif } //============================= templates and their specializations ============================= @@ -5697,10 +5444,9 @@ void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg float dl, ml; uint8_t sc = xb->scales[il]; -#if QK_K == 256 q = q + 32*(il/8) + 16*(il&1); il = (il/2)%4; -#endif + half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h); uchar mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); dl = d * (sc & 0xF) * coef, ml = min * (sc >> 4); @@ -5716,7 +5462,6 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg device const uint8_t * h = (device const uint8_t *)xb->hmask; device const int8_t * scales = (device const int8_t *)xb->scales; -#if QK_K == 256 q = q + 32 * (il/8) + 16 * (il&1); h = h + 16 * (il&1); uint8_t m = 1 << (il/2); @@ -5737,17 +5482,6 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg for (int i = 0; i < 16; ++i) { reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml); } -#else - float kcoef = il&1 ? 1.f/16.f : 1.f; - uint16_t kmask = il&1 ? 0xF0 : 0x0F; - float dl = d_all * ((scales[il/2] & kmask) * kcoef - 8); - float coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h); - uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); - uint8_t m = 1<<(il*2); - for (int i = 0; i < 16; ++i) { - reg[i/4][i%4] = coef * dl * ((q[i] & mask) - ((h[i%8] & (m * (1 + i/8))) ? 0 : 4.f/coef)); - } -#endif } static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) { @@ -5759,7 +5493,6 @@ template void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) { device const uchar * q = xb->qs; -#if QK_K == 256 short is = (il/4) * 2; q = q + (il/4) * 32 + 16 * (il&1); il = il & 3; @@ -5768,16 +5501,7 @@ void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg const float min = xb->dmin; const float dl = d * sc[0]; const float ml = min * sc[1]; -#else - (void) get_scale_min_k4_just2; - q = q + 16 * (il&1); - device const uint8_t * s = xb->scales; - device const half2 * dh = (device const half2 *)xb->d; - const float2 d = (float2)dh[0]; - const float dl = il<2 ? d[0] * (s[0]&0xF) : d[0] * (s[1]&0xF)/16.h; - const float ml = il<2 ? d[1] * (s[0]>>4) : d[1] * (s[1]>>4); -#endif const ushort mask = il<2 ? 0x0F : 0xF0; for (int i = 0; i < 16; ++i) { reg[i/4][i%4] = dl * (q[i] & mask) - ml; @@ -5789,7 +5513,6 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg device const uint8_t * q = xb->qs; device const uint8_t * qh = xb->qh; -#if QK_K == 256 short is = (il/4) * 2; q = q + 32 * (il/4) + 16 * (il&1); qh = qh + 16 * (il&1); @@ -5806,17 +5529,6 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg for (int i = 0; i < 16; ++i) { reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml; } -#else - q = q + 16 * (il&1); - device const int8_t * s = xb->scales; - const float dl = xb->d * s[il]; - uint8_t m = 1<<(il*2); - const float coef = il<2 ? 1.f : 1.f/16.f; - const ushort mask = il<2 ? 0x0F : 0xF0; - for (int i = 0; i < 16; ++i) { - reg[i/4][i%4] = coef * dl * ((q[i] & mask) - (qh[i%8] & (m*(1+i/8)) ? 0.f : 16.f/coef)); - } -#endif } template @@ -5826,15 +5538,11 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg device const uint8_t * qh = (device const uint8_t *)xb->qh; device const int8_t * scales = (device const int8_t *)xb->scales; -#if QK_K == 256 ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1); qh = qh + 32*(il/8) + 16*(il&1); float sc = scales[(il%2) + 2 * ((il/2))]; il = (il/2) & 3; -#else - ql = ql + 16 * (il&1); - float sc = scales[il]; -#endif + const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F; const float coef = il>1 ? 1.f/16.f : 1.f; @@ -5991,20 +5699,15 @@ void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 & const int ib32 = il/2; il = il%2; device const uint16_t * sc = (device const uint16_t *)xb->scales; -#if QK_K == 64 - const float d = xb->d; -#else + iq1m_scale_t scale; scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); const float d = scale.f16; -#endif + device const uint8_t * qs = xb->qs + 4*ib32 + 2*il; device const uint8_t * qh = xb->qh + 2*ib32 + il; -#if QK_K == 64 - const float dl = d * (2*((sc[ib32/2] >> (8*(ib32%2)+4*il)) & 0xf) + 1); -#else + const float dl = d * (2*((sc[ib32/2] >> (6*(ib32%2)+3*il)) & 7) + 1); -#endif const float ml1 = dl * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA); const float ml2 = dl * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA); constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700))); @@ -6034,9 +5737,6 @@ void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4 template void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) { -#if QK_K == 64 - dequantize_iq4_nl(xb, il, reg); -#else // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 const int ib32 = il/2; il = il%2; @@ -6053,7 +5753,6 @@ void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 reg[i][2] = d * kvalues_iq4nl_f[q8[2]]; reg[i][3] = d * kvalues_iq4nl_f[q8[3]]; } -#endif } template @@ -6558,11 +6257,7 @@ kernel void kernel_mul_mm_id( sgitg); } -#if QK_K == 256 #define QK_NL 16 -#else -#define QK_NL 4 -#endif // // get rows @@ -6602,11 +6297,7 @@ template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_t kernel_get_r template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows; -#if QK_K == 64 -template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows; -#else template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows; -#endif // // matrix-matrix multiplication @@ -6634,11 +6325,7 @@ template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_m template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm; -#if QK_K == 64 -template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm; -#else template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm; -#endif // // indirect matrix-matrix multiplication @@ -6666,11 +6353,7 @@ template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mat_mm_id_t kernel template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_iq1_m_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -#if QK_K == 64 -template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -#else template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -#endif // // matrix-vector multiplication @@ -6879,7 +6562,5 @@ template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -#if QK_K != 64 template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; -#endif diff --git a/llama/ggml-metal.o b/llama/ggml-metal.o index 887a0d0b..ac44271d 100644 Binary files a/llama/ggml-metal.o and b/llama/ggml-metal.o differ diff --git a/llama/ggml-quants.c b/llama/ggml-quants.c index 8a8ef139..0eda1ffc 100644 --- a/llama/ggml-quants.c +++ b/llama/ggml-quants.c @@ -1,5 +1,5 @@ /** - * llama.cpp - git 059031b8c40e1f4ba60586842c5b1ed3ddf61842 + * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28 * * MIT License * @@ -288,6 +288,403 @@ static const uint64_t table_b2b_0[1 << 8] = { B8(00, 10) }; // ( b) << 4 static const uint64_t table_b2b_1[1 << 8] = { B8(10, 00) }; // (!b) << 4 #endif +#if defined(__loongarch_asx) + +#ifdef __clang__ +#define VREGS_PREFIX "$vr" +#define XREGS_PREFIX "$xr" +#else // GCC +#define VREGS_PREFIX "$f" +#define XREGS_PREFIX "$f" +#endif +#define __ALL_REGS "0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31" +// Convert __m128i to __m256i +static inline __m256i ____m256i(__m128i in) { + __m256i out = __lasx_xvldi(0); + __asm__ volatile ( + ".irp i," __ALL_REGS "\n\t" + " .ifc %[out], " XREGS_PREFIX"\\i \n\t" + " .irp j," __ALL_REGS "\n\t" + " .ifc %[in], " VREGS_PREFIX "\\j \n\t" + " xvpermi.q $xr\\i, $xr\\j, 0x20 \n\t" + " .endif \n\t" + " .endr \n\t" + " .endif \n\t" + ".endr \n\t" + : [out] "+f" (out) : [in] "f" (in) + ); + return out; +} +// Convert two __m128i to __m256i +static inline __m256i lasx_set_q(__m128i inhi, __m128i inlo) { + __m256i out; + __asm__ volatile ( + ".irp i," __ALL_REGS "\n\t" + " .ifc %[hi], " VREGS_PREFIX "\\i \n\t" + " .irp j," __ALL_REGS "\n\t" + " .ifc %[lo], " VREGS_PREFIX "\\j \n\t" + " xvpermi.q $xr\\i, $xr\\j, 0x20 \n\t" + " .endif \n\t" + " .endr \n\t" + " .endif \n\t" + ".endr \n\t" + ".ifnc %[out], %[hi] \n\t" + ".irp i," __ALL_REGS "\n\t" + " .ifc %[out], " XREGS_PREFIX "\\i \n\t" + " .irp j," __ALL_REGS "\n\t" + " .ifc %[hi], " VREGS_PREFIX "\\j \n\t" + " xvori.b $xr\\i, $xr\\j, 0 \n\t" + " .endif \n\t" + " .endr \n\t" + " .endif \n\t" + ".endr \n\t" + ".endif \n\t" + : [out] "=f" (out), [hi] "+f" (inhi) + : [lo] "f" (inlo) + ); + return out; +} +// Convert __m256i low part to __m128i +static inline __m128i lasx_extracti128_lo(__m256i in) { + __m128i out; + __asm__ volatile ( + ".ifnc %[out], %[in] \n\t" + ".irp i," __ALL_REGS "\n\t" + " .ifc %[out], " VREGS_PREFIX "\\i \n\t" + " .irp j," __ALL_REGS "\n\t" + " .ifc %[in], " XREGS_PREFIX "\\j \n\t" + " vori.b $vr\\i, $vr\\j, 0 \n\t" + " .endif \n\t" + " .endr \n\t" + " .endif \n\t" + ".endr \n\t" + ".endif \n\t" + : [out] "=f" (out) : [in] "f" (in) + ); + return out; +} +// Convert __m256i high part to __m128i +static inline __m128i lasx_extracti128_hi(__m256i in) { + __m128i out; + __asm__ volatile ( + ".irp i," __ALL_REGS "\n\t" + " .ifc %[out], " VREGS_PREFIX "\\i \n\t" + " .irp j," __ALL_REGS "\n\t" + " .ifc %[in], " XREGS_PREFIX "\\j \n\t" + " xvpermi.q $xr\\i, $xr\\j, 0x11 \n\t" + " .endif \n\t" + " .endr \n\t" + " .endif \n\t" + ".endr \n\t" + : [out] "=f" (out) : [in] "f" (in) + ); + return out; +} + +static __m256i lasx_set_w(int e7, int e6, int e5, int e4, int e3, int e2, int e1, int e0) { + v8i32 __ret = {e0, e1, e2, e3, e4, e5, e6, e7}; + return (__m256i)__ret; +} + +static __m128i lsx_set_w(int32_t a, int32_t b, int32_t c, int32_t d) { + v4i32 __ret = {d, c, b, a}; + return (__m128i)__ret; +} + +static __m256i lasx_set_d(int64_t a, int64_t b, int64_t c, int64_t d) { + v4i64 __ret = {d, c, b, a}; + return (__m256i)__ret; +} + +static __m256i lasx_insertf128( __m128i x, __m128i y) { + return lasx_set_q(x, y); +} + +static __m128i lsx_shuffle_b(__m128i a, __m128i b) { + __m128i mask_f, zero, tmp0, tmp2, mask; + int f = 0x8f; + mask_f = __lsx_vreplgr2vr_b(f); + zero = __lsx_vldi(0); + tmp0 = __lsx_vand_v(b, mask_f); // get mask with low 4 bit and sign bits + tmp0 = __lsx_vori_b(tmp0, 0x10); // make each mask or with 0x10 prepare for positive + mask = __lsx_vsle_b(zero, tmp0); // if mask >= 0, set mask + tmp2 = __lsx_vand_v(tmp0, mask); // maskout the in2 < ones + return __lsx_vshuf_b(a, zero, tmp2); +} + +static __m256i lasx_shuffle_b(__m256i a, __m256i b) { + __m256i mask_f, zero, tmp0, tmp2, mask; + int f = 0x8f; + mask_f = __lasx_xvreplgr2vr_b(f); + zero = __lasx_xvldi(0); + tmp0 = __lasx_xvand_v(b, mask_f); // get mask with low 4 bit and sign bits + tmp0 = __lasx_xvori_b(tmp0, 0x10); // make each mask or with 0x10 prepare for positive + mask = __lasx_xvsle_b(zero, tmp0); // if mask >= 0, set mask + tmp2 = __lasx_xvand_v(tmp0, mask); // maskout the in2 < ones + return __lasx_xvshuf_b(a, zero, tmp2); +} + +static __m256i lasx_extu8_16(__m128i a) { + __m128i zero = __lsx_vldi(0); + __m128i vlo = __lsx_vilvl_b(zero, a); + __m128i vhi = __lsx_vilvh_b(zero, a); + return lasx_set_q(vhi, vlo); +} + +static __m256i lasx_ext8_16(__m128i a) { + __m128i sign = __lsx_vslti_b(a, 0); + __m128i vlo = __lsx_vilvl_b(sign, a); + __m128i vhi = __lsx_vilvh_b(sign, a); + return lasx_set_q(vhi, vlo); +} + +static __m256i lasx_ext16_32(__m128i a) { + __m256i tmp1; + tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 0), 0); + tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 1), 1); + tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 2), 2); + tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 3), 3); + tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 4), 4); + tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 5), 5); + tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 6), 6); + tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 7), 7); + return tmp1; +} + +static __m128i lasx_extracti128( __m256i a, int pos) { + __m128i ret; + if( pos == 0) + { + ret = lasx_extracti128_lo(a); + } else { + ret = lasx_extracti128_hi(a); + } + return ret; +} + +static __m128 lasx_extractf128( __m256 a, int pos) { + __m128 ret; + if( pos == 0) + { + ret = (__m128)lasx_extracti128_lo((__m256i)a); + } else { + ret = (__m128)lasx_extracti128_hi((__m256i)a); + } + return ret; +} + +static __m128i lsx_hadd_h(__m128i a, __m128i b) { + __m128i tmp1 = __lsx_vpickev_h(b, a); + __m128i tmp2 = __lsx_vpickod_h(b, a); + return __lsx_vadd_h(tmp1, tmp2); +} + +static __m128i lsx_hadd_w(__m128i a, __m128i b) { + __m128i tmp1 = __lsx_vpickev_w(b, a); + __m128i tmp2 = __lsx_vpickod_w(b, a); + return __lsx_vadd_w(tmp1, tmp2); +} + +static __m128 lsx_hadd_s(__m128 a, __m128 b) { + __m128 tmp1 = (__m128)__lsx_vpickev_w((__m128i)b, (__m128i)a); + __m128 tmp2 = (__m128)__lsx_vpickod_w((__m128i)b, (__m128i)a); + + return __lsx_vfadd_s(tmp1, tmp2); +} + +static __m256i lasx_maddubs_h(__m256i a, __m256i b) { + __m256i tmp1, tmp2; + tmp1 = __lasx_xvmulwev_h_b(a, b); + tmp2 = __lasx_xvmulwod_h_b(a, b); + return __lasx_xvsadd_h(tmp1, tmp2); +} + +static __m256i lasx_madd_h(__m256i a, __m256i b) { + __m256i tmp1, tmp2; + tmp1 = __lasx_xvmulwev_w_h(a, b); + tmp2 = __lasx_xvmulwod_w_h(a, b); + return __lasx_xvadd_w(tmp1, tmp2); +} + +static __m256i lasx_packs_w(__m256i a, __m256i b) { + __m256i tmp, tmp1; + tmp = __lasx_xvsat_w(a, 15); + tmp1 = __lasx_xvsat_w(b, 15); + return __lasx_xvpickev_h(tmp1, tmp); +} + +static __m256i lasx_packs_h(__m256i a, __m256i b) { + __m256i tmp, tmp1; + tmp = __lasx_xvsat_h(a, 7); + tmp1 = __lasx_xvsat_h(b, 7); + return __lasx_xvpickev_b(tmp1, tmp); +} + +static __m128i lsx_packs_w(__m128i a, __m128i b) { + __m128i tmp, tmp1; + tmp = __lsx_vsat_w(a, 15); + tmp1 = __lsx_vsat_w(b, 15); + return __lsx_vpickev_h(tmp1, tmp); +} + +static __m128i lsx_packs_h(__m128i a, __m128i b) { + __m128i tmp, tmp1; + tmp = __lsx_vsat_h(a, 7); + tmp1 = __lsx_vsat_h(b, 7); + return __lsx_vpickev_b(tmp1, tmp); +} + +static __m128i lsx_packus_h(__m128i a, __m128i b) { + __m128i tmp, tmp1; + tmp = __lsx_vsat_hu(a, 7); + tmp1 = __lsx_vsat_hu(b, 7); + return __lsx_vpickev_b(tmp1, tmp); +} + + +static __m128i lsx_maddubs_h(__m128i a, __m128i b) { + __m128i tmp1, tmp2; + tmp1 = __lsx_vmulwev_h_b(a, b); + tmp2 = __lsx_vmulwod_h_b(a, b); + return __lsx_vsadd_h(tmp1, tmp2); +} + +static __m128i lsx_madd_h(__m128i a, __m128i b) { + __m128i tmp1, tmp2; + tmp1 = __lsx_vmulwev_w_h(a, b); + tmp2 = __lsx_vmulwod_w_h(a, b); + return __lsx_vadd_w(tmp1, tmp2); +} + +// multiply int8_t, add results pairwise twice +static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) { + // Get absolute values of x vectors + const __m128i ax = __lsx_vsigncov_b(x, x); + // Sign the values of the y vectors + const __m128i sy = __lsx_vsigncov_b(x, y); + // Perform multiplication and create 16-bit values + const __m128i dot = lsx_maddubs_h(ax, sy); + const __m128i ones = __lsx_vreplgr2vr_h(1); + return lsx_madd_h(ones, dot); +} + +// horizontally add 8 floats +static inline float hsum_float_8(const __m256 x) { + __m128 res = lasx_extractf128(x, 1); + ft_union tmp; + res = __lsx_vfadd_s(res, lasx_extractf128(x, 0)); + res = __lsx_vfadd_s(res, (__m128)__lsx_vpickod_d((__m128i)res, (__m128i)res)); + res = __lsx_vfadd_s(res, (__m128)__lsx_vinsgr2vr_w(__lsx_vldi(0), __lsx_vpickve2gr_w(res, 1), 0)); + tmp.i = __lsx_vpickve2gr_w(res, 0); + return tmp.f; +} + +// horizontally add 8 int32_t +static inline int hsum_i32_8(const __m256i a) { + + __m256i tmp1 = __lasx_xvpermi_q(a, a, 0x11); + __m256i tmp2 = __lasx_xvpermi_q(a, a, 0x00); + + __m128i tmp1_128 = lasx_extracti128_lo(tmp1); + __m128i tmp2_128 = lasx_extracti128_lo(tmp2); + + __m128i sum128 = __lsx_vadd_w(tmp1_128, tmp2_128); + + __m128i ev = __lsx_vpickev_w(sum128, sum128); + __m128i od = __lsx_vpickod_w(sum128, sum128); + __m128i sum64 = __lsx_vadd_w(ev, od); + + int sum64_1, sum64_2; + sum64_1 = __lsx_vpickve2gr_w(sum64, 0); + sum64_2 = __lsx_vpickve2gr_w(sum64, 1); + + return sum64_1 + sum64_2; +} + +// horizontally add 4 int32_t +static inline int hsum_i32_4(const __m128i a) { + __m128i ev = __lsx_vpickev_w(a, a); + __m128i od = __lsx_vpickod_w(a, a); + __m128i sum64 = __lsx_vadd_w(ev, od); + + int sum64_1, sum64_2; + sum64_1 = __lsx_vpickve2gr_w(sum64, 0); + sum64_2 = __lsx_vpickve2gr_w(sum64, 1); + + return sum64_1 + sum64_2; +} + +// spread 32 bits to 32 bytes { 0x00, 0xFF } +static inline __m256i bytes_from_bits_32(const uint8_t * x) { + + uint32_t x32; + memcpy(&x32, x, sizeof(uint32_t)); + const __m256i shuf_mask = lasx_set_d( + 0x0303030303030303, 0x0202020202020202, + 0x0101010101010101, 0x0000000000000000); + + __m256i bytes = lasx_shuffle_b(__lasx_xvreplgr2vr_w(x32), shuf_mask); + const __m256i bit_mask = __lasx_xvreplgr2vr_d(0x7fbfdfeff7fbfdfe); + bytes = __lasx_xvor_v(bytes, bit_mask); + return __lasx_xvseq_b(bytes, __lasx_xvreplgr2vr_d(-1)); +} + +// Unpack 32 4-bit fields into 32 bytes +// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval +static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) { + const __m128i lo = __lsx_vld((const __m128i *)rsi, 0); + __m128i hi = __lsx_vsrli_h(lo, 4); + return __lasx_xvandi_b(lasx_insertf128(hi, lo), 0xf); +} + +// add int16_t pairwise and return as float vector +static inline __m256 sum_i16_pairs_float(const __m256i x) { + __m256i v = __lasx_xvpackod_h(x, x); + __m256i summed_pairs = __lasx_xvaddwev_w_h(x, v); + return __lasx_xvffint_s_w(summed_pairs); +} + +static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) { + // Perform multiplication and create 16-bit values + const __m256i dot = lasx_maddubs_h(ax, sy); + return sum_i16_pairs_float(dot); +} + +// multiply int8_t, add results pairwise twice and return as float vector +static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) { + + // Get absolute values of x vectors + const __m256i ax = __lasx_xvsigncov_b(x, x); + // Sign the values of the y vectors + const __m256i sy = __lasx_xvsigncov_b(x, y); + + return mul_sum_us8_pairs_float(ax, sy); +} + +static inline __m128i packNibbles( __m256i bytes ) { + // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh + const __m256i lowByte = __lasx_xvreplgr2vr_h(0xFF); + __m256i high = __lasx_xvandn_v(lowByte, bytes); + __m256i low = __lasx_xvand_v(lowByte, bytes); + high = __lasx_xvsrli_h(high, 4); + bytes = __lasx_xvor_v(low, high); + // Compress uint16_t lanes into bytes + __m128i *r0 = (__m128i *)&bytes; + __m256i tmp_h128 = __lasx_xvpermi_q(bytes, bytes, 0x11); + __m128i *r1 = (__m128i *)&tmp_h128; + + __m128i zero = __lsx_vldi(0); + __m128i tmp, tmp2, tmp3; + + tmp = __lsx_vmax_h(zero, *r0); + tmp2 = __lsx_vsat_hu(tmp, 7); + + tmp = __lsx_vmax_h(zero, *r1); + tmp3 = __lsx_vsat_hu(tmp, 7); + return __lsx_vpickev_b(tmp3, tmp2); +} +#endif //__loongarch_asx + // reference implementation for deterministic creation of model files void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int64_t k) { static const int qk = QK4_0; @@ -675,6 +1072,7 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k) // store result __riscv_vse8_v_i8m1(y[i].qs , vs, vl); } + #elif defined(__POWER9_VECTOR__) for (int i = 0; i < nb; i++) { vector float srcv [8]; @@ -706,6 +1104,69 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k) } vec_xst(vec_pack(vec_pack(vi[0], vi[1]), vec_pack(vi[2], vi[3])), 0, &y[i].qs[0]); vec_xst(vec_pack(vec_pack(vi[4], vi[5]), vec_pack(vi[6], vi[7])), 16, &y[i].qs[0]); + +#elif defined(__loongarch_asx) + for (int i = 0; i < nb; i++) { + ft_union fi; + __m256 v0 = (__m256)__lasx_xvld( x , 0); + __m256 v1 = (__m256)__lasx_xvld( x , 32); + __m256 v2 = (__m256)__lasx_xvld( x , 64); + __m256 v3 = (__m256)__lasx_xvld( x , 96); + x += 32; + + // Compute max(abs(e)) for the block + const __m256 sign_bit = __lasx_xvreplfr2vr_s( -0.0f ); + __m256 max_abs = (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v0 ); + max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v1 ) ); + max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v2 ) ); + max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v3 ) ); + + __m128 max4 = __lsx_vfmax_s( lasx_extractf128( max_abs, 1 ), lasx_extractf128( max_abs , 0) ); + max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vpickod_d((__m128i) max4, (__m128i)max4 ) ); + __m128 tmp = max4; + max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vinsgr2vr_w(tmp, __lsx_vpickve2gr_w( max4, 1 ), 0 )); + fi.i = __lsx_vpickve2gr_w( (__m128i)max4, 0 ); + const float max_scalar = fi.f; + + // Quantize these floats + const float d = max_scalar / 127.f; + y[i].d = GGML_FP32_TO_FP16(d); + const float id = ( max_scalar != 0.0f ) ? 127.f / max_scalar : 0.0f; + const __m256 mul = (__m256)__lasx_xvreplfr2vr_s( id ); + + // Apply the multiplier + v0 = __lasx_xvfmul_s( v0, mul ); + v1 = __lasx_xvfmul_s( v1, mul ); + v2 = __lasx_xvfmul_s( v2, mul ); + v3 = __lasx_xvfmul_s( v3, mul ); + + // Round to nearest integer + __m256i i0 = __lasx_xvftintrne_w_s( v0 ); + __m256i i1 = __lasx_xvftintrne_w_s( v1 ); + __m256i i2 = __lasx_xvftintrne_w_s( v2 ); + __m256i i3 = __lasx_xvftintrne_w_s( v3 ); + + __m128i ni0 = lasx_extracti128( i0, 0 ); + __m128i ni1 = lasx_extracti128( i0, 1); + __m128i ni2 = lasx_extracti128( i1, 0); + __m128i ni3 = lasx_extracti128( i1, 1); + __m128i ni4 = lasx_extracti128( i2, 0); + __m128i ni5 = lasx_extracti128( i2, 1); + __m128i ni6 = lasx_extracti128( i3, 0); + __m128i ni7 = lasx_extracti128( i3, 1); + + // Convert int32 to int16 + ni0 = lsx_packs_w( ni0, ni1 ); + ni2 = lsx_packs_w( ni2, ni3 ); + ni4 = lsx_packs_w( ni4, ni5 ); + ni6 = lsx_packs_w( ni6, ni7 ); + // Convert int16 to int8 + ni0 = lsx_packs_h( ni0, ni2 ); + ni4 = lsx_packs_h( ni4, ni6 ); + + __lsx_vst(ni0, (__m128i *)(y[i].qs + 0), 0); + __lsx_vst(ni4, (__m128i *)(y[i].qs + 16), 0); + } #else GGML_UNUSED(nb); @@ -854,12 +1315,12 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k) __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) ); max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) ); max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) ); - const float maxScalar = _mm_cvtss_f32( max4 ); + const float max_scalar = _mm_cvtss_f32( max4 ); // Quantize these floats - const float d = maxScalar / 127.f; + const float d = max_scalar / 127.f; y[i].d = GGML_FP32_TO_FP16(d); - const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f; + const float id = ( max_scalar != 0.0f ) ? 127.f / max_scalar : 0.0f; const __m256 mul = _mm256_set1_ps( id ); // Apply the multiplier @@ -962,6 +1423,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k) int sum = __riscv_vmv_x_s_i16m1_i16(vwrs); y[i].s = GGML_FP32_TO_FP16(sum*d); } + #elif defined(__POWER9_VECTOR__) for (int i = 0; i < nb; i++) { vector float srcv [8]; @@ -1001,6 +1463,73 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k) accv = vec_add(accv, vec_sld(accv, accv, 4)); accv = vec_add(accv, vec_sld(accv, accv, 8)); y[i].s = GGML_FP32_TO_FP16(d * vec_extract(accv, 0)); + +#elif defined(__loongarch_asx) + for (int i = 0; i < nb; i++) { + ft_union ft; + __m256 v0 = (__m256)__lasx_xvld( x , 0 ); + __m256 v1 = (__m256)__lasx_xvld( x , 32 ); + __m256 v2 = (__m256)__lasx_xvld( x , 64 ); + __m256 v3 = (__m256)__lasx_xvld( x , 96 ); + x += 32; + + // Compute max(abs(e)) for the block + const __m256 sign_bit = __lasx_xvreplfr2vr_s( -0.0f ); + __m256 max_abs = (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v0 ); + max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v1 ) ); + max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v2 ) ); + max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v3 ) ); + + __m128 max4 = __lsx_vfmax_s( lasx_extractf128( max_abs, 1 ), lasx_extractf128( max_abs, 0) ); + max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vpickod_d((__m128i) max4, (__m128i)max4 ) ); + __m128 tmp = max4; + max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vextrins_w((__m128i)tmp, (__m128i)max4, 0x10 )); + ft.i = __lsx_vpickve2gr_w( (__m128i)max4, 0 ); + const float max_scalar = ft.f; + + // Quantize these floats + const float d = max_scalar / 127.f; + y[i].d = GGML_FP32_TO_FP16(d); + const float id = ( max_scalar != 0.0f ) ? 127.f / max_scalar : 0.0f; + const __m256 mul = __lasx_xvreplfr2vr_s( id ); + + // Apply the multiplier + v0 = __lasx_xvfmul_s( v0, mul ); + v1 = __lasx_xvfmul_s( v1, mul ); + v2 = __lasx_xvfmul_s( v2, mul ); + v3 = __lasx_xvfmul_s( v3, mul ); + + // Round to nearest integer + __m256i i0 = __lasx_xvftintrne_w_s( v0 ); + __m256i i1 = __lasx_xvftintrne_w_s( v1 ); + __m256i i2 = __lasx_xvftintrne_w_s( v2 ); + __m256i i3 = __lasx_xvftintrne_w_s( v3 ); + + __m128i ni0 = lasx_extracti128(i0, 0); + __m128i ni1 = lasx_extracti128( i0, 1); + __m128i ni2 = lasx_extracti128( i1, 0); + __m128i ni3 = lasx_extracti128( i1, 1); + __m128i ni4 = lasx_extracti128( i2, 0 ); + __m128i ni5 = lasx_extracti128( i2, 1); + __m128i ni6 = lasx_extracti128( i3, 0); + __m128i ni7 = lasx_extracti128( i3, 1); + + // Compute the sum of the quants and set y[i].s + const __m128i s0 = __lsx_vadd_w(__lsx_vadd_w(ni0, ni1), __lsx_vadd_w(ni2, ni3)); + const __m128i s1 = __lsx_vadd_w(__lsx_vadd_w(ni4, ni5), __lsx_vadd_w(ni6, ni7)); + y[i].s = GGML_FP32_TO_FP16(d * hsum_i32_4(__lsx_vadd_w(s0, s1))); + + // Convert int32 to int16 + ni0 = lsx_packs_w( ni0, ni1 ); + ni2 = lsx_packs_w( ni2, ni3 ); + ni4 = lsx_packs_w( ni4, ni5 ); + ni6 = lsx_packs_w( ni6, ni7 ); + // Convert int16 to int8 + ni0 = lsx_packs_h( ni0, ni2 ); + ni4 = lsx_packs_h( ni4, ni6 ); + + __lsx_vst(ni0, (__m128i *)(y[i].qs + 0), 0); + __lsx_vst(ni4, (__m128i *)(y[i].qs + 16), 0); } #else GGML_UNUSED(nb); @@ -1175,7 +1704,7 @@ static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t * sumlx += w*x[i]*l; suml2 += w*l*l; } - float scale = sumlx/suml2; + float scale = suml2 ? sumlx/suml2 : 0.0f; if (return_early) return suml2 > 0 ? 0.5f*(scale + 1/iscale) : 1/iscale; float best = scale * sumlx; for (int is = -9; is <= 9; ++is) { @@ -1385,7 +1914,6 @@ static float make_qkx2_quants(int n, int nmax, const float * restrict x, const f return scale; } -#if QK_K == 256 static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t * restrict d, uint8_t * restrict m) { if (j < 4) { *d = q[j] & 63; *m = q[j + 4] & 63; @@ -1394,7 +1922,6 @@ static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t * *m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4); } } -#endif //========================- 2-bit (de)-quantization @@ -1458,20 +1985,13 @@ void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict } } -#if QK_K == 256 for (int j = 0; j < QK_K; j += 128) { for (int l = 0; l < 32; ++l) { y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6); } } -#else - for (int l = 0; l < 16; ++l) { - y[i].qs[l] = L[l] | (L[l + 16] << 2) | (L[l + 32] << 4) | (L[l + 48] << 6); - } -#endif x += QK_K; - } } @@ -1486,7 +2006,6 @@ void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int6 const uint8_t * q = x[i].qs; -#if QK_K == 256 int is = 0; float dl, ml; for (int n = 0; n < QK_K; n += 128) { @@ -1505,19 +2024,6 @@ void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int6 } q += 32; } -#else - float dl1 = d * (x[i].scales[0] & 0xF), ml1 = min * (x[i].scales[0] >> 4); - float dl2 = d * (x[i].scales[1] & 0xF), ml2 = min * (x[i].scales[1] >> 4); - float dl3 = d * (x[i].scales[2] & 0xF), ml3 = min * (x[i].scales[2] >> 4); - float dl4 = d * (x[i].scales[3] & 0xF), ml4 = min * (x[i].scales[3] >> 4); - for (int l = 0; l < 16; ++l) { - y[l+ 0] = dl1 * ((int8_t)((q[l] >> 0) & 3)) - ml1; - y[l+16] = dl2 * ((int8_t)((q[l] >> 2) & 3)) - ml2; - y[l+32] = dl3 * ((int8_t)((q[l] >> 4) & 3)) - ml3; - y[l+48] = dl4 * ((int8_t)((q[l] >> 6) & 3)) - ml4; - } - y += QK_K; -#endif } } @@ -1708,36 +2214,9 @@ static void quantize_row_q2_K_impl(const float * restrict x, block_q2_K * restri } float dm, mm; -#if QK_K == 64 - float max_scale = 0, max_min = 0; - for (int j = 0; j < QK_K/16; ++j) { - max_scale = MAX(max_scale, scales[j]); - max_min = MAX(max_min, mins[j]); - } - dm = max_scale/15; - mm = max_min/15; - if (max_scale) { - float id = 1/dm; - for (int j = 0; j < QK_K/16; ++j) { - int l = nearest_int(id*scales[j]); - Ls[j] = MAX(0, MIN(15, l)); - } - } else { - memset(Ls, 0, QK_K/16); - } - if (max_min) { - float id = 1/mm; - for (int j = 0; j < QK_K/16; ++j) { - int l = nearest_int(id*mins[j]); - Lm[j] = MAX(0, MIN(15, l)); - } - } else { - memset(Lm, 0, QK_K/16); - } -#else dm = make_qp_quants(QK_K/16, 15, scales, Ls, sw); mm = make_qp_quants(QK_K/16, 15, mins, Lm, sw); -#endif + y[i].d = GGML_FP32_TO_FP16(dm); y[i].dmin = GGML_FP32_TO_FP16(mm); dm = GGML_FP16_TO_FP32(y[i].d); @@ -1760,20 +2239,13 @@ static void quantize_row_q2_K_impl(const float * restrict x, block_q2_K * restri } } -#if QK_K == 256 for (int j = 0; j < QK_K; j += 128) { for (int l = 0; l < 32; ++l) { y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6); } } -#else - for (int l = 0; l < 16; ++l) { - y[i].qs[l] = L[l] | (L[l + 16] << 2) | (L[l + 32] << 4) | (L[l + 48] << 6); - } -#endif x += QK_K; - } } @@ -1814,7 +2286,6 @@ void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict } } -#if QK_K == 256 memset(y[i].scales, 0, 12); if (max_scale) { float iscale = -32.f/max_scale; @@ -1848,36 +2319,6 @@ void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict L[16*j + ii] = l + 4; } } -#else - if (max_scale) { - float iscale = -8.f/max_scale; - for (int j = 0; j < QK_K/16; j+=2) { - int l1 = nearest_int(iscale*scales[j]); - l1 = 8 + MAX(-8, MIN(7, l1)); - int l2 = nearest_int(iscale*scales[j+1]); - l2 = 8 + MAX(-8, MIN(7, l2)); - y[i].scales[j/2] = l1 | (l2 << 4); - } - y[i].d = GGML_FP32_TO_FP16(1/iscale); - } else { - for (int j = 0; j < QK_K/16; j+=2) { - y[i].scales[j/2] = 0; - } - y[i].d = GGML_FP32_TO_FP16(0.f); - } - for (int j = 0; j < QK_K/16; ++j) { - int s = j%2 == 0 ? y[i].scales[j/2] & 0xF : y[i].scales[j/2] >> 4; - float d = GGML_FP16_TO_FP32(y[i].d) * (s - 8); - if (!d) { - continue; - } - for (int ii = 0; ii < 16; ++ii) { - int l = nearest_int(x[16*j + ii]/d); - l = MAX(-4, MIN(3, l)); - L[16*j + ii] = l + 4; - } - } -#endif memset(y[i].hmask, 0, QK_K/8); // We put the high-bit for the 1st 8 quants into bit 0, the next 8 into bit 1, etc. @@ -1892,23 +2333,16 @@ void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict m = 0; hm <<= 1; } } -#if QK_K == 256 for (int j = 0; j < QK_K; j += 128) { for (int l = 0; l < 32; ++l) { y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6); } } -#else - for (int l = 0; l < 16; ++l) { - y[i].qs[l] = L[l] | (L[l + 16] << 2) | (L[l + 32] << 4) | (L[l + 48] << 6); - } -#endif x += QK_K; } } -#if QK_K == 256 void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int64_t k) { assert(k % QK_K == 0); const int nb = k / QK_K; @@ -1958,49 +2392,12 @@ void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int6 } } -#else -void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int64_t k) { - assert(k % QK_K == 0); - assert(QK_K == 64); - const int nb = k / QK_K; - - for (int i = 0; i < nb; i++) { - - const float d_all = GGML_FP16_TO_FP32(x[i].d); - - const uint8_t * restrict q = x[i].qs; - const uint8_t * restrict hm = x[i].hmask; - - const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8); - const float d2 = d_all * ((x[i].scales[0] >> 4) - 8); - const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8); - const float d4 = d_all * ((x[i].scales[1] >> 4) - 8); - - for (int l=0; l<8; ++l) { - uint8_t h = hm[l]; - y[l+ 0] = d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((h & 0x01) ? 0 : 4)); - y[l+ 8] = d1 * ((int8_t)((q[l+8] >> 0) & 3) - ((h & 0x02) ? 0 : 4)); - y[l+16] = d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((h & 0x04) ? 0 : 4)); - y[l+24] = d2 * ((int8_t)((q[l+8] >> 2) & 3) - ((h & 0x08) ? 0 : 4)); - y[l+32] = d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((h & 0x10) ? 0 : 4)); - y[l+40] = d3 * ((int8_t)((q[l+8] >> 4) & 3) - ((h & 0x20) ? 0 : 4)); - y[l+48] = d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((h & 0x40) ? 0 : 4)); - y[l+56] = d4 * ((int8_t)((q[l+8] >> 6) & 3) - ((h & 0x80) ? 0 : 4)); - } - y += QK_K; - } -} -#endif void quantize_row_q3_K(const float * restrict x, void * restrict vy, int64_t k) { quantize_row_q3_K_reference(x, vy, k); } static void quantize_row_q3_K_impl(const float * restrict x, block_q3_K * restrict y, int64_t n_per_row, const float * restrict quant_weights) { -#if QK_K != 256 - (void)quant_weights; - quantize_row_q3_K_reference(x, y, n_per_row); -#else assert(n_per_row % QK_K == 0); const int nb = n_per_row / QK_K; @@ -2082,7 +2479,6 @@ static void quantize_row_q3_K_impl(const float * restrict x, block_q3_K * restri x += QK_K; } -#endif } size_t quantize_q3_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { @@ -2114,7 +2510,6 @@ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict float scales[QK_K/32]; for (int i = 0; i < nb; i++) { - float max_scale = 0; // as we are deducting the min, scales are always positive float max_min = 0; for (int j = 0; j < QK_K/32; ++j) { @@ -2134,7 +2529,6 @@ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict } } -#if QK_K == 256 float inv_scale = max_scale > 0 ? 63.f/max_scale : 0.f; float inv_min = max_min > 0 ? 63.f/max_min : 0.f; for (int j = 0; j < QK_K/32; ++j) { @@ -2166,39 +2560,7 @@ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict L[32*j + ii] = l; } } -#else - const float s_factor = 15.f; - float inv_scale = max_scale > 0 ? s_factor/max_scale : 0.f; - float inv_min = max_min > 0 ? s_factor/max_min : 0.f; - int d1 = nearest_int(inv_scale*scales[0]); - int m1 = nearest_int(inv_min*mins[0]); - int d2 = nearest_int(inv_scale*scales[1]); - int m2 = nearest_int(inv_min*mins[1]); - y[i].scales[0] = d1 | (m1 << 4); - y[i].scales[1] = d2 | (m2 << 4); - y[i].d[0] = GGML_FP32_TO_FP16(max_scale/s_factor); - y[i].d[1] = GGML_FP32_TO_FP16(max_min/s_factor); - float sumlx = 0; - int suml2 = 0; - for (int j = 0; j < QK_K/32; ++j) { - const uint8_t sd = y[i].scales[j] & 0xF; - const uint8_t sm = y[i].scales[j] >> 4; - const float d = GGML_FP16_TO_FP32(y[i].d[0]) * sd; - if (!d) continue; - const float m = GGML_FP16_TO_FP32(y[i].d[1]) * sm; - for (int ii = 0; ii < 32; ++ii) { - int l = nearest_int((x[32*j + ii] + m)/d); - l = MAX(0, MIN(15, l)); - L[32*j + ii] = l; - sumlx += (x[32*j + ii] + m)*l*sd; - suml2 += l*l*sd*sd; - } - } - if (suml2) { - y[i].d[0] = GGML_FP32_TO_FP16(sumlx/suml2); - } -#endif uint8_t * q = y[i].qs; for (int j = 0; j < QK_K; j += 64) { for (int l = 0; l < 32; ++l) q[l] = L[j + l] | (L[j + l + 32] << 4); @@ -2206,7 +2568,6 @@ void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict } x += QK_K; - } } @@ -2215,11 +2576,8 @@ void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int6 const int nb = k / QK_K; for (int i = 0; i < nb; i++) { - const uint8_t * q = x[i].qs; -#if QK_K == 256 - const float d = GGML_FP16_TO_FP32(x[i].d); const float min = GGML_FP16_TO_FP32(x[i].dmin); @@ -2234,18 +2592,6 @@ void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int6 for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2; q += 32; is += 2; } -#else - const float dall = GGML_FP16_TO_FP32(x[i].d[0]); - const float mall = GGML_FP16_TO_FP32(x[i].d[1]); - const float d1 = dall * (x[i].scales[0] & 0xF), m1 = mall * (x[i].scales[0] >> 4); - const float d2 = dall * (x[i].scales[1] & 0xF), m2 = mall * (x[i].scales[1] >> 4); - for (int l = 0; l < 32; ++l) { - y[l+ 0] = d1 * (q[l] & 0xF) - m1; - y[l+32] = d2 * (q[l] >> 4) - m2; - } - y += QK_K; -#endif - } } @@ -2256,10 +2602,6 @@ void quantize_row_q4_K(const float * restrict x, void * restrict vy, int64_t k) } static void quantize_row_q4_K_impl(const float * restrict x, block_q4_K * restrict y, int64_t n_per_row, const float * quant_weights) { -#if QK_K != 256 - (void)quant_weights; - quantize_row_q4_K_reference(x, y, n_per_row); -#else assert(n_per_row % QK_K == 0); const int64_t nb = n_per_row / QK_K; @@ -2330,7 +2672,6 @@ static void quantize_row_q4_K_impl(const float * restrict x, block_q4_K * restri x += QK_K; } -#endif } size_t quantize_q4_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { @@ -2355,21 +2696,13 @@ void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict assert(k % QK_K == 0); const int64_t nb = k / QK_K; -#if QK_K == 256 uint8_t L[QK_K]; float mins[QK_K/32]; float scales[QK_K/32]; float weights[32]; uint8_t Laux[32]; -#else - int8_t L[QK_K]; - float scales[QK_K/16]; -#endif for (int i = 0; i < nb; i++) { - -#if QK_K == 256 - float max_scale = 0; // as we are deducting the min, scales are always positive float max_min = 0; for (int j = 0; j < QK_K/32; ++j) { @@ -2441,55 +2774,8 @@ void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict m1 <<= 2; m2 <<= 2; ql += 32; } -#else - float max_scale = 0, amax = 0; - for (int j = 0; j < QK_K/16; ++j) { - scales[j] = make_qx_quants(16, 16, x + 16*j, L + 16*j, 1, NULL); - float abs_scale = fabsf(scales[j]); - if (abs_scale > amax) { - amax = abs_scale; - max_scale = scales[j]; - } - } - - float iscale = -128.f/max_scale; - for (int j = 0; j < QK_K/16; ++j) { - int l = nearest_int(iscale*scales[j]); - y[i].scales[j] = MAX(-128, MIN(127, l)); - } - y[i].d = GGML_FP32_TO_FP16(1/iscale); - - for (int j = 0; j < QK_K/16; ++j) { - const float d = GGML_FP16_TO_FP32(y[i].d) * y[i].scales[j]; - if (!d) continue; - for (int ii = 0; ii < 16; ++ii) { - int l = nearest_int(x[16*j + ii]/d); - l = MAX(-16, MIN(15, l)); - L[16*j + ii] = l + 16; - } - } - - uint8_t * restrict qh = y[i].qh; - uint8_t * restrict ql = y[i].qs; - memset(qh, 0, QK_K/8); - - for (int j = 0; j < 32; ++j) { - int jm = j%8; - int is = j/8; - int l1 = L[j]; - if (l1 > 15) { - l1 -= 16; qh[jm] |= (1 << is); - } - int l2 = L[j + 32]; - if (l2 > 15) { - l2 -= 16; qh[jm] |= (1 << (4 + is)); - } - ql[j] = l1 | (l2 << 4); - } -#endif x += QK_K; - } } @@ -2498,12 +2784,9 @@ void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int6 const int64_t nb = k / QK_K; for (int i = 0; i < nb; i++) { - const uint8_t * ql = x[i].qs; const uint8_t * qh = x[i].qh; -#if QK_K == 256 - const float d = GGML_FP16_TO_FP32(x[i].d); const float min = GGML_FP16_TO_FP32(x[i].dmin); @@ -2520,21 +2803,6 @@ void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int6 ql += 32; is += 2; u1 <<= 2; u2 <<= 2; } -#else - float d = GGML_FP16_TO_FP32(x[i].d); - const int8_t * restrict s = x[i].scales; - for (int l = 0; l < 8; ++l) { - y[l+ 0] = d * s[0] * ((ql[l+ 0] & 0xF) - (qh[l] & 0x01 ? 0 : 16)); - y[l+ 8] = d * s[0] * ((ql[l+ 8] & 0xF) - (qh[l] & 0x02 ? 0 : 16)); - y[l+16] = d * s[1] * ((ql[l+16] & 0xF) - (qh[l] & 0x04 ? 0 : 16)); - y[l+24] = d * s[1] * ((ql[l+24] & 0xF) - (qh[l] & 0x08 ? 0 : 16)); - y[l+32] = d * s[2] * ((ql[l+ 0] >> 4) - (qh[l] & 0x10 ? 0 : 16)); - y[l+40] = d * s[2] * ((ql[l+ 8] >> 4) - (qh[l] & 0x20 ? 0 : 16)); - y[l+48] = d * s[3] * ((ql[l+16] >> 4) - (qh[l] & 0x40 ? 0 : 16)); - y[l+56] = d * s[3] * ((ql[l+24] >> 4) - (qh[l] & 0x80 ? 0 : 16)); - } - y += QK_K; -#endif } } @@ -2545,10 +2813,6 @@ void quantize_row_q5_K(const float * restrict x, void * restrict vy, int64_t k) } static void quantize_row_q5_K_impl(const float * restrict x, block_q5_K * restrict y, int64_t n_per_row, const float * quant_weights) { -#if QK_K != 256 - (void)quant_weights; - quantize_row_q5_K_reference(x, y, n_per_row); -#else assert(n_per_row % QK_K == 0); const int64_t nb = n_per_row / QK_K; @@ -2639,7 +2903,6 @@ static void quantize_row_q5_K_impl(const float * restrict x, block_q5_K * restri x += QK_K; } -#endif } size_t quantize_q5_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { @@ -2712,7 +2975,6 @@ void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict uint8_t * restrict ql = y[i].ql; uint8_t * restrict qh = y[i].qh; -#if QK_K == 256 for (int j = 0; j < QK_K; j += 128) { for (int l = 0; l < 32; ++l) { const uint8_t q1 = L[j + l + 0] & 0xF; @@ -2726,19 +2988,8 @@ void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict ql += 64; qh += 32; } -#else - for (int l = 0; l < 32; ++l) { - const uint8_t q1 = L[l + 0] & 0xF; - const uint8_t q2 = L[l + 32] & 0xF; - ql[l] = q1 | (q2 << 4); - } - for (int l = 0; l < 16; ++l) { - qh[l] = (L[l] >> 4) | ((L[l + 16] >> 4) << 2) | ((L[l + 32] >> 4) << 4) | ((L[l + 48] >> 4) << 6); - } -#endif x += QK_K; - } } @@ -2747,14 +2998,12 @@ void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int6 const int64_t nb = k / QK_K; for (int i = 0; i < nb; i++) { - const float d = GGML_FP16_TO_FP32(x[i].d); const uint8_t * restrict ql = x[i].ql; const uint8_t * restrict qh = x[i].qh; const int8_t * restrict sc = x[i].scales; -#if QK_K == 256 for (int n = 0; n < QK_K; n += 128) { for (int l = 0; l < 32; ++l) { int is = l/16; @@ -2772,20 +3021,6 @@ void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int6 qh += 32; sc += 8; } -#else - for (int l = 0; l < 16; ++l) { - const int8_t q1 = (int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; - const int8_t q2 = (int8_t)((ql[l+16] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; - const int8_t q3 = (int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; - const int8_t q4 = (int8_t)((ql[l+16] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; - y[l+ 0] = d * sc[0] * q1; - y[l+16] = d * sc[1] * q2; - y[l+32] = d * sc[2] * q3; - y[l+48] = d * sc[3] * q4; - } - y += 64; -#endif - } } @@ -2796,10 +3031,6 @@ void quantize_row_q6_K(const float * restrict x, void * restrict vy, int64_t k) } static void quantize_row_q6_K_impl(const float * restrict x, block_q6_K * restrict y, int64_t n_per_row, const float * quant_weights) { -#if QK_K != 256 - (void)quant_weights; - quantize_row_q6_K_reference(x, y, n_per_row); -#else assert(n_per_row % QK_K == 0); const int64_t nb = n_per_row / QK_K; @@ -2881,7 +3112,6 @@ static void quantize_row_q6_K_impl(const float * restrict x, block_q6_K * restri x += QK_K; } -#endif } size_t quantize_q6_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { @@ -3298,30 +3528,21 @@ void dequantize_row_iq1_m(const block_iq1_m * restrict x, float * restrict y, in float delta[4]; uint16_t idx[4]; -#if QK_K != 64 iq1m_scale_t scale; -#endif for (int i = 0; i < nb; i++) { const uint16_t * sc = (const uint16_t *)x[i].scales; -#if QK_K == 64 - const float d = GGML_FP16_TO_FP32(x[i].d); -#else scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); const float d = GGML_FP16_TO_FP32(scale.f16); -#endif + const uint8_t * qs = x[i].qs; const uint8_t * qh = x[i].qh; for (int ib = 0; ib < QK_K/32; ++ib) { -#if QK_K == 64 - const float dl1 = d * (2*((sc[ib/2] >> (8*(ib%2)+0)) & 0xf) + 1); - const float dl2 = d * (2*((sc[ib/2] >> (8*(ib%2)+4)) & 0xf) + 1); -#else const float dl1 = d * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 0x7) + 1); const float dl2 = d * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 0x7) + 1); -#endif + idx[0] = qs[0] | ((qh[0] << 8) & 0x700); idx[1] = qs[1] | ((qh[0] << 4) & 0x700); idx[2] = qs[2] | ((qh[1] << 8) & 0x700); @@ -3372,9 +3593,6 @@ void dequantize_row_iq4_nl(const block_iq4_nl * restrict x, float * restrict y, void dequantize_row_iq4_xs(const block_iq4_xs * restrict x, float * restrict y, int64_t k) { assert(k % QK_K == 0); -#if QK_K == 64 - dequantize_row_iq4_nl((const block_iq4_nl *)x, y, k); -#else const int64_t nb = k / QK_K; for (int i = 0; i < nb; i++) { @@ -3394,7 +3612,6 @@ void dequantize_row_iq4_xs(const block_iq4_xs * restrict x, float * restrict y, qs += 16; } } -#endif } //===================================== Q8_K ============================================== @@ -3496,6 +3713,43 @@ static inline __m128i get_scale_shuffle(int i) { }; return _mm_loadu_si128((const __m128i*)k_shuffle + i); } +#elif defined(__loongarch_asx) +// shuffles to pick the required scales in dot products +static inline __m256i get_scale_shuffle_q3k(int i) { + static const uint8_t k_shuffle[128] = { + 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, + 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, + 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11, + 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13, 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15, + }; + return __lasx_xvld((const __m256i*)k_shuffle + i, 0); +} +static inline __m256i get_scale_shuffle_k4(int i) { + static const uint8_t k_shuffle[256] = { + 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, + 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, + 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, + 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, + 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, + 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11, + 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13, + 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15 + }; + return __lasx_xvld((const __m256i*)k_shuffle + i, 0); +} +static inline __m128i get_scale_shuffle(int i) { + static const uint8_t k_shuffle[128] = { + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, + 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, + 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, + 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, + 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, + 10,10,10,10,10,10,10,10, 11,11,11,11,11,11,11,11, + 12,12,12,12,12,12,12,12, 13,13,13,13,13,13,13,13, + 14,14,14,14,14,14,14,14, 15,15,15,15,15,15,15,15 + }; + return __lsx_vld((const __m128i*)k_shuffle + i, 0); +} #endif void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { @@ -3585,7 +3839,44 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r return; } #endif -#if defined(__ARM_NEON) +#if defined(__ARM_FEATURE_SVE) + const svbool_t ptrueh = svptrue_pat_b8(SV_VL16); + const svbool_t ptruel = svnot_b_z(svptrue_b8(), ptrueh); + + svfloat32_t sumv0 = svdup_n_f32(0.0f); + svfloat32_t sumv1 = svdup_n_f32(0.0f); + + assert(nb % 2 == 0); // TODO: handle odd nb + + for (int i = 0; i < nb; i += 2) { + const block_q4_0 * restrict x0 = &x[i + 0]; + const block_q4_0 * restrict x1 = &x[i + 1]; + const block_q8_0 * restrict y0 = &y[i + 0]; + const block_q8_0 * restrict y1 = &y[i + 1]; + + // load x + const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs); + const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs); + + // 4-bit -> 8-bit + const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx0r, 0x0F), 0x04)); + const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx1r, 0x0F), 0x04)); + + // sub 8 + const svint8_t qx0s = svsub_n_s8_x(svptrue_b8(), qx0, 8); + const svint8_t qx1s = svsub_n_s8_x(svptrue_b8(), qx1, 8); + + // load y + const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs); + const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs); + + // dot product + sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0s, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); + sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1s, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); + } + + *s = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1)); +#elif defined(__ARM_NEON) float32x4_t sumv0 = vdupq_n_f32(0.0f); float32x4_t sumv1 = vdupq_n_f32(0.0f); @@ -3845,6 +4136,7 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r } *s = sumf; + #elif defined(__POWER9_VECTOR__) const vector signed char lowMask = vec_splats((signed char)0xF); const vector unsigned char v4 = vec_splats((unsigned char)0x4); @@ -3885,6 +4177,149 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); *s = vec_extract(vsumf0, 0); + +#elif defined(__loongarch_asx) + // Initialize accumulator with zeros + __m256 acc = (__m256)__lasx_xvldi(0); + + // Main loop + for (int i = 0; i < nb; ++i) { + /* Compute combined scale for the block */ + const __m256 d = __lasx_xvreplfr2vr_s( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) ); + + __m256i qx = bytes_from_nibbles_32(x[i].qs); + + // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. + const __m256i off = __lasx_xvreplgr2vr_b( 8 ); + qx = __lasx_xvsub_b( qx, off ); + + __m256i qy = __lasx_xvld((const __m256i *)y[i].qs, 0); + + const __m256 q = mul_sum_i8_pairs_float(qx, qy); + + /* Multiply q with scale and accumulate */ + acc = __lasx_xvfmadd_s( d, q, acc ); + } + + *s = hsum_float_8(acc); +#elif defined(__loongarch_sx) + // set constants + const __m128i low_mask = __lsx_vreplgr2vr_b(0xF); + const __m128i off = __lsx_vreplgr2vr_b(8); + + // Initialize accumulator with zeros + __m128 acc_0 = __lsx_vldi(0); + __m128 acc_1 = __lsx_vldi(0); + __m128 acc_2 = __lsx_vldi(0); + __m128 acc_3 = __lsx_vldi(0); + + // First round without accumulation + { + _mm_prefetch(&x[0] + sizeof(block_q4_0), _MM_HINT_T0); + _mm_prefetch(&y[0] + sizeof(block_q8_0), _MM_HINT_T0); + + // Compute combined scale for the block 0 and 1 + const __m128 d_0_1 = __lsx_vreplgr2vr_w( GGML_FP16_TO_FP32(x[0].d) * GGML_FP16_TO_FP32(y[0].d) ); + + const __m128i tmp_0_1 = __lsx_vld((const __m128i *)x[0].qs, 0); + + __m128i bx_0 = __lsx_vand_v(low_mask, tmp_0_1); + __m128i by_0 = __lsx_vld((const __m128i *)y[0].qs, 0); + bx_0 = __lsx_vsub_b(bx_0, off); + const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0); + + __m128i bx_1 = __lsx_vand_v(low_mask, __lsx_vsrli_d(tmp_0_1, 4)); + __m128i by_1 = __lsx_vld((const __m128i *)(y[0].qs + 16), 0); + bx_1 = __lsx_vsub_b(bx_1, off); + const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1); + + // Compute combined scale for the block 2 and 3 + const __m128 d_2_3 = __lsx_vreplgr2vr_w( GGML_FP16_TO_FP32(x[1].d) * GGML_FP16_TO_FP32(y[1].d) ); + + const __m128i tmp_2_3 = __lsx_vld((const __m128i *)x[1].qs, 0); + + __m128i bx_2 = __lsx_vand_v(low_mask, tmp_2_3); + __m128i by_2 = __lsx_vld((const __m128i *)y[1].qs, 0); + bx_2 = __lsx_vsub_b(bx_2, off); + const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2); + + __m128i bx_3 = __lsx_vand_v(low_mask, __lsx_vsrli_d(tmp_2_3, 4)); + __m128i by_3 = __lsx_vld((const __m128i *)(y[1].qs + 16), 0); + bx_3 = __lsx_vsub_b(bx_3, off); + const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3); + + // Convert int32_t to float + __m128 p0 = __lsx_vffint_s_w(i32_0); + __m128 p1 = __lsx_vffint_s_w(i32_1); + __m128 p2 = __lsx_vffint_s_w(i32_2); + __m128 p3 = __lsx_vffint_s_w(i32_3); + + // Apply the scale + acc_0 = __lsx_vfmul_s( d_0_1, p0 ); + acc_1 = __lsx_vfmul_s( d_0_1, p1 ); + acc_2 = __lsx_vfmul_s( d_2_3, p2 ); + acc_3 = __lsx_vfmul_s( d_2_3, p3 ); + } + + assert(nb % 2 == 0); // TODO: handle odd nb + + // Main loop + for (int i = 2; i < nb; i+=2) { + + // Compute combined scale for the block 0 and 1 + const __m128 d_0_1 = __lsx_vreplgr2vr_w( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) ); + + const __m128i tmp_0_1 = __lsx_vld((const __m128i *)x[i].qs, 0); + + __m128i bx_0 = __lsx_vand_v(low_mask, tmp_0_1); + __m128i by_0 = __lsx_vld((const __m128i *)y[i].qs, 0); + bx_0 = __lsx_vsub_b(bx_0, off); + const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0); + + __m128i bx_1 = __lsx_vand_v(low_mask, __lsx_vsrli_d(tmp_0_1, 4)); + __m128i by_1 = __lsx_vld((const __m128i *)(y[i].qs + 16), 0); + bx_1 = __lsx_vsub_b(bx_1, off); + const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1); + + //_mm_prefetch(&x[i] + 2 * sizeof(block_q4_0), _MM_HINT_T0); + //_mm_prefetch(&y[i] + 2 * sizeof(block_q8_0), _MM_HINT_T0); + + // Compute combined scale for the block 2 and 3 + const __m128 d_2_3 = __lsx_vreplgr2vr_w( GGML_FP16_TO_FP32(x[i + 1].d) * GGML_FP16_TO_FP32(y[i + 1].d) ); + + const __m128i tmp_2_3 = __lsx_vld((const __m128i *)x[i + 1].qs, 0); + + __m128i bx_2 = __lsx_vand_v(low_mask, tmp_2_3); + __m128i by_2 = __lsx_vld((const __m128i *)y[i + 1].qs, 0); + bx_2 = __lsx_vsub_b(bx_2, off); + const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2); + + __m128i bx_3 = __lsx_vand_v(low_mask, __lsx_vsrli_d(tmp_2_3, 4)); + __m128i by_3 = __lsx_vld((const __m128i *)(y[i + 1].qs + 16), 0); + bx_3 = __lsx_vsub_b(bx_3, off); + const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3); + + // Convert int32_t to float + __m128 p0 = __lsx_vffint_s_w(i32_0); + __m128 p1 = __lsx_vffint_s_w(i32_1); + __m128 p2 = __lsx_vffint_s_w(i32_2); + __m128 p3 = __lsx_vffint_s_w(i32_3); + + // Apply the scale + __m128 p0_d = __lsx_vfmul_s( d_0_1, p0 ); + __m128 p1_d = __lsx_vfmul_s( d_0_1, p1 ); + __m128 p2_d = __lsx_vfmul_s( d_2_3, p2 ); + __m128 p3_d = __lsx_vfmul_s( d_2_3, p3 ); + + // Acummulate + acc_0 = __lsx_vfadd_s(p0_d, acc_0); + acc_1 = __lsx_vfadd_s(p1_d, acc_1); + acc_2 = __lsx_vfadd_s(p2_d, acc_2); + acc_3 = __lsx_vfadd_s(p3_d, acc_3); + } + + *s = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3); + #else // scalar float sumf = 0.0; @@ -4104,6 +4539,7 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r } *s = sumf; + #elif defined(__POWER9_VECTOR__) const vector signed char lowMask = vec_splats((signed char)0xF); const vector unsigned char v4 = vec_splats((unsigned char)0x4); @@ -4144,6 +4580,38 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); *s = vec_extract(vsumf0, 0); + +#elif defined(__loongarch_asx) + // Initialize accumulator with zeros + __m256 acc = (__m256)__lasx_xvldi(0); + + float summs = 0; + + // Main loop + for (int i = 0; i < nb; ++i) { + const float d0 = GGML_FP16_TO_FP32(x[i].d); + const float d1 = GGML_FP16_TO_FP32(y[i].d); + + summs += GGML_FP16_TO_FP32(x[i].m) * GGML_FP16_TO_FP32(y[i].s); + + const __m256 d0v = __lasx_xvreplfr2vr_s( d0 ); + const __m256 d1v = __lasx_xvreplfr2vr_s( d1 ); + + // Compute combined scales + const __m256 d0d1 = __lasx_xvfmul_s( d0v, d1v ); + + // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes + const __m256i qx = bytes_from_nibbles_32(x[i].qs); + const __m256i qy = __lasx_xvld( (const __m256i *)y[i].qs, 0); + + const __m256 xy = mul_sum_us8_pairs_float(qx, qy); + + // Accumulate d0*d1*x*y + acc = __lasx_xvfmadd_s( d0d1, xy, acc ); + } + + *s = hsum_float_8(acc) + summs; + #else // scalar float sumf = 0.0; @@ -4429,6 +4897,7 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r } *s = sumf; + #elif defined(__POWER9_VECTOR__) const vector signed char lowMask = vec_splats((signed char)0xF); const vector unsigned char v4 = vec_splats((unsigned char)4); @@ -4472,6 +4941,31 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); *s = vec_extract(vsumf0, 0); + +#elif defined(__loongarch_asx) + // Initialize accumulator with zeros + __m256 acc = (__m256)__lasx_xvldi(0); + + // Main loop + for (int i = 0; i < nb; i++) { + /* Compute combined scale for the block */ + const __m256 d = __lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d)); //FIXME + + __m256i qx = bytes_from_nibbles_32(x[i].qs); + __m256i bxhi = bytes_from_bits_32(x[i].qh); + bxhi = __lasx_xvandn_v(bxhi, __lasx_xvreplgr2vr_b((char)0xF0)); + qx = __lasx_xvor_v(qx, bxhi); + + __m256i qy = __lasx_xvld((const __m256i *)y[i].qs, 0); + + const __m256 q = mul_sum_i8_pairs_float(qx, qy); + + /* Multiply q with scale and accumulate */ + acc = __lasx_xvfmadd_s(d, q, acc); + } + + *s = hsum_float_8(acc); + #else // scalar float sumf = 0.0; @@ -4776,6 +5270,7 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r } *s = sumf; + #elif defined(__POWER9_VECTOR__) const vector signed char lowMask = vec_splats((signed char)0xF); const vector unsigned char v4 = vec_splats((unsigned char)0x4); @@ -4823,6 +5318,34 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); *s = vec_extract(vsumf0, 0); + +#elif defined(__loongarch_asx) + // Initialize accumulator with zeros + __m256 acc = (__m256)__lasx_xvldi(0); + + float summs = 0.0f; + + // Main loop + for (int i = 0; i < nb; i++) { + const __m256 dx = __lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(x[i].d)); + + summs += GGML_FP16_TO_FP32(x[i].m) * GGML_FP16_TO_FP32(y[i].s); + + __m256i qx = bytes_from_nibbles_32(x[i].qs); + __m256i bxhi = bytes_from_bits_32(x[i].qh); + bxhi = __lasx_xvand_v(bxhi, __lasx_xvreplgr2vr_b(0x10)); + qx = __lasx_xvor_v(qx, bxhi); + + const __m256 dy = __lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(y[i].d)); + const __m256i qy = __lasx_xvld((const __m256i *)y[i].qs, 0); + + const __m256 q = mul_sum_us8_pairs_float(qx, qy); + + acc = __lasx_xvfmadd_s(q, __lasx_xvfmul_s(dx, dy), acc); + } + + *s = hsum_float_8(acc) + summs; + #else // scalar float sumf = 0.0; @@ -4924,7 +5447,32 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r return; } #endif -#if defined(__ARM_NEON) +#if defined(__ARM_FEATURE_SVE) + svfloat32_t sumv0 = svdup_n_f32(0.0f); + svfloat32_t sumv1 = svdup_n_f32(0.0f); + + assert(nb % 2 == 0); // TODO: handle odd nb + + for (int i = 0; i < nb; i += 2) { + const block_q8_0 * restrict x0 = &x[i + 0]; + const block_q8_0 * restrict x1 = &x[i + 1]; + const block_q8_0 * restrict y0 = &y[i + 0]; + const block_q8_0 * restrict y1 = &y[i + 1]; + + // load x + const svint8_t qx0 = svld1_s8(svptrue_b8(), x0->qs); + const svint8_t qx1 = svld1_s8(svptrue_b8(), x1->qs); + + // load y + const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs); + const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs); + + sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); + sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); + } + + *s = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1)); +#elif defined(__ARM_NEON) float32x4_t sumv0 = vdupq_n_f32(0.0f); float32x4_t sumv1 = vdupq_n_f32(0.0f); @@ -4999,6 +5547,7 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r } *s = sumf; + #elif defined(__POWER9_VECTOR__) vector float vsumf0 = vec_splats(0.0f); @@ -5038,6 +5587,26 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); *s = vec_extract(vsumf0, 0); + +#elif defined(__loongarch_asx) + // Initialize accumulator with zeros + __m256 acc = (__m256)__lasx_xvldi(0); + + // Main loop + for (int i = 0; i < nb; ++i) { + // Compute combined scale for the block + const __m256 d = __lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d)); + __m256i qx = __lasx_xvld((const __m256i *)x[i].qs, 0); + __m256i qy = __lasx_xvld((const __m256i *)y[i].qs, 0); + + const __m256 q = mul_sum_i8_pairs_float(qx, qy); + + // Multiply q with scale and accumulate + acc = __lasx_xvfmadd_s( d, q, acc ); + } + + *s = hsum_float_8(acc); + #else // scalar float sumf = 0.0; @@ -5056,7 +5625,6 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r #endif } -#if QK_K == 256 void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { assert(nrc == 1); UNUSED(nrc); @@ -5442,8 +6010,6 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r vector signed int vsumi6 = vec_splats((int32_t)0); vector signed int vsumi7 = vec_splats((int32_t)0); - const uint8_t * restrict q2 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; for (int j = 0; j < QK_K/128; ++j) { __builtin_prefetch(q2, 0, 1); @@ -5534,6 +6100,72 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r *s = vec_extract(vsumf0, 0); +#elif defined __loongarch_asx + + const __m256i m3 = __lasx_xvreplgr2vr_b(3); + const __m128i m4 = __lsx_vreplgr2vr_b(0xF); + + __m256 acc = (__m256)__lasx_xvldi(0); + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + const uint8_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const __m128i mins_and_scales = __lsx_vld((const __m128i*)x[i].scales, 0); + const __m128i scales8 = __lsx_vand_v(mins_and_scales, m4); + const __m128i mins8 = __lsx_vand_v(__lsx_vsrli_h(mins_and_scales, 4), m4); + const __m256i mins = lasx_ext8_16(mins8); + const __m256i prod = lasx_madd_h(mins, __lasx_xvld((const __m256i*)y[i].bsums, 0)); + + acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(dmin), __lasx_xvffint_s_w(prod), acc); + + const __m256i all_scales = lasx_ext8_16(scales8); + const __m128i l_scales = lasx_extracti128(all_scales, 0); + const __m128i h_scales = lasx_extracti128(all_scales, 1); + const __m256i scales[2] = {lasx_insertf128(l_scales, l_scales), lasx_insertf128(h_scales, h_scales)}; + + __m256i sumi = __lasx_xvldi(0); + + for (int j = 0; j < QK_K/128; ++j) { + + const __m256i q2bits = __lasx_xvld((const __m256i*)q2, 0); q2 += 32; + + const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; + const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; + const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; + const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; + + const __m256i q2_0 = __lasx_xvand_v(q2bits, m3); + const __m256i q2_1 = __lasx_xvand_v(__lasx_xvsrli_h(q2bits, 2), m3); + const __m256i q2_2 = __lasx_xvand_v(__lasx_xvsrli_h(q2bits, 4), m3); + const __m256i q2_3 = __lasx_xvand_v(__lasx_xvsrli_h(q2bits, 6), m3); + + __m256i p0 = lasx_maddubs_h(q2_0, q8_0); + __m256i p1 = lasx_maddubs_h(q2_1, q8_1); + __m256i p2 = lasx_maddubs_h(q2_2, q8_2); + __m256i p3 = lasx_maddubs_h(q2_3, q8_3); + + p0 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(0)), p0); + p1 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(1)), p1); + p2 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(2)), p2); + p3 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(3)), p3); + + p0 = __lasx_xvadd_w(p0, p1); + p2 = __lasx_xvadd_w(p2, p3); + + sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p0, p2)); + } + + acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), acc); + + } + + *s = hsum_float_8(acc); + #else float sumf = 0; @@ -5577,352 +6209,6 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r #endif } -#else - -void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { - assert(nrc == 1); - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - - const block_q2_K * restrict x = vx; - const block_q8_K * restrict y = vy; - - const int nb = n / QK_K; - -#ifdef __ARM_NEON - const uint8x16_t m3 = vdupq_n_u8(0x3); - - const int32x4_t vzero = vdupq_n_s32(0); - - ggml_int8x16x4_t q2bytes; - - uint32_t aux32[2]; - const uint8_t * scales = (const uint8_t *)aux32; - - float sum = 0; - - for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); - - const uint8_t * restrict q2 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - const uint32_t * restrict sc = (const uint32_t *)x[i].scales; - - aux32[0] = sc[0] & 0x0f0f0f0f; - aux32[1] = (sc[0] >> 4) & 0x0f0f0f0f; - - sum += dmin * (scales[4] * y[i].bsums[0] + scales[5] * y[i].bsums[1] + scales[6] * y[i].bsums[2] + scales[7] * y[i].bsums[3]); - - int isum1 = 0, isum2 = 0; - - const uint8x16_t q2bits = vld1q_u8(q2); - - const ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8); - - q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits, m3)); - q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 2), m3)); - q2bytes.val[2] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 4), m3)); - q2bytes.val[3] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 6), m3)); - - isum1 += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * scales[0]; - isum2 += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * scales[1]; - isum1 += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[2], q8bytes.val[2])) * scales[2]; - isum2 += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[3], q8bytes.val[3])) * scales[3]; - - sum += d * (isum1 + isum2); - } - - *s = sum; - -#elif defined __AVX2__ - - const __m256i m3 = _mm256_set1_epi8(3); - - __m256 acc = _mm256_setzero_ps(); - - uint32_t ud, um; - const uint8_t * restrict db = (const uint8_t *)&ud; - const uint8_t * restrict mb = (const uint8_t *)&um; - - float summs = 0; - - // TODO: optimize this - - for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); - - const uint8_t * restrict q2 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - const uint32_t * restrict sc = (const uint32_t *)x[i].scales; - ud = (sc[0] >> 0) & 0x0f0f0f0f; - um = (sc[0] >> 4) & 0x0f0f0f0f; - - int32_t smin = mb[0] * y[i].bsums[0] + mb[1] * y[i].bsums[1] + mb[2] * y[i].bsums[2] + mb[3] * y[i].bsums[3]; - summs += dmin * smin; - - const __m128i q2bits = _mm_loadu_si128((const __m128i*)q2); - const __m256i q2_0 = _mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(q2bits, 2), q2bits), m3); - const __m256i q2_1 = _mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(q2bits, 6), _mm_srli_epi16(q2bits, 4)), m3); - - const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); - const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); - - const __m256i p0 = _mm256_maddubs_epi16(q2_0, q8_0); - const __m256i p1 = _mm256_maddubs_epi16(q2_1, q8_1); - - const __m256i p_0 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p0, 0)); - const __m256i p_1 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p0, 1)); - const __m256i p_2 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p1, 0)); - const __m256i p_3 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p1, 1)); - - acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[0]), _mm256_cvtepi32_ps(p_0), acc); - acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[1]), _mm256_cvtepi32_ps(p_1), acc); - acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[2]), _mm256_cvtepi32_ps(p_2), acc); - acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[3]), _mm256_cvtepi32_ps(p_3), acc); - } - - *s = hsum_float_8(acc) + summs; - -#elif defined __AVX__ - - const __m128i m3 = _mm_set1_epi8(3); - - __m256 acc = _mm256_setzero_ps(); - - uint32_t ud, um; - const uint8_t * restrict db = (const uint8_t *)&ud; - const uint8_t * restrict mb = (const uint8_t *)&um; - - float summs = 0; - - // TODO: optimize this - - for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); - - const uint8_t * restrict q2 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - const uint32_t * restrict sc = (const uint32_t *)x[i].scales; - ud = (sc[0] >> 0) & 0x0f0f0f0f; - um = (sc[0] >> 4) & 0x0f0f0f0f; - - int32_t smin = mb[0] * y[i].bsums[0] + mb[1] * y[i].bsums[1] + mb[2] * y[i].bsums[2] + mb[3] * y[i].bsums[3]; - summs += dmin * smin; - - const __m128i q2bits = _mm_loadu_si128((const __m128i*)q2); - const __m128i q2_0 = _mm_and_si128(q2bits, m3); - const __m128i q2_1 = _mm_and_si128(_mm_srli_epi16(q2bits, 2), m3); - const __m128i q2_2 = _mm_and_si128(_mm_srli_epi16(q2bits, 4), m3); - const __m128i q2_3 = _mm_and_si128(_mm_srli_epi16(q2bits, 6), m3); - - const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); - const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); - - const __m128i p0 = _mm_maddubs_epi16(q2_0, _mm256_extractf128_si256(q8_0, 0)); - const __m128i p1 = _mm_maddubs_epi16(q2_1, _mm256_extractf128_si256(q8_0, 1)); - const __m128i p2 = _mm_maddubs_epi16(q2_2, _mm256_extractf128_si256(q8_1, 0)); - const __m128i p3 = _mm_maddubs_epi16(q2_3, _mm256_extractf128_si256(q8_1, 1)); - - const __m256i p_0 = MM256_SET_M128I(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p0, p0)), _mm_cvtepi16_epi32(p0)); - const __m256i p_1 = MM256_SET_M128I(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p1, p1)), _mm_cvtepi16_epi32(p1)); - const __m256i p_2 = MM256_SET_M128I(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p2, p2)), _mm_cvtepi16_epi32(p2)); - const __m256i p_3 = MM256_SET_M128I(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p3, p3)), _mm_cvtepi16_epi32(p3)); - - acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[0]), _mm256_cvtepi32_ps(p_0)), acc); - acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[1]), _mm256_cvtepi32_ps(p_1)), acc); - acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[2]), _mm256_cvtepi32_ps(p_2)), acc); - acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[3]), _mm256_cvtepi32_ps(p_3)), acc); - } - - *s = hsum_float_8(acc) + summs; - -#elif defined __riscv_v_intrinsic - - uint32_t aux32[2]; - const uint8_t * scales = (const uint8_t *)aux32; - - float sumf = 0; - - for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); - - const uint8_t * restrict q2 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - const uint32_t * restrict sc = (const uint32_t *)x[i].scales; - - aux32[0] = sc[0] & 0x0f0f0f0f; - aux32[1] = (sc[0] >> 4) & 0x0f0f0f0f; - - sumf += dmin * (scales[4] * y[i].bsums[0] + scales[5] * y[i].bsums[1] + scales[6] * y[i].bsums[2] + scales[7] * y[i].bsums[3]); - - int isum1 = 0; - int isum2 = 0; - - size_t vl = 16; - - vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1); - - // load Q2 - vuint8mf2_t q2_x = __riscv_vle8_v_u8mf2(q2, vl); - - vint8mf2_t q2_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(q2_x, 0x03, vl)); - vint8mf2_t q2_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q2_x, 0x2, vl), 0x03 , vl)); - vint8mf2_t q2_2 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q2_x, 0x4, vl), 0x03 , vl)); - vint8mf2_t q2_3 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q2_x, 0x6, vl), 0x03 , vl)); - - // load Q8, and take product with Q2 - vint16m1_t p0 = __riscv_vwmul_vv_i16m1(q2_0, __riscv_vle8_v_i8mf2(q8, vl), vl); - vint16m1_t p1 = __riscv_vwmul_vv_i16m1(q2_1, __riscv_vle8_v_i8mf2(q8+16, vl), vl); - vint16m1_t p2 = __riscv_vwmul_vv_i16m1(q2_2, __riscv_vle8_v_i8mf2(q8+32, vl), vl); - vint16m1_t p3 = __riscv_vwmul_vv_i16m1(q2_3, __riscv_vle8_v_i8mf2(q8+48, vl), vl); - - vint16m1_t vs_0 = __riscv_vredsum_vs_i16m1_i16m1(p0, vzero, vl); - vint16m1_t vs_1 = __riscv_vredsum_vs_i16m1_i16m1(p1, vzero, vl); - vint16m1_t vs_2 = __riscv_vredsum_vs_i16m1_i16m1(p2, vzero, vl); - vint16m1_t vs_3 = __riscv_vredsum_vs_i16m1_i16m1(p3, vzero, vl); - - isum1 += __riscv_vmv_x_s_i16m1_i16(vs_0) * scales[0]; - isum2 += __riscv_vmv_x_s_i16m1_i16(vs_1) * scales[1]; - isum1 += __riscv_vmv_x_s_i16m1_i16(vs_2) * scales[2]; - isum2 += __riscv_vmv_x_s_i16m1_i16(vs_3) * scales[3]; - - sumf += d * (isum1 + isum2); - - } - - *s = sumf; - -#elif defined(__POWER9_VECTOR__) - const vector signed char lowMask = vec_splats((signed char)0x3); - const vector signed char lowScaleMask = vec_splats((signed char)0xF); - const vector unsigned char v2 = vec_splats((unsigned char)0x2); - const vector unsigned char v4 = vec_splats((unsigned char)0x4); - const vector unsigned char v6 = vec_splats((unsigned char)0x6); - - vector float vsumf0 = vec_splats(0.0f); - vector float vsumf1 = vec_splats(0.0f); - vector float vsumf2 = vec_splats(0.0f); - vector float vsumf3 = vec_splats(0.0f); - -#pragma GCC unroll 2 - for (int i = 0; i < nb; ++i) { - __builtin_prefetch(x[i].qs, 0, 1); - __builtin_prefetch(y[i].qs, 0, 1); - - vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d)); - vector float vyd = vec_splats(y[i].d); - vector float vd = vec_mul(vxd, vyd); - - vector float vxmin = vec_splats(GGML_FP16_TO_FP32(x[i].dmin)); - vector float vdmin = vec_mul(vxmin, vyd); - - vector signed short q8ysums0 = vec_xl_len(y[i].bsums, 8); - - vector signed char q2xmins = (vector signed char)vec_xl_len(x[i].scales, 4); - vector signed char vscales = vec_and(q2xmins, lowScaleMask); - - q2xmins = vec_sr(q2xmins, v4); - vector signed short q2xmins0 = vec_unpackh((vector signed char)q2xmins); - - vector signed int prod0 = vec_mule(q2xmins0, q8ysums0); - vector signed int prod1 = vec_mulo(q2xmins0, q8ysums0); - - vsumf0 = vec_nmsub(vec_ctf(prod0, 0), vdmin, vsumf0); - vsumf1 = vec_nmsub(vec_ctf(prod1, 0), vdmin, vsumf1); - - vector signed char qxs0 = (vector signed char)vec_xl( 0, x[i].qs); - vector signed char q2x00 = vec_and(qxs0, lowMask); - vector signed char q2x01 = vec_and(vec_sr(qxs0, v2), lowMask); - vector signed char q2x02 = vec_and(vec_sr(qxs0, v4), lowMask); - vector signed char q2x03 = vec_and(vec_sr(qxs0, v6), lowMask); - - vector signed char q8y00 = vec_xl( 0, y[i].qs); - vector signed char q8y01 = vec_xl( 16, y[i].qs); - vector signed char q8y02 = vec_xl( 32, y[i].qs); - vector signed char q8y03 = vec_xl( 48, y[i].qs); - - vector signed short qv0 = vec_add(vec_mule(q2x00, q8y00), vec_mulo(q2x00, q8y00)); - vector signed short qv1 = vec_add(vec_mule(q2x01, q8y01), vec_mulo(q2x01, q8y01)); - vector signed short qv2 = vec_add(vec_mule(q2x02, q8y02), vec_mulo(q2x02, q8y02)); - vector signed short qv3 = vec_add(vec_mule(q2x03, q8y03), vec_mulo(q2x03, q8y03)); - - vector signed short vscales_h = vec_unpackh(vscales); - vector signed short vs0 = vec_splat(vscales_h, 0); - vector signed short vs1 = vec_splat(vscales_h, 1); - vector signed short vs2 = vec_splat(vscales_h, 2); - vector signed short vs3 = vec_splat(vscales_h, 3); - - vector signed int vsumi0 = vec_add(vec_mule(qv0, vs0), vec_mulo(qv0, vs0)); - vector signed int vsumi1 = vec_add(vec_mule(qv1, vs1), vec_mulo(qv1, vs1)); - vector signed int vsumi2 = vec_add(vec_mule(qv2, vs2), vec_mulo(qv2, vs2)); - vector signed int vsumi3 = vec_add(vec_mule(qv3, vs3), vec_mulo(qv3, vs3)); - - vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); - vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); - vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); - vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); - } - - vsumf0 = vec_add(vsumf0, vsumf2); - vsumf1 = vec_add(vsumf1, vsumf3); - - vsumf0 = vec_add(vsumf0, vsumf1); - - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); - - *s = vec_extract(vsumf0, 0); - -#else - - float sumf = 0; - - int isum[QK_K/16]; - - for (int i = 0; i < nb; ++i) { - - const uint8_t * q2 = x[i].qs; - const int8_t * q8 = y[i].qs; - const uint8_t * sc = x[i].scales; - - int summs = 0; - for (int j = 0; j < QK_K/16; ++j) { - summs += y[i].bsums[j] * (sc[j] >> 4); - } - - const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); - - memset(isum, 0, (QK_K/16)*sizeof(int)); - for (int l = 0; l < 16; ++l) { - isum[0] += q8[l+ 0] * ((q2[l] >> 0) & 3); - isum[1] += q8[l+16] * ((q2[l] >> 2) & 3); - isum[2] += q8[l+32] * ((q2[l] >> 4) & 3); - isum[3] += q8[l+48] * ((q2[l] >> 6) & 3); - } - for (int l = 0; l < QK_K/16; ++l) { - isum[l] *= (sc[l] & 0xF); - } - sumf += dall * (isum[0] + isum[1] + isum[2] + isum[3]) - dmin * summs; - } - *s = sumf; -#endif -} -#endif - -#if QK_K == 256 void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); @@ -6422,6 +6708,7 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r vector signed int vsumi6 = vec_splats((int32_t)0); vector signed int vsumi7 = vec_splats((int32_t)0); + const uint8_t * restrict q3 = x[i].qs; const int8_t * restrict q8 = y[i].qs; @@ -6533,6 +6820,112 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); *s = vec_extract(vsumf0, 0); + +#elif defined __loongarch_asx + + const __m256i m3 = __lasx_xvreplgr2vr_b(3); + const __m256i mone = __lasx_xvreplgr2vr_b(1); + const __m128i m32 = __lsx_vreplgr2vr_b(32); + + __m256 acc = (__m256)__lasx_xvldi(0); + + uint32_t aux[3]; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const uint8_t * restrict q3 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + // Set up scales + memcpy(aux, x[i].scales, 12); + __m128i scales128 = lsx_set_w( + ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4), + ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4), + (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4), + (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4)); + scales128 = __lsx_vsub_b(scales128, m32); + const __m256i all_scales = lasx_ext8_16(scales128); + const __m128i l_scales = lasx_extracti128(all_scales, 0); + const __m128i h_scales = lasx_extracti128(all_scales, 1); + const __m256i scales[2] = {lasx_insertf128(l_scales, l_scales), lasx_insertf128(h_scales, h_scales)}; + + // high bit + const __m256i hbits = __lasx_xvld((const __m256i*)x[i].hmask, 0); + + // integer accumulator + __m256i sumi = __lasx_xvldi(0); + + int bit = 0; + int is = 0; + __m256i xvbit; + + + for (int j = 0; j < QK_K/128; ++j) { + // load low 2 bits + const __m256i q3bits = __lasx_xvld((const __m256i*)q3, 0); q3 += 32; + + xvbit = __lasx_xvreplgr2vr_h(bit); + // prepare low and high bits + const __m256i q3l_0 = __lasx_xvand_v(q3bits, m3); + const __m256i q3h_0 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2); + ++bit; + + xvbit = __lasx_xvreplgr2vr_h(bit); + const __m256i q3l_1 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 2), m3); + const __m256i q3h_1 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2); + ++bit; + + xvbit = __lasx_xvreplgr2vr_h(bit); + const __m256i q3l_2 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 4), m3); + const __m256i q3h_2 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2); + ++bit; + + xvbit = __lasx_xvreplgr2vr_h(bit); + const __m256i q3l_3 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 6), m3); + const __m256i q3h_3 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2); + ++bit; + + // load Q8 quants + const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; + const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; + const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; + const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; + + // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use lasx_maddubs_h, + // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set, + // and 2 if the high bit was set) + __m256i q8s_0 = lasx_maddubs_h(q3h_0, q8_0); + __m256i q8s_1 = lasx_maddubs_h(q3h_1, q8_1); + __m256i q8s_2 = lasx_maddubs_h(q3h_2, q8_2); + __m256i q8s_3 = lasx_maddubs_h(q3h_3, q8_3); + + __m256i p16_0 = lasx_maddubs_h(q3l_0, q8_0); + __m256i p16_1 = lasx_maddubs_h(q3l_1, q8_1); + __m256i p16_2 = lasx_maddubs_h(q3l_2, q8_2); + __m256i p16_3 = lasx_maddubs_h(q3l_3, q8_3); + + p16_0 = __lasx_xvsub_h(p16_0, q8s_0); + p16_1 = __lasx_xvsub_h(p16_1, q8s_1); + p16_2 = __lasx_xvsub_h(p16_2, q8s_2); + p16_3 = __lasx_xvsub_h(p16_3, q8s_3); + + // multiply with scales + p16_0 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(is + 0)), p16_0); + p16_1 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(is + 1)), p16_1); + p16_2 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(is + 2)), p16_2); + p16_3 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(is + 3)), p16_3); + + // accumulate + p16_0 = __lasx_xvadd_w(p16_0, p16_1); + p16_2 = __lasx_xvadd_w(p16_2, p16_3); + sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_2)); + } + // multiply with block scale and accumulate + acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), acc);//FIXME + } + + *s = hsum_float_8(acc); + #else // scalar version // This function is written like this so the compiler can manage to vectorize most of it @@ -6600,445 +6993,6 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r } -#else - -void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { - assert(n % QK_K == 0); - assert(nrc == 1); - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - - const block_q3_K * restrict x = vx; - const block_q8_K * restrict y = vy; - - const int nb = n / QK_K; - -#ifdef __ARM_NEON - const int32x4_t vzero = vdupq_n_s32(0); - - const uint8x16_t m3b = vdupq_n_u8(0x3); - const uint8x16_t mh = vdupq_n_u8(4); - - ggml_int8x16x4_t q3bytes; - - uint16_t aux16[2]; - int8_t * scales = (int8_t *)aux16; - - float sum = 0; - - for (int i = 0; i < nb; ++i) { - - ggml_uint8x16x4_t q3h; - - const uint8x8_t hbits = vld1_u8(x[i].hmask); - const uint8x16_t q3bits = vld1q_u8(x[i].qs); - const ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(y[i].qs); - - const uint16_t a = *(const uint16_t *)x[i].scales; - aux16[0] = a & 0x0f0f; - aux16[1] = (a >> 4) & 0x0f0f; - - for (int j = 0; j < 4; ++j) scales[j] -= 8; - - int32_t isum = -4*(scales[0] * y[i].bsums[0] + scales[2] * y[i].bsums[1] + scales[1] * y[i].bsums[2] + scales[3] * y[i].bsums[3]); - - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - - const uint8x16_t htmp = vcombine_u8(hbits, vshr_n_u8(hbits, 1)); - q3h.val[0] = vandq_u8(mh, vshlq_n_u8(htmp, 2)); - q3h.val[1] = vandq_u8(mh, htmp); - q3h.val[2] = vandq_u8(mh, vshrq_n_u8(htmp, 2)); - q3h.val[3] = vandq_u8(mh, vshrq_n_u8(htmp, 4)); - - q3bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q3bits, m3b), q3h.val[0])); - q3bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(vshrq_n_u8(q3bits, 2), m3b), q3h.val[1])); - q3bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(vshrq_n_u8(q3bits, 4), m3b), q3h.val[2])); - q3bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q3bits, 6), q3h.val[3])); - - isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[0], q8bytes.val[0])) * scales[0]; - isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[1], q8bytes.val[1])) * scales[2]; - isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[2], q8bytes.val[2])) * scales[1]; - isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[3], q8bytes.val[3])) * scales[3]; - - sum += d * isum; - - } - - *s = sum; - -#elif defined __AVX2__ - - const __m256i m3 = _mm256_set1_epi8(3); - const __m256i m1 = _mm256_set1_epi8(1); - - __m256 acc = _mm256_setzero_ps(); - - uint64_t aux64; - - uint16_t aux16[2]; - const int8_t * aux8 = (const int8_t *)aux16; - - for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - - const uint8_t * restrict q3 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - const uint16_t a = *(const uint16_t *)x[i].scales; - aux16[0] = a & 0x0f0f; - aux16[1] = (a >> 4) & 0x0f0f; - - const __m256i scale_0 = MM256_SET_M128I(_mm_set1_epi16(aux8[2] - 8), _mm_set1_epi16(aux8[0] - 8)); - const __m256i scale_1 = MM256_SET_M128I(_mm_set1_epi16(aux8[3] - 8), _mm_set1_epi16(aux8[1] - 8)); - - memcpy(&aux64, x[i].hmask, 8); - - const __m128i haux = _mm_set_epi64x(aux64 >> 1, aux64 >> 0); - __m256i q3h_0 = MM256_SET_M128I(_mm_srli_epi16(haux, 2), haux); - __m256i q3h_1 = _mm256_srli_epi16(q3h_0, 4); - q3h_0 = _mm256_slli_epi16(_mm256_andnot_si256(q3h_0, m1), 2); - q3h_1 = _mm256_slli_epi16(_mm256_andnot_si256(q3h_1, m1), 2); - - // load low 2 bits - const __m128i q3bits = _mm_loadu_si128((const __m128i*)q3); - - // prepare low and high bits - const __m256i q3aux = MM256_SET_M128I(_mm_srli_epi16(q3bits, 2), q3bits); - const __m256i q3l_0 = _mm256_and_si256(q3aux, m3); - const __m256i q3l_1 = _mm256_and_si256(_mm256_srli_epi16(q3aux, 4), m3); - - // load Q8 quants - const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); - const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); - - // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16, - // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set, - // and 2 if the high bit was set) - const __m256i q8s_0 = _mm256_maddubs_epi16(q3h_0, q8_0); - const __m256i q8s_1 = _mm256_maddubs_epi16(q3h_1, q8_1); - - __m256i p16_0 = _mm256_maddubs_epi16(q3l_0, q8_0); - __m256i p16_1 = _mm256_maddubs_epi16(q3l_1, q8_1); - - p16_0 = _mm256_sub_epi16(p16_0, q8s_0); - p16_1 = _mm256_sub_epi16(p16_1, q8s_1); - - // multiply with scales - p16_0 = _mm256_madd_epi16(scale_0, p16_0); - p16_1 = _mm256_madd_epi16(scale_1, p16_1); - - p16_0 = _mm256_add_epi32(p16_0, p16_1); - - // multiply with block scale and accumulate - acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(p16_0), acc); - - } - - *s = hsum_float_8(acc); - -#elif defined __AVX__ - - const __m128i m3 = _mm_set1_epi8(3); - const __m128i m1 = _mm_set1_epi8(1); - - __m256 acc = _mm256_setzero_ps(); - - uint64_t aux64; - - uint16_t aux16[2]; - const int8_t * aux8 = (const int8_t *)aux16; - - for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - - const uint8_t * restrict q3 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - const uint16_t a = *(const uint16_t *)x[i].scales; - aux16[0] = a & 0x0f0f; - aux16[1] = (a >> 4) & 0x0f0f; - - const __m128i scale_0 = _mm_set1_epi16(aux8[0] - 8); - const __m128i scale_1 = _mm_set1_epi16(aux8[2] - 8); - const __m128i scale_2 = _mm_set1_epi16(aux8[1] - 8); - const __m128i scale_3 = _mm_set1_epi16(aux8[3] - 8); - - memcpy(&aux64, x[i].hmask, 8); - - __m128i q3h_0 = _mm_set_epi64x(aux64 >> 1, aux64 >> 0); - __m128i q3h_1 = _mm_srli_epi16(q3h_0, 2); - __m128i q3h_2 = _mm_srli_epi16(q3h_0, 4); - __m128i q3h_3 = _mm_srli_epi16(q3h_0, 6); - q3h_0 = _mm_slli_epi16(_mm_andnot_si128(q3h_0, m1), 2); - q3h_1 = _mm_slli_epi16(_mm_andnot_si128(q3h_1, m1), 2); - q3h_2 = _mm_slli_epi16(_mm_andnot_si128(q3h_2, m1), 2); - q3h_3 = _mm_slli_epi16(_mm_andnot_si128(q3h_3, m1), 2); - - // load low 2 bits - const __m128i q3bits = _mm_loadu_si128((const __m128i*)q3); - - // prepare low and high bits - const __m128i q3l_0 = _mm_and_si128(q3bits, m3); - const __m128i q3l_1 = _mm_and_si128(_mm_srli_epi16(q3bits, 2), m3); - const __m128i q3l_2 = _mm_and_si128(_mm_srli_epi16(q3bits, 4), m3); - const __m128i q3l_3 = _mm_and_si128(_mm_srli_epi16(q3bits, 6), m3); - - // load Q8 quants - const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); - const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); - - // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm_maddubs_epi16, - // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set, - // and 2 if the high bit was set) - const __m128i q8s_0 = _mm_maddubs_epi16(q3h_0, _mm256_extractf128_si256(q8_0, 0)); - const __m128i q8s_1 = _mm_maddubs_epi16(q3h_1, _mm256_extractf128_si256(q8_0, 1)); - const __m128i q8s_2 = _mm_maddubs_epi16(q3h_2, _mm256_extractf128_si256(q8_1, 0)); - const __m128i q8s_3 = _mm_maddubs_epi16(q3h_3, _mm256_extractf128_si256(q8_1, 1)); - - __m128i p16_0 = _mm_maddubs_epi16(q3l_0, _mm256_extractf128_si256(q8_0, 0)); - __m128i p16_1 = _mm_maddubs_epi16(q3l_1, _mm256_extractf128_si256(q8_0, 1)); - __m128i p16_2 = _mm_maddubs_epi16(q3l_2, _mm256_extractf128_si256(q8_1, 0)); - __m128i p16_3 = _mm_maddubs_epi16(q3l_3, _mm256_extractf128_si256(q8_1, 1)); - - p16_0 = _mm_sub_epi16(p16_0, q8s_0); - p16_1 = _mm_sub_epi16(p16_1, q8s_1); - p16_2 = _mm_sub_epi16(p16_2, q8s_2); - p16_3 = _mm_sub_epi16(p16_3, q8s_3); - - // multiply with scales - p16_0 = _mm_madd_epi16(scale_0, p16_0); - p16_1 = _mm_madd_epi16(scale_1, p16_1); - p16_2 = _mm_madd_epi16(scale_2, p16_2); - p16_3 = _mm_madd_epi16(scale_3, p16_3); - - p16_0 = _mm_add_epi32(p16_0, p16_2); - p16_1 = _mm_add_epi32(p16_1, p16_3); - __m256i p16 = MM256_SET_M128I(p16_1, p16_0); - - // multiply with block scale and accumulate - acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(p16)), acc); - - } - - *s = hsum_float_8(acc); - -#elif defined __riscv_v_intrinsic - - uint16_t aux16[2]; - int8_t * scales = (int8_t *)aux16; - - float sumf = 0; - - for (int i = 0; i < nb; ++i) { - - const uint8_t * restrict q3 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - const uint16_t a = *(const uint16_t *)x[i].scales; - aux16[0] = a & 0x0f0f; - aux16[1] = (a >> 4) & 0x0f0f; - - for (int j = 0; j < 4; ++j) scales[j] -= 8; - - int32_t isum = -4*(scales[0] * y[i].bsums[0] + scales[2] * y[i].bsums[1] + scales[1] * y[i].bsums[2] + scales[3] * y[i].bsums[3]); - - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - - vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); - - // load qh - vuint8mf4_t qh_x1 = __riscv_vle8_v_u8mf4(x[i].hmask, 8); - vuint8mf2_t qh_x2 = __riscv_vlmul_ext_v_u8mf4_u8mf2(__riscv_vsrl_vx_u8mf4(qh_x1, 1, 8)); - - size_t vl = 16; - - // extend and combine both qh_x1 and qh_x2 - vuint8mf2_t qh_x = __riscv_vslideup_vx_u8mf2(__riscv_vlmul_ext_v_u8mf4_u8mf2(qh_x1), qh_x2, vl/2, vl); - - vuint8mf2_t qh_0 = __riscv_vand_vx_u8mf2(__riscv_vsll_vx_u8mf2(qh_x, 0x2, vl), 0x4, vl); - vuint8mf2_t qh_1 = __riscv_vand_vx_u8mf2(qh_x, 0x4, vl); - vuint8mf2_t qh_2 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(qh_x, 0x2, vl), 0x4, vl); - vuint8mf2_t qh_3 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(qh_x, 0x4, vl), 0x4, vl); - - // load Q3 - vuint8mf2_t q3_x = __riscv_vle8_v_u8mf2(q3, vl); - - vuint8mf2_t q3h_0 = __riscv_vor_vv_u8mf2(__riscv_vand_vx_u8mf2(q3_x, 0x3, vl), qh_0, vl); - vuint8mf2_t q3h_1 = __riscv_vor_vv_u8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q3_x, 2, vl), 0x3, vl), qh_1, vl); - vuint8mf2_t q3h_2 = __riscv_vor_vv_u8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q3_x, 4, vl), 0x3, vl), qh_2, vl); - vuint8mf2_t q3h_3 = __riscv_vor_vv_u8mf2(__riscv_vsrl_vx_u8mf2(q3_x, 0x6, vl), qh_3, vl); - - vint8mf2_t q3_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(q3h_0); - vint8mf2_t q3_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(q3h_1); - vint8mf2_t q3_2 = __riscv_vreinterpret_v_u8mf2_i8mf2(q3h_2); - vint8mf2_t q3_3 = __riscv_vreinterpret_v_u8mf2_i8mf2(q3h_3); - - // load Q8 and take product with Q3 - vint16m1_t p0 = __riscv_vwmul_vv_i16m1(q3_0, __riscv_vle8_v_i8mf2(q8, vl), vl); - vint16m1_t p1 = __riscv_vwmul_vv_i16m1(q3_1, __riscv_vle8_v_i8mf2(q8+16, vl), vl); - vint16m1_t p2 = __riscv_vwmul_vv_i16m1(q3_2, __riscv_vle8_v_i8mf2(q8+32, vl), vl); - vint16m1_t p3 = __riscv_vwmul_vv_i16m1(q3_3, __riscv_vle8_v_i8mf2(q8+48, vl), vl); - - vint32m1_t vs_0 = __riscv_vwredsum_vs_i16m1_i32m1(p0, vzero, vl); - vint32m1_t vs_1 = __riscv_vwredsum_vs_i16m1_i32m1(p1, vzero, vl); - vint32m1_t vs_2 = __riscv_vwredsum_vs_i16m1_i32m1(p2, vzero, vl); - vint32m1_t vs_3 = __riscv_vwredsum_vs_i16m1_i32m1(p3, vzero, vl); - - isum += __riscv_vmv_x_s_i32m1_i32(vs_0) * scales[0]; - isum += __riscv_vmv_x_s_i32m1_i32(vs_1) * scales[2]; - isum += __riscv_vmv_x_s_i32m1_i32(vs_2) * scales[1]; - isum += __riscv_vmv_x_s_i32m1_i32(vs_3) * scales[3]; - - sumf += d * isum; - - } - - *s = sumf; - -#elif defined(__POWER9_VECTOR__) - const vector signed char lowMask = vec_splats((signed char)0x3); - const vector signed char v1 = vec_splats((signed char)0x1); - const vector unsigned char v2 = vec_splats((unsigned char)0x2); - const vector unsigned char v4 = vec_splats((unsigned char)0x4); - const vector unsigned char v6 = vec_splats((unsigned char)0x6); - const vector signed char off = vec_splats((signed char)0x8); - - vector float vsumf0 = vec_splats(0.0f); - vector float vsumf1 = vec_splats(0.0f); - vector float vsumf2 = vec_splats(0.0f); - vector float vsumf3 = vec_splats(0.0f); - -#pragma GCC unroll 2 - for (int i = 0; i < nb; ++i) { - __builtin_prefetch(x[i].qs, 0, 1); - __builtin_prefetch(y[i].qs, 0, 1); - - vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d)); - vector float vyd = vec_splats(y[i].d); - vector float vd = vec_mul(vxd, vyd); - - uint16_t aux16[2]; - int8_t * scales = (int8_t *)aux16; - - const uint16_t a = *(const uint16_t *)x[i].scales; - aux16[0] = a & 0x0f0f; - aux16[1] = (a >> 4) & 0x0f0f; - - vector signed char vscales = (vector signed char)vec_xl_len(scales, 8); - vector signed char qxhs0 = (vector signed char)vec_xl_len(x[i].hmask, 8); - qxhs0 = vec_or(qxhs0, vec_sr(vec_sld(qxhs0, qxhs0, 8), (vector unsigned char)v1)); - - vscales = vec_sub(vscales, off); - - vector signed char qxs0 = (vector signed char)vec_xl( 0, x[i].qs); - vector signed char qxs00 = vec_and(qxs0, lowMask); - vector signed char qxs01 = vec_and(vec_sr(qxs0, v2), lowMask); - vector signed char qxs10 = vec_and(vec_sr(qxs0, v4), lowMask); - vector signed char qxs11 = vec_and(vec_sr(qxs0, v6), lowMask); - - //the 3rd bit - vector signed char qxh00 = vec_sl(vec_andc(v1, qxhs0), v2); - vector signed char qxh01 = vec_sl(vec_andc(v1, vec_sr(qxhs0, v2)), v2); - vector signed char qxh02 = vec_sl(vec_andc(v1, vec_sr(qxhs0, v4)), v2); - vector signed char qxh03 = vec_sl(vec_andc(v1, vec_sr(qxhs0, v6)), v2); - qxhs0 = vec_sr(qxhs0, v4); - - vector signed char q3x00 = vec_sub(qxs00, qxh00); - vector signed char q3x01 = vec_sub(qxs01, qxh01); - vector signed char q3x10 = vec_sub(qxs10, qxh02); - vector signed char q3x11 = vec_sub(qxs11, qxh03); - - vector signed char q8y00 = vec_xl( 0, y[i].qs); - vector signed char q8y01 = vec_xl( 16, y[i].qs); - vector signed char q8y10 = vec_xl( 32, y[i].qs); - vector signed char q8y11 = vec_xl( 48, y[i].qs); - - vector signed short vscales_h = vec_unpackh(vscales); - vector signed short vs0 = vec_splat(vscales_h, 0); - vector signed short vs1 = vec_splat(vscales_h, 1); - vector signed short vs2 = vec_splat(vscales_h, 2); - vector signed short vs3 = vec_splat(vscales_h, 3); - - vector signed short qv00 = vec_add(vec_mule(q3x00, q8y00), vec_mulo(q3x00, q8y00)); - vector signed short qv10 = vec_add(vec_mule(q3x10, q8y10), vec_mulo(q3x10, q8y10)); - vector signed short qv01 = vec_add(vec_mule(q3x01, q8y01), vec_mulo(q3x01, q8y01)); - vector signed short qv11 = vec_add(vec_mule(q3x11, q8y11), vec_mulo(q3x11, q8y11)); - - vector signed int vsumi0 = vec_add(vec_mule(qv00, vs0), vec_mulo(qv00, vs0)); - vector signed int vsumi1 = vec_add(vec_mule(qv10, vs1), vec_mulo(qv10, vs1)); - vector signed int vsumi2 = vec_add(vec_mule(qv01, vs2), vec_mulo(qv01, vs2)); - vector signed int vsumi3 = vec_add(vec_mule(qv11, vs3), vec_mulo(qv11, vs3)); - - vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); - vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); - vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); - vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); - } - - vsumf0 = vec_add(vsumf0, vsumf2); - vsumf1 = vec_add(vsumf1, vsumf3); - - vsumf0 = vec_add(vsumf0, vsumf1); - - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); - - *s = vec_extract(vsumf0, 0); -#else - - int8_t aux8[QK_K]; - int16_t aux16[8]; - float sums [8]; - int32_t aux32[8]; - int32_t scales[4]; - memset(sums, 0, 8*sizeof(float)); - - float sumf = 0; - for (int i = 0; i < nb; ++i) { - const uint8_t * restrict q3 = x[i].qs; - const uint8_t * restrict hm = x[i].hmask; - const int8_t * restrict q8 = y[i].qs; - int8_t * restrict a = aux8; - for (int l = 0; l < 8; ++l) { - a[l+ 0] = (int8_t)((q3[l+0] >> 0) & 3) - (hm[l] & 0x01 ? 0 : 4); - a[l+ 8] = (int8_t)((q3[l+8] >> 0) & 3) - (hm[l] & 0x02 ? 0 : 4); - a[l+16] = (int8_t)((q3[l+0] >> 2) & 3) - (hm[l] & 0x04 ? 0 : 4); - a[l+24] = (int8_t)((q3[l+8] >> 2) & 3) - (hm[l] & 0x08 ? 0 : 4); - a[l+32] = (int8_t)((q3[l+0] >> 4) & 3) - (hm[l] & 0x10 ? 0 : 4); - a[l+40] = (int8_t)((q3[l+8] >> 4) & 3) - (hm[l] & 0x20 ? 0 : 4); - a[l+48] = (int8_t)((q3[l+0] >> 6) & 3) - (hm[l] & 0x40 ? 0 : 4); - a[l+56] = (int8_t)((q3[l+8] >> 6) & 3) - (hm[l] & 0x80 ? 0 : 4); - } - - scales[0] = (x[i].scales[0] & 0xF) - 8; - scales[1] = (x[i].scales[0] >> 4) - 8; - scales[2] = (x[i].scales[1] & 0xF) - 8; - scales[3] = (x[i].scales[1] >> 4) - 8; - - memset(aux32, 0, 8*sizeof(int32_t)); - for (int j = 0; j < QK_K/16; ++j) { - for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - q8 += 8; a += 8; - for (int l = 0; l < 8; ++l) aux16[l] += q8[l] * a[l]; - q8 += 8; a += 8; - for (int l = 0; l < 8; ++l) aux32[l] += scales[j] * aux16[l]; - } - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; - } - for (int l = 0; l < 8; ++l) sumf += sums[l]; - *s = sumf; - -#endif - -} -#endif - -#if QK_K == 256 void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); @@ -7476,8 +7430,78 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r *s = vec_extract(vsumf0, 0); -#else +#elif defined __loongarch_asx + GGML_UNUSED(kmask1); + GGML_UNUSED(kmask2); + GGML_UNUSED(kmask3); + const __m256i m4 = __lasx_xvreplgr2vr_b(0xF); + + __m256 acc = (__m256)__lasx_xvldi(0); + __m128 acc_m = (__m128)__lsx_vldi(0); + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const __m256i mins_and_scales = lasx_extu8_16(lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0])); + + const __m256i q8sums = __lasx_xvld((const __m256i*)y[i].bsums, 0); + const __m128i q8s = lsx_hadd_h(lasx_extracti128(q8sums, 0), lasx_extracti128(q8sums, 1)); + const __m128i prod = lsx_madd_h(lasx_extracti128(mins_and_scales, 1), q8s); + acc_m = __lsx_vfmadd_s(__lsx_vreplfr2vr_s(dmin), __lsx_vffint_s_w(prod), acc_m); + + const __m128i sc128 = lasx_extracti128(mins_and_scales, 0); + const __m256i scales = lasx_insertf128(sc128, sc128); + + __m256i sumi = __lasx_xvldi(0); + + for (int j = 0; j < QK_K/64; ++j) { + + const __m256i scale_l = lasx_shuffle_b(scales, get_scale_shuffle_k4(2*j+0)); + const __m256i scale_h = lasx_shuffle_b(scales, get_scale_shuffle_k4(2*j+1)); + + const __m256i q4bits = __lasx_xvld((const __m256i*)q4, 0); q4 += 32; + const __m256i q4l = __lasx_xvand_v(q4bits, m4); + const __m256i q4h = __lasx_xvand_v(__lasx_xvsrli_h(q4bits, 4), m4); + + const __m256i q8l = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; + __m256i p16l = lasx_maddubs_h(q4l, q8l); + p16l = lasx_madd_h(scale_l, p16l); + + const __m256i q8h = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; + __m256i p16h = lasx_maddubs_h(q4h, q8h); + p16h = lasx_madd_h(scale_h, p16h); + const __m256i sumj = __lasx_xvadd_w(p16l, p16h); + + sumi = __lasx_xvadd_w(sumi, sumj); + } + + __m256 vd = __lasx_xvreplfr2vr_s(d); + acc = __lasx_xvfmadd_s(vd, __lasx_xvffint_s_w(sumi), acc); + + } + + acc_m = __lsx_vfadd_s(acc_m, (__m128)__lsx_vpermi_w((__m128i)acc_m, (__m128i)acc_m, 0xee)); + __m128i tmp1 = __lsx_vinsgr2vr_w(__lsx_vldi(0), __lsx_vpickve2gr_w((__m128i)acc_m, 1), 0); + acc_m = __lsx_vfadd_s(acc_m, (__m128)tmp1); + + + ft_union fi; + fi.i = __lsx_vpickve2gr_w(acc_m, 0); + *s = hsum_float_8(acc) + fi.f ; +#else const uint8_t * scales = (const uint8_t*)&utmp[0]; const uint8_t * mins = (const uint8_t*)&utmp[2]; @@ -7535,336 +7559,7 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r *s = sumf; #endif } -#else -void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { - assert(n % QK_K == 0); - assert(nrc == 1); - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - const block_q4_K * restrict x = vx; - const block_q8_K * restrict y = vy; - - const int nb = n / QK_K; - -#ifdef __ARM_NEON - const uint8x16_t m4b = vdupq_n_u8(0xf); - - const int32x4_t mzero = vdupq_n_s32(0); - - float sumf = 0; - - ggml_int8x16x2_t q4bytes; - ggml_int8x16x4_t q8bytes; - - float sum_mins = 0.f; - - uint16_t aux16[2]; - const uint8_t * restrict scales = (const uint8_t *)aux16; - - for (int i = 0; i < nb; ++i) { - - const uint8_t * restrict q4 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - const uint16_t * restrict a = (const uint16_t *)x[i].scales; - aux16[0] = a[0] & 0x0f0f; - aux16[1] = (a[0] >> 4) & 0x0f0f; - - const int32_t summi = scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]); - sum_mins += y[i].d * GGML_FP16_TO_FP32(x[i].d[1]) * summi; - - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d[0]); - - const ggml_uint8x16x2_t q4bits = ggml_vld1q_u8_x2(q4); - - q8bytes = ggml_vld1q_s8_x4(q8); - q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b)); - q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b)); - - const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]); - const int32_t sumi1 = vaddvq_s32(p1) * scales[0]; - - q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4)); - q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4)); - - const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[2]), q4bytes.val[1], q8bytes.val[3]); - const int32_t sumi2 = vaddvq_s32(p2) * scales[1]; - - sumf += d * (sumi1 + sumi2); - } - - *s = sumf - sum_mins; - -#elif defined __AVX2__ - - const __m256i m4 = _mm256_set1_epi8(0xF); - - __m256 acc = _mm256_setzero_ps(); - - float summs = 0; - - uint16_t aux16[2]; - const uint8_t * scales = (const uint8_t *)aux16; - - for (int i = 0; i < nb; ++i) { - - const float d = GGML_FP16_TO_FP32(x[i].d[0]) * y[i].d; - const float m = GGML_FP16_TO_FP32(x[i].d[1]) * y[i].d; - const __m256 vd = _mm256_set1_ps(d); - - const uint16_t * a = (const uint16_t *)x[i].scales; - aux16[0] = a[0] & 0x0f0f; - aux16[1] = (a[0] >> 4) & 0x0f0f; - - summs += m * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3])); - - const uint8_t * restrict q4 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4); - const __m256i q4l = _mm256_and_si256(q4bits, m4); - const __m256i q4h = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), m4); - - const __m256i q8l = _mm256_loadu_si256((const __m256i*)(q8+ 0)); - const __m256i q8h = _mm256_loadu_si256((const __m256i*)(q8+32)); - - const __m256i p16l = _mm256_maddubs_epi16(q4l, q8l); - const __m256i p16h = _mm256_maddubs_epi16(q4h, q8h); - - const __m256i p32l = _mm256_madd_epi16(_mm256_set1_epi16(scales[0]), p16l); - acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(p32l), acc); - - const __m256i p32h = _mm256_madd_epi16(_mm256_set1_epi16(scales[1]), p16h); - acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(p32h), acc); - - } - - *s = hsum_float_8(acc) - summs; - -#elif defined __AVX__ - - const __m128i m4 = _mm_set1_epi8(0xF); - - __m256 acc = _mm256_setzero_ps(); - - float summs = 0; - - uint16_t aux16[2]; - const uint8_t * scales = (const uint8_t *)aux16; - - for (int i = 0; i < nb; ++i) { - - const float d = GGML_FP16_TO_FP32(x[i].d[0]) * y[i].d; - const float m = GGML_FP16_TO_FP32(x[i].d[1]) * y[i].d; - const __m256 vd = _mm256_set1_ps(d); - - const uint16_t * a = (const uint16_t *)x[i].scales; - aux16[0] = a[0] & 0x0f0f; - aux16[1] = (a[0] >> 4) & 0x0f0f; - - summs += m * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3])); - - const uint8_t * restrict q4 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4); - const __m128i q4bits_0 = _mm256_extractf128_si256(q4bits, 0); - const __m128i q4bits_1 = _mm256_extractf128_si256(q4bits, 1); - const __m128i q4_0 = _mm_and_si128(q4bits_0, m4); - const __m128i q4_1 = _mm_and_si128(q4bits_1, m4); - const __m128i q4_2 = _mm_and_si128(_mm_srli_epi16(q4bits_0, 4), m4); - const __m128i q4_3 = _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4); - - const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); - const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); - - const __m128i p16_0 = _mm_maddubs_epi16(q4_0, _mm256_extractf128_si256(q8_0, 0)); - const __m128i p16_1 = _mm_maddubs_epi16(q4_1, _mm256_extractf128_si256(q8_0, 1)); - const __m128i p16_2 = _mm_maddubs_epi16(q4_2, _mm256_extractf128_si256(q8_1, 0)); - const __m128i p16_3 = _mm_maddubs_epi16(q4_3, _mm256_extractf128_si256(q8_1, 1)); - - const __m128i p32_0 = _mm_madd_epi16(_mm_set1_epi16(scales[0]), p16_0); - const __m128i p32_1 = _mm_madd_epi16(_mm_set1_epi16(scales[0]), p16_1); - acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(MM256_SET_M128I(p32_1, p32_0))), acc); - - const __m128i p32_2 = _mm_madd_epi16(_mm_set1_epi16(scales[1]), p16_2); - const __m128i p32_3 = _mm_madd_epi16(_mm_set1_epi16(scales[1]), p16_3); - acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(MM256_SET_M128I(p32_3, p32_2))), acc); - - } - - *s = hsum_float_8(acc) - summs; - -#elif defined __riscv_v_intrinsic - - uint16_t s16[2]; - const uint8_t * restrict scales = (const uint8_t *)s16; - - float sumf = 0; - - for (int i = 0; i < nb; ++i) { - - const uint8_t * restrict q4 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - const uint16_t * restrict b = (const uint16_t *)x[i].scales; - s16[0] = b[0] & 0x0f0f; - s16[1] = (b[0] >> 4) & 0x0f0f; - - sumf -= y[i].d * GGML_FP16_TO_FP32(x[i].d[1]) * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3])); - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d[0]); - - size_t vl = 32; - - vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1); - - // load Q4 - vuint8m1_t q4_x = __riscv_vle8_v_u8m1(q4, vl); - - // load Q8 and multiply it with lower Q4 nibble - vint8m1_t q4_a = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q4_x, 0x0F, vl)); - vint16m2_t va_0 = __riscv_vwmul_vv_i16m2(q4_a, __riscv_vle8_v_i8m1(q8, vl), vl); - vint16m1_t aux1 = __riscv_vredsum_vs_i16m2_i16m1(va_0, vzero, vl); - - sumf += d*scales[0]*__riscv_vmv_x_s_i16m1_i16(aux1); - - // load Q8 and multiply it with upper Q4 nibble - vint8m1_t q4_s = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q4_x, 0x04, vl)); - vint16m2_t va_1 = __riscv_vwmul_vv_i16m2(q4_s, __riscv_vle8_v_i8m1(q8+32, vl), vl); - vint16m1_t aux2 = __riscv_vredsum_vs_i16m2_i16m1(va_1, vzero, vl); - - sumf += d*scales[1]*__riscv_vmv_x_s_i16m1_i16(aux2); - - } - - *s = sumf; - -#elif defined(__POWER9_VECTOR__) - const vector signed char lowMask = vec_splats((signed char)0xF); - const vector unsigned char v4 = vec_splats((unsigned char)0x4); - - vector float vsumf0 = vec_splats(0.0f); - vector float vsumf1 = vec_splats(0.0f); - vector float vsumf2 = vec_splats(0.0f); - vector float vsumf3 = vec_splats(0.0f); - -#pragma GCC unroll 2 - for (int i = 0; i < nb; ++i) { - __builtin_prefetch(x[i].qs, 0, 1); - __builtin_prefetch(y[i].qs, 0, 1); - - vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d[1])); - vector float vyd = vec_splats(y[i].d); - vector float vd= vec_mul(vxd, vyd); - - uint16_t s16[2]; - const uint8_t * scales = (const uint8_t *)s16; - - const uint16_t * restrict b = (const uint16_t *)x[i].scales; - s16[0] = b[0] & 0x0f0f; - s16[1] = (b[0] >> 4) & 0x0f0f; - - vector signed char utmps = (vector signed char)vec_xl_len(scales, 4); - vector signed short vscales = (vector signed short)vec_unpackh(utmps); - vector signed short q4xmins0 = vec_mergeh(vscales, vscales); - q4xmins0 = vec_sld(q4xmins0, q4xmins0, 8); - - vector signed short q8ysums0 = vec_xl_len((const int16_t *)(y[i].bsums), 8); - - vector signed int prod0 = vec_mule(q4xmins0, q8ysums0); - vector signed int prod1 = vec_mulo(q4xmins0, q8ysums0); - - vsumf0 = vec_nmsub(vec_ctf(prod0, 0), vd, vsumf0); - vsumf1 = vec_nmsub(vec_ctf(prod1, 0), vd, vsumf1); - - vd = vec_mul(vyd, vec_splats(GGML_FP16_TO_FP32(x[i].d[0]))); - - vector signed char qxs0 = (vector signed char)vec_xl( 0, x[i].qs); - vector signed char qxs1 = (vector signed char)vec_xl(16, x[i].qs); - vector signed char q4x00 = vec_and(qxs0, lowMask); - vector signed char q4x01 = vec_sr(qxs0, v4); - vector signed char q4x10 = vec_and(qxs1, lowMask); - vector signed char q4x11 = vec_sr(qxs1, v4); - - vector signed char q8y00 = vec_xl( 0, y[i].qs); - vector signed char q8y10 = vec_xl(16, y[i].qs); - vector signed char q8y01 = vec_xl(32, y[i].qs); - vector signed char q8y11 = vec_xl(48, y[i].qs); - - vector signed short qv00 = vec_add(vec_mule(q4x00, q8y00), vec_mulo(q4x00, q8y00)); - vector signed short qv01 = vec_add(vec_mule(q4x01, q8y01), vec_mulo(q4x01, q8y01)); - vector signed short qv10 = vec_add(vec_mule(q4x10, q8y10), vec_mulo(q4x10, q8y10)); - vector signed short qv11 = vec_add(vec_mule(q4x11, q8y11), vec_mulo(q4x11, q8y11)); - - vector signed short vs0 = vec_splat(vscales, 0); - vector signed short vs1 = vec_splat(vscales, 1); - - vector signed int vsumi0 = vec_add(vec_mule(qv00, vs0), vec_mulo(qv00, vs0)); - vector signed int vsumi1 = vec_add(vec_mule(qv10, vs0), vec_mulo(qv10, vs0)); - vector signed int vsumi2 = vec_add(vec_mule(qv01, vs1), vec_mulo(qv01, vs1)); - vector signed int vsumi3 = vec_add(vec_mule(qv11, vs1), vec_mulo(qv11, vs1)); - - vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); - vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); - vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); - vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); - } - - vsumf0 = vec_add(vsumf0, vsumf2); - vsumf1 = vec_add(vsumf1, vsumf3); - - vsumf0 = vec_add(vsumf0, vsumf1); - - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); - - *s = vec_extract(vsumf0, 0); - -#else - - uint8_t aux8[QK_K]; - int16_t aux16[16]; - float sums [8]; - memset(sums, 0, 8*sizeof(float)); - - uint16_t s16[2]; - const uint8_t * restrict scales = (const uint8_t *)s16; - - float sumf = 0; - for (int i = 0; i < nb; ++i) { - const uint8_t * restrict q4 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - uint8_t * restrict a = aux8; - for (int l = 0; l < 32; ++l) a[l+ 0] = q4[l] & 0xF; - for (int l = 0; l < 32; ++l) a[l+32] = q4[l] >> 4; - - const uint16_t * restrict b = (const uint16_t *)x[i].scales; - s16[0] = b[0] & 0x0f0f; - s16[1] = (b[0] >> 4) & 0x0f0f; - - sumf -= y[i].d * GGML_FP16_TO_FP32(x[i].d[1]) * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3])); - - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d[0]); - - for (int j = 0; j < QK_K/32; ++j) { - for (int l = 0; l < 16; ++l) aux16[l] = q8[l] * a[l]; - q8 += 16; a += 16; - for (int l = 0; l < 16; ++l) aux16[l] += q8[l] * a[l]; - q8 += 16; a += 16; - const float dl = d * scales[j]; - for (int l = 0; l < 8; ++l) sums[l] += dl * (aux16[l] + aux16[l+8]); - } - } - for (int l = 0; l < 8; ++l) sumf += sums[l]; - *s = sumf; -#endif -} -#endif - -#if QK_K == 256 void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); @@ -7962,12 +7657,10 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r float summs = 0.f; - for (int i = 0; i < nb; ++i) { - + for (int i = 0; i < nb; ++i) { const uint8_t * restrict q5 = x[i].qs; const int8_t * restrict q8 = y[i].qs; -#if QK_K == 256 const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); @@ -7977,10 +7670,6 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); utmp[2] = uaux; utmp[0] &= kmask1; -#else - // TODO - const float d = 0, dmin = 0; -#endif const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0])); @@ -8348,6 +8037,92 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r *s = vec_extract(vsumf0, 0); +#elif defined __loongarch_asx + GGML_UNUSED(kmask1); + GGML_UNUSED(kmask2); + GGML_UNUSED(kmask3); + + const __m256i m4 = __lasx_xvreplgr2vr_b(0xF); + const __m128i mzero = __lsx_vldi(0); + const __m256i mone = __lasx_xvreplgr2vr_b(1); + + __m256 acc = (__m256)__lasx_xvldi(0); + + float summs = 0.f; + + for (int i = 0; i < nb; ++i) { + + const uint8_t * restrict q5 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + const __m256i mins_and_scales = lasx_extu8_16(lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0])); + + const __m256i q8sums = __lasx_xvld((const __m256i*)y[i].bsums, 0); + const __m128i q8s = lsx_hadd_h(lasx_extracti128(q8sums, 0), lasx_extracti128(q8sums, 1)); + const __m128i prod = lsx_madd_h(lasx_extracti128(mins_and_scales, 1), q8s); + const __m128i hsum = lsx_hadd_w(lsx_hadd_w(prod, mzero), mzero); + summs += dmin * __lsx_vpickve2gr_w(hsum, 0); //TODO check + + const __m128i sc128 = lasx_extracti128(mins_and_scales, 0); + const __m256i scales = lasx_insertf128(sc128, sc128); + + const __m256i hbits = __lasx_xvld((const __m256i*)x[i].qh, 0); + __m256i hmask = mone; + + __m256i sumi = __lasx_xvldi(0); + + int bit = 0; + __m256i xvbit; + + for (int j = 0; j < QK_K/64; ++j) { + + const __m256i scale_0 = lasx_shuffle_b(scales, get_scale_shuffle_k4(2*j+0)); + const __m256i scale_1 = lasx_shuffle_b(scales, get_scale_shuffle_k4(2*j+1)); + + const __m256i q5bits = __lasx_xvld((const __m256i*)q5, 0); q5 += 32; + + xvbit = __lasx_xvreplgr2vr_h(bit++); + const __m256i q5l_0 = __lasx_xvand_v(q5bits, m4); + const __m256i q5h_0 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvand_v(hbits, hmask), xvbit), 4); + const __m256i q5_0 = __lasx_xvadd_b(q5l_0, q5h_0); + hmask = __lasx_xvslli_h(hmask, 1); + + xvbit = __lasx_xvreplgr2vr_h(bit++); + const __m256i q5l_1 = __lasx_xvand_v(__lasx_xvsrli_h(q5bits, 4), m4); + const __m256i q5h_1 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvand_v(hbits, hmask), xvbit), 4); + const __m256i q5_1 = __lasx_xvadd_b(q5l_1, q5h_1); + hmask = __lasx_xvslli_h(hmask, 1); + + const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; + const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; + + __m256i p16_0 = lasx_maddubs_h(q5_0, q8_0); + __m256i p16_1 = lasx_maddubs_h(q5_1, q8_1); + + p16_0 = lasx_madd_h(scale_0, p16_0); + p16_1 = lasx_madd_h(scale_1, p16_1); + + sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_1)); + + } + + __m256 vd = __lasx_xvreplfr2vr_s(d); + acc = __lasx_xvfmadd_s(vd, __lasx_xvffint_s_w(sumi), acc); + + } + + *s = hsum_float_8(acc) + summs; + #else const uint8_t * scales = (const uint8_t*)&utmp[0]; @@ -8412,356 +8187,6 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r #endif } -#else - -void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { - assert(n % QK_K == 0); - assert(nrc == 1); - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - - const block_q5_K * restrict x = vx; - const block_q8_K * restrict y = vy; - - const int nb = n / QK_K; - -#ifdef __ARM_NEON - const uint8x16_t m4b = vdupq_n_u8(0xf); - const uint8x16_t mh = vdupq_n_u8(16); - const int32x4_t mzero = vdupq_n_s32(0); - - ggml_int8x16x4_t q5bytes; - ggml_uint8x16x4_t q5h; - - float sumf = 0; - - for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const int8_t * sc = x[i].scales; - - const uint8_t * restrict q5 = x[i].qs; - const uint8_t * restrict qh = x[i].qh; - const int8_t * restrict q8 = y[i].qs; - - const uint8x8_t qhbits = vld1_u8(qh); - - const ggml_uint8x16x2_t q5bits = ggml_vld1q_u8_x2(q5); - const ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8); - - const uint8x16_t htmp = vcombine_u8(qhbits, vshr_n_u8(qhbits, 1)); - q5h.val[0] = vbicq_u8(mh, vshlq_n_u8(htmp, 4)); - q5h.val[1] = vbicq_u8(mh, vshlq_n_u8(htmp, 2)); - q5h.val[2] = vbicq_u8(mh, htmp); - q5h.val[3] = vbicq_u8(mh, vshrq_n_u8(htmp, 2)); - - q5bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q5bits.val[0], m4b)), vreinterpretq_s8_u8(q5h.val[0])); - q5bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q5bits.val[1], m4b)), vreinterpretq_s8_u8(q5h.val[1])); - q5bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(q5bits.val[0], 4)), vreinterpretq_s8_u8(q5h.val[2])); - q5bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(q5bits.val[1], 4)), vreinterpretq_s8_u8(q5h.val[3])); - - int32_t sumi1 = sc[0] * vaddvq_s32(ggml_vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0])); - int32_t sumi2 = sc[1] * vaddvq_s32(ggml_vdotq_s32(mzero, q5bytes.val[1], q8bytes.val[1])); - int32_t sumi3 = sc[2] * vaddvq_s32(ggml_vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2])); - int32_t sumi4 = sc[3] * vaddvq_s32(ggml_vdotq_s32(mzero, q5bytes.val[3], q8bytes.val[3])); - - sumf += d * (sumi1 + sumi2 + sumi3 + sumi4); - } - - *s = sumf; - -#elif defined __AVX2__ - - const __m256i m4 = _mm256_set1_epi8(0xF); - const __m256i mone = _mm256_set1_epi8(1); - - __m256 acc = _mm256_setzero_ps(); - - for (int i = 0; i < nb; ++i) { - - const uint8_t * restrict q5 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - - const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5); - - const __m256i scale_l = MM256_SET_M128I(_mm_set1_epi16(x[i].scales[1]), _mm_set1_epi16(x[i].scales[0])); - const __m256i scale_h = MM256_SET_M128I(_mm_set1_epi16(x[i].scales[3]), _mm_set1_epi16(x[i].scales[2])); - - int64_t aux64; - memcpy(&aux64, x[i].qh, 8); - const __m128i haux128 = _mm_set_epi64x(aux64 >> 1, aux64); - const __m256i haux256 = MM256_SET_M128I(_mm_srli_epi16(haux128, 2), haux128); - - const __m256i q5h_0 = _mm256_slli_epi16(_mm256_andnot_si256(haux256, mone), 4); - const __m256i q5h_1 = _mm256_slli_epi16(_mm256_andnot_si256(_mm256_srli_epi16(haux256, 4), mone), 4); - - const __m256i q5l_0 = _mm256_and_si256(q5bits, m4); - const __m256i q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), m4); - - const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); - const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); - - const __m256i p16_0 = _mm256_madd_epi16(scale_l, _mm256_maddubs_epi16(q5l_0, q8_0)); - const __m256i p16_1 = _mm256_madd_epi16(scale_h, _mm256_maddubs_epi16(q5l_1, q8_1)); - const __m256i s16_0 = _mm256_madd_epi16(scale_l, _mm256_maddubs_epi16(q5h_0, q8_0)); - const __m256i s16_1 = _mm256_madd_epi16(scale_h, _mm256_maddubs_epi16(q5h_1, q8_1)); - - const __m256i dot = _mm256_sub_epi32(_mm256_add_epi32(p16_0, p16_1), _mm256_add_epi32(s16_0, s16_1)); - - acc = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(dot), acc); - - } - - *s = hsum_float_8(acc); - -#elif defined __AVX__ - - const __m128i m4 = _mm_set1_epi8(0xF); - const __m128i mone = _mm_set1_epi8(1); - - __m256 acc = _mm256_setzero_ps(); - - for (int i = 0; i < nb; ++i) { - - const uint8_t * restrict q5 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - - const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5); - - const __m128i scale_0 = _mm_set1_epi16(x[i].scales[0]); - const __m128i scale_1 = _mm_set1_epi16(x[i].scales[1]); - const __m128i scale_2 = _mm_set1_epi16(x[i].scales[2]); - const __m128i scale_3 = _mm_set1_epi16(x[i].scales[3]); - - int64_t aux64; - memcpy(&aux64, x[i].qh, 8); - const __m128i haux128_0 = _mm_set_epi64x(aux64 >> 1, aux64); - const __m128i haux128_1 = _mm_srli_epi16(haux128_0, 2); - - const __m128i q5h_0 = _mm_slli_epi16(_mm_andnot_si128(haux128_0, mone), 4); - const __m128i q5h_1 = _mm_slli_epi16(_mm_andnot_si128(haux128_1, mone), 4); - const __m128i q5h_2 = _mm_slli_epi16(_mm_andnot_si128(_mm_srli_epi16(haux128_0, 4), mone), 4); - const __m128i q5h_3 = _mm_slli_epi16(_mm_andnot_si128(_mm_srli_epi16(haux128_1, 4), mone), 4); - - const __m128i q5l_0 = _mm_and_si128(_mm256_extractf128_si256(q5bits, 0), m4); - const __m128i q5l_1 = _mm_and_si128(_mm256_extractf128_si256(q5bits, 1), m4); - const __m128i q5l_2 = _mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q5bits, 0), 4), m4); - const __m128i q5l_3 = _mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q5bits, 1), 4), m4); - - const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); - const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); - - const __m128i p16_0 = _mm_madd_epi16(scale_0, _mm_maddubs_epi16(q5l_0, _mm256_extractf128_si256(q8_0, 0))); - const __m128i p16_1 = _mm_madd_epi16(scale_1, _mm_maddubs_epi16(q5l_1, _mm256_extractf128_si256(q8_0, 1))); - const __m128i p16_2 = _mm_madd_epi16(scale_2, _mm_maddubs_epi16(q5l_2, _mm256_extractf128_si256(q8_1, 0))); - const __m128i p16_3 = _mm_madd_epi16(scale_3, _mm_maddubs_epi16(q5l_3, _mm256_extractf128_si256(q8_1, 1))); - const __m128i s16_0 = _mm_madd_epi16(scale_0, _mm_maddubs_epi16(q5h_0, _mm256_extractf128_si256(q8_0, 0))); - const __m128i s16_1 = _mm_madd_epi16(scale_1, _mm_maddubs_epi16(q5h_1, _mm256_extractf128_si256(q8_0, 1))); - const __m128i s16_2 = _mm_madd_epi16(scale_2, _mm_maddubs_epi16(q5h_2, _mm256_extractf128_si256(q8_1, 0))); - const __m128i s16_3 = _mm_madd_epi16(scale_3, _mm_maddubs_epi16(q5h_3, _mm256_extractf128_si256(q8_1, 1))); - - const __m128i dot_0 = _mm_sub_epi32(_mm_add_epi32(p16_0, p16_2), _mm_add_epi32(s16_0, s16_2)); - const __m128i dot_1 = _mm_sub_epi32(_mm_add_epi32(p16_1, p16_3), _mm_add_epi32(s16_1, s16_3)); - - acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(dot_1, dot_0))), acc); - - } - - *s = hsum_float_8(acc); - -#elif defined __riscv_v_intrinsic - - float sumf = 0; - - for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const int8_t * sc = x[i].scales; - - const uint8_t * restrict q5 = x[i].qs; - const uint8_t * restrict qh = x[i].qh; - const int8_t * restrict q8 = y[i].qs; - - vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); - - // load qh - vuint8mf4_t qh_x1 = __riscv_vle8_v_u8mf4(qh, 8); - vuint8mf2_t qh_x2 = __riscv_vlmul_ext_v_u8mf4_u8mf2(__riscv_vsrl_vx_u8mf4(qh_x1, 1, 8)); - - size_t vl = 16; - - // combine both qh_1 and qh_2 - vuint8mf2_t qh_x = __riscv_vslideup_vx_u8mf2(__riscv_vlmul_ext_v_u8mf4_u8mf2(qh_x1), qh_x2, vl/2, vl); - - vuint8mf2_t qh_h0 = __riscv_vand_vx_u8mf2(__riscv_vnot_v_u8mf2(__riscv_vsll_vx_u8mf2(qh_x, 0x4, vl), vl), 16, vl); - vuint8mf2_t qh_h1 = __riscv_vand_vx_u8mf2(__riscv_vnot_v_u8mf2(__riscv_vsll_vx_u8mf2(qh_x, 0x2, vl), vl), 16, vl); - vuint8mf2_t qh_h2 = __riscv_vand_vx_u8mf2(__riscv_vnot_v_u8mf2(qh_x, vl), 16, vl); - vuint8mf2_t qh_h3 = __riscv_vand_vx_u8mf2(__riscv_vnot_v_u8mf2(__riscv_vsrl_vx_u8mf2(qh_x, 0x4, vl), vl), 16, vl); - - vint8mf2_t qh_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(qh_h0); - vint8mf2_t qh_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(qh_h1); - vint8mf2_t qh_2 = __riscv_vreinterpret_v_u8mf2_i8mf2(qh_h2); - vint8mf2_t qh_3 = __riscv_vreinterpret_v_u8mf2_i8mf2(qh_h3); - - // load q5 - vuint8mf2_t q5_x1 = __riscv_vle8_v_u8mf2(q5, vl); - vuint8mf2_t q5_x2 = __riscv_vle8_v_u8mf2(q5+16, vl); - - vint8mf2_t q5s_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(q5_x1, 0xF, vl)); - vint8mf2_t q5s_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(q5_x2, 0xF, vl)); - vint8mf2_t q5s_2 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(q5_x1, 0x4, vl)); - vint8mf2_t q5s_3 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(q5_x2, 0x4, vl)); - - vint8mf2_t q5_0 = __riscv_vsub_vv_i8mf2(q5s_0, qh_0, vl); - vint8mf2_t q5_1 = __riscv_vsub_vv_i8mf2(q5s_1, qh_1, vl); - vint8mf2_t q5_2 = __riscv_vsub_vv_i8mf2(q5s_2, qh_2, vl); - vint8mf2_t q5_3 = __riscv_vsub_vv_i8mf2(q5s_3, qh_3, vl); - - // load Q8 and multiply it with Q5 - vint16m1_t p0 = __riscv_vwmul_vv_i16m1(q5_0, __riscv_vle8_v_i8mf2(q8, vl), vl); - vint16m1_t p1 = __riscv_vwmul_vv_i16m1(q5_1, __riscv_vle8_v_i8mf2(q8+16, vl), vl); - vint16m1_t p2 = __riscv_vwmul_vv_i16m1(q5_2, __riscv_vle8_v_i8mf2(q8+32, vl), vl); - vint16m1_t p3 = __riscv_vwmul_vv_i16m1(q5_3, __riscv_vle8_v_i8mf2(q8+48, vl), vl); - - vint32m1_t vs_0 = __riscv_vwredsum_vs_i16m1_i32m1(p0, vzero, vl); - vint32m1_t vs_1 = __riscv_vwredsum_vs_i16m1_i32m1(p1, vzero, vl); - vint32m1_t vs_2 = __riscv_vwredsum_vs_i16m1_i32m1(p2, vzero, vl); - vint32m1_t vs_3 = __riscv_vwredsum_vs_i16m1_i32m1(p3, vzero, vl); - - int32_t sumi1 = sc[0] * __riscv_vmv_x_s_i32m1_i32(vs_0); - int32_t sumi2 = sc[1] * __riscv_vmv_x_s_i32m1_i32(vs_1); - int32_t sumi3 = sc[2] * __riscv_vmv_x_s_i32m1_i32(vs_2); - int32_t sumi4 = sc[3] * __riscv_vmv_x_s_i32m1_i32(vs_3); - - sumf += d * (sumi1 + sumi2 + sumi3 + sumi4); - - } - - *s = sumf; - -#elif defined(__POWER9_VECTOR__) - const vector signed char lowMask = vec_splats((signed char)0xF); - const vector unsigned char v1 = vec_splats((unsigned char)0x1); - const vector unsigned char v2 = vec_splats((unsigned char)0x2); - const vector unsigned char v4 = vec_splats((unsigned char)0x4); - - vector float vsumf0 = vec_splats(0.0f); - vector float vsumf1 = vec_splats(0.0f); - vector float vsumf2 = vec_splats(0.0f); - vector float vsumf3 = vec_splats(0.0f); - -#pragma GCC unroll 2 - for (int i = 0; i < nb; ++i) { - __builtin_prefetch(x[i].qs, 0, 1); - __builtin_prefetch(y[i].qs, 0, 1); - - vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d)); - vector float vyd = vec_splats(y[i].d); - vector float vd= vec_mul(vxd, vyd); - - vector signed char qxs0 = (vector signed char)vec_xl( 0, x[i].qs); - vector signed char qxs1 = (vector signed char)vec_xl(16, x[i].qs); - vector signed char qxs00 = (vector signed char)vec_and(qxs0, lowMask); - vector signed char qxs01 = (vector signed char)vec_sr(qxs0, v4); - vector signed char qxs10 = (vector signed char)vec_and(qxs1, lowMask); - vector signed char qxs11 = (vector signed char)vec_sr(qxs1, v4); - - vector signed char qxhs = (vector signed char)vec_xl_len(x[i].qh, 8); - vector signed char qxhs0 = vec_or(qxhs, vec_sr(vec_sld(qxhs, qxhs, 8), v1)); - vector signed char qxhs1 = vec_sr(qxhs0, v2); - vector signed char qxh00 = vec_sl(vec_andc((vector signed char)v1, qxhs0), v4); - vector signed char qxh10 = vec_sl(vec_andc((vector signed char)v1, qxhs1), v4); - vector signed char qxh01 = vec_sl(vec_andc((vector signed char)v1, vec_sr(qxhs0, v4)), v4); - vector signed char qxh11 = vec_sl(vec_andc((vector signed char)v1, vec_sr(qxhs1, v4)), v4); - - vector signed char q5x00 = vec_sub(qxs00, qxh00); - vector signed char q5x10 = vec_sub(qxs10, qxh10); - vector signed char q5x01 = vec_sub(qxs01, qxh01); - vector signed char q5x11 = vec_sub(qxs11, qxh11); - - vector signed char q8y00 = vec_xl( 0, y[i].qs); - vector signed char q8y10 = vec_xl(16, y[i].qs); - vector signed char q8y01 = vec_xl(32, y[i].qs); - vector signed char q8y11 = vec_xl(48, y[i].qs); - - vector signed short qv00 = vec_add(vec_mule(q5x00, q8y00), vec_mulo(q5x00, q8y00)); - vector signed short qv01 = vec_add(vec_mule(q5x01, q8y01), vec_mulo(q5x01, q8y01)); - vector signed short qv10 = vec_add(vec_mule(q5x10, q8y10), vec_mulo(q5x10, q8y10)); - vector signed short qv11 = vec_add(vec_mule(q5x11, q8y11), vec_mulo(q5x11, q8y11)); - - vector signed short vs = (vector signed short)vec_unpackh(vec_xl_len(x[i].scales, 4)); - vector signed short vs0 = vec_splat(vs, 0); - vector signed short vs1 = vec_splat(vs, 1); - vector signed short vs2 = vec_splat(vs, 2); - vector signed short vs3 = vec_splat(vs, 3); - - vector signed int vsumi0 = vec_add(vec_mule(qv00, vs0), vec_mulo(qv00, vs0)); - vector signed int vsumi1 = vec_add(vec_mule(qv10, vs1), vec_mulo(qv10, vs1)); - vector signed int vsumi2 = vec_add(vec_mule(qv01, vs2), vec_mulo(qv01, vs2)); - vector signed int vsumi3 = vec_add(vec_mule(qv11, vs3), vec_mulo(qv11, vs3)); - - vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); - vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); - vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); - vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); - } - - vsumf0 = vec_add(vsumf0, vsumf2); - vsumf1 = vec_add(vsumf1, vsumf3); - - vsumf0 = vec_add(vsumf0, vsumf1); - - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); - - *s = vec_extract(vsumf0, 0); - -#else - - int8_t aux8[QK_K]; - int16_t aux16[16]; - float sums [8]; - memset(sums, 0, 8*sizeof(float)); - - float sumf = 0; - for (int i = 0; i < nb; ++i) { - const uint8_t * restrict q4 = x[i].qs; - const uint8_t * restrict hm = x[i].qh; - const int8_t * restrict q8 = y[i].qs; - int8_t * restrict a = aux8; - for (int l = 0; l < 32; ++l) { - a[l+ 0] = q4[l] & 0xF; - a[l+32] = q4[l] >> 4; - } - for (int is = 0; is < 8; ++is) { - uint8_t m = 1 << is; - for (int l = 0; l < 8; ++l) a[8*is + l] -= (hm[l] & m ? 0 : 16); - } - - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const int8_t * restrict sc = x[i].scales; - - for (int j = 0; j < QK_K/16; ++j) { - const float dl = d * sc[j]; - for (int l = 0; l < 16; ++l) aux16[l] = q8[l] * a[l]; - for (int l = 0; l < 8; ++l) sums[l] += dl * (aux16[l] + aux16[8+l]); - q8 += 16; a += 16; - } - } - for (int l = 0; l < 8; ++l) sumf += sums[l]; - *s = sumf; -#endif -} -#endif - - -#if QK_K == 256 void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); @@ -9297,6 +8722,84 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r *s = vec_extract(vsumf0, 0); +#elif defined __loongarch_asx + + const __m256i m4 = __lasx_xvreplgr2vr_b(0xF); + const __m256i m2 = __lasx_xvreplgr2vr_b(3); + const __m256i m32s = __lasx_xvreplgr2vr_b(32); + + __m256 acc = (__m256)__lasx_xvldi(0); + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + + const uint8_t * restrict q4 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + const __m128i scales = __lsx_vld((const __m128i*)x[i].scales, 0); + + __m256i sumi = __lasx_xvldi(0); + + int is = 0; + + for (int j = 0; j < QK_K/128; ++j) { + + const __m128i scale_0 = lsx_shuffle_b(scales, get_scale_shuffle(is + 0)); + const __m128i scale_1 = lsx_shuffle_b(scales, get_scale_shuffle(is + 1)); + const __m128i scale_2 = lsx_shuffle_b(scales, get_scale_shuffle(is + 2)); + const __m128i scale_3 = lsx_shuffle_b(scales, get_scale_shuffle(is + 3)); + is += 4; + + const __m256i q4bits1 = __lasx_xvld((const __m256i*)q4, 0); q4 += 32; + const __m256i q4bits2 = __lasx_xvld((const __m256i*)q4, 0); q4 += 32; + const __m256i q4bitsH = __lasx_xvld((const __m256i*)qh, 0); qh += 32; + + const __m256i q4h_0 = __lasx_xvslli_h(__lasx_xvand_v(q4bitsH, m2), 4); + const __m256i q4h_1 = __lasx_xvslli_h(__lasx_xvand_v(__lasx_xvsrli_h(q4bitsH, 2), m2), 4); + const __m256i q4h_2 = __lasx_xvslli_h(__lasx_xvand_v(__lasx_xvsrli_h(q4bitsH, 4), m2), 4); + const __m256i q4h_3 = __lasx_xvslli_h(__lasx_xvand_v(__lasx_xvsrli_h(q4bitsH, 6), m2), 4); + + const __m256i q4_0 = __lasx_xvor_v(__lasx_xvand_v(q4bits1, m4), q4h_0); + const __m256i q4_1 = __lasx_xvor_v(__lasx_xvand_v(q4bits2, m4), q4h_1); + const __m256i q4_2 = __lasx_xvor_v(__lasx_xvand_v(__lasx_xvsrli_h(q4bits1, 4), m4), q4h_2); + const __m256i q4_3 = __lasx_xvor_v(__lasx_xvand_v(__lasx_xvsrli_h(q4bits2, 4), m4), q4h_3); + + const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; + const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; + const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; + const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; + + __m256i q8s_0 = lasx_maddubs_h(m32s, q8_0); + __m256i q8s_1 = lasx_maddubs_h(m32s, q8_1); + __m256i q8s_2 = lasx_maddubs_h(m32s, q8_2); + __m256i q8s_3 = lasx_maddubs_h(m32s, q8_3); + + __m256i p16_0 = lasx_maddubs_h(q4_0, q8_0); + __m256i p16_1 = lasx_maddubs_h(q4_1, q8_1); + __m256i p16_2 = lasx_maddubs_h(q4_2, q8_2); + __m256i p16_3 = lasx_maddubs_h(q4_3, q8_3); + + p16_0 = __lasx_xvsub_h(p16_0, q8s_0); + p16_1 = __lasx_xvsub_h(p16_1, q8s_1); + p16_2 = __lasx_xvsub_h(p16_2, q8s_2); + p16_3 = __lasx_xvsub_h(p16_3, q8s_3); + + p16_0 = lasx_madd_h(lasx_ext8_16(scale_0), p16_0); + p16_1 = lasx_madd_h(lasx_ext8_16(scale_1), p16_1); + p16_2 = lasx_madd_h(lasx_ext8_16(scale_2), p16_2); + p16_3 = lasx_madd_h(lasx_ext8_16(scale_3), p16_3); + + sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_1)); + sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_2, p16_3)); + } + + acc = __lasx_xvfmadd_s((__m256)__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), acc); + } + + *s = hsum_float_8(acc); + #else int8_t aux8[QK_K]; @@ -9342,388 +8845,7 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r #endif } -#else - -void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { - assert(n % QK_K == 0); - assert(nrc == 1); - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - - const block_q6_K * restrict x = vx; - const block_q8_K * restrict y = vy; - - const int nb = n / QK_K; - -#ifdef __ARM_NEON - float sum = 0; - - const uint8x16_t m4b = vdupq_n_u8(0xF); - const int8x16_t m32s = vdupq_n_s8(32); - const int32x4_t vzero = vdupq_n_s32(0); - - const uint8x16_t mone = vdupq_n_u8(3); - - ggml_int8x16x4_t q6bytes; - ggml_uint8x16x4_t q6h; - - for (int i = 0; i < nb; ++i) { - - const float d_all = GGML_FP16_TO_FP32(x[i].d); - - const uint8_t * restrict q6 = x[i].ql; - const uint8_t * restrict qh = x[i].qh; - const int8_t * restrict q8 = y[i].qs; - - const int8_t * restrict scale = x[i].scales; - - int32_t isum = 0; - - uint8x16_t qhbits = vld1q_u8(qh); - ggml_uint8x16x2_t q6bits = ggml_vld1q_u8_x2(q6); - ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8); - - q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits), 4); - uint8x16_t shifted = vshrq_n_u8(qhbits, 2); - q6h.val[1] = vshlq_n_u8(vandq_u8(mone, shifted), 4); - shifted = vshrq_n_u8(qhbits, 4); - q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4); - shifted = vshrq_n_u8(qhbits, 6); - q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4); - - q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0])), m32s); - q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1])), m32s); - q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[2])), m32s); - q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[3])), m32s); - - isum += vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] + - vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] + - vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] + - vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3]; - - sum += isum * d_all * y[i].d; - - } - *s = sum; - -#elif defined __AVX2__ - - const __m256i m4 = _mm256_set1_epi8(0xF); - const __m256i m2 = _mm256_set1_epi8(3); - const __m256i m32s = _mm256_set1_epi8(32); - - __m256 acc = _mm256_setzero_ps(); - - for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - - const uint8_t * restrict q4 = x[i].ql; - const uint8_t * restrict qh = x[i].qh; - const int8_t * restrict q8 = y[i].qs; - - const __m64 scales_1 = _mm_set1_pi8(x[i].scales[0]); - const __m64 scales_2 = _mm_set1_pi8(x[i].scales[1]); - const __m64 scales_3 = _mm_set1_pi8(x[i].scales[2]); - const __m64 scales_4 = _mm_set1_pi8(x[i].scales[3]); - - __m256i sumi = _mm256_setzero_si256(); - - const __m128i scale_0 = _mm_set_epi64(scales_2, scales_1); - const __m128i scale_1 = _mm_set_epi64(scales_4, scales_3); - - const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4); - const __m128i q4bitsH = _mm_loadu_si128((const __m128i*)qh); - - const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(q4bitsH, 2), q4bitsH), m2), 4); - const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(q4bitsH, 6), _mm_srli_epi16(q4bitsH, 4)), m2), 4); - - const __m256i q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m4), q4h_0); - const __m256i q4_1 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m4), q4h_1); - - const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); - const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); - - __m256i q8s_0 = _mm256_maddubs_epi16(m32s, q8_0); - __m256i q8s_1 = _mm256_maddubs_epi16(m32s, q8_1); - - __m256i p16_0 = _mm256_maddubs_epi16(q4_0, q8_0); - __m256i p16_1 = _mm256_maddubs_epi16(q4_1, q8_1); - - p16_0 = _mm256_sub_epi16(p16_0, q8s_0); - p16_1 = _mm256_sub_epi16(p16_1, q8s_1); - - p16_0 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_0), p16_0); - p16_1 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_1), p16_1); - - sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1)); - - acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc); - } - - *s = hsum_float_8(acc); - -#elif defined __AVX__ - - const __m128i m4 = _mm_set1_epi8(0xF); - const __m128i m2 = _mm_set1_epi8(3); - const __m128i m32s = _mm_set1_epi8(32); - - __m256 acc = _mm256_setzero_ps(); - - for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - - const uint8_t * restrict q4 = x[i].ql; - const uint8_t * restrict qh = x[i].qh; - const int8_t * restrict q8 = y[i].qs; - - const __m64 scales_1 = _mm_set1_pi8(x[i].scales[0]); - const __m64 scales_2 = _mm_set1_pi8(x[i].scales[1]); - const __m64 scales_3 = _mm_set1_pi8(x[i].scales[2]); - const __m64 scales_4 = _mm_set1_pi8(x[i].scales[3]); - - __m128i sumi_0 = _mm_setzero_si128(); - __m128i sumi_1 = _mm_setzero_si128(); - - const __m128i scale_0 = _mm_set_epi64(scales_2, scales_1); - const __m128i scale_1 = _mm_set_epi64(scales_4, scales_3); - - const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4); - const __m128i q4bitsH = _mm_loadu_si128((const __m128i*)qh); - - const __m128i q4h_0 = _mm_slli_epi16(_mm_and_si128(q4bitsH, m2), 4); - const __m128i q4h_1 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH, 2), m2), 4); - const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH, 4), m2), 4); - const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH, 6), m2), 4); - - const __m128i q4_0 = _mm_or_si128(_mm_and_si128(_mm256_extractf128_si256(q4bits1, 0), m4), q4h_0); - const __m128i q4_1 = _mm_or_si128(_mm_and_si128(_mm256_extractf128_si256(q4bits1, 1), m4), q4h_1); - const __m128i q4_2 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q4bits1, 0), 4), m4), q4h_2); - const __m128i q4_3 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q4bits1, 1), 4), m4), q4h_3); - - const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); - const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); - - __m128i q8s_0 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_0, 0)); - __m128i q8s_1 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_0, 1)); - __m128i q8s_2 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_1, 0)); - __m128i q8s_3 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_1, 1)); - - __m128i p16_0 = _mm_maddubs_epi16(q4_0, _mm256_extractf128_si256(q8_0, 0)); - __m128i p16_1 = _mm_maddubs_epi16(q4_1, _mm256_extractf128_si256(q8_0, 1)); - __m128i p16_2 = _mm_maddubs_epi16(q4_2, _mm256_extractf128_si256(q8_1, 0)); - __m128i p16_3 = _mm_maddubs_epi16(q4_3, _mm256_extractf128_si256(q8_1, 1)); - - p16_0 = _mm_sub_epi16(p16_0, q8s_0); - p16_1 = _mm_sub_epi16(p16_1, q8s_1); - p16_2 = _mm_sub_epi16(p16_2, q8s_2); - p16_3 = _mm_sub_epi16(p16_3, q8s_3); - - p16_0 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0); - p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_0, scale_0)), p16_1); - p16_2 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_2); - p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_1, scale_1)), p16_3); - - sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2)); - sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3)); - - acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(MM256_SET_M128I(sumi_1, sumi_0))), acc); - } - - *s = hsum_float_8(acc); - -#elif defined __riscv_v_intrinsic - - float sumf = 0; - - for (int i = 0; i < nb; ++i) { - - const float d_all = GGML_FP16_TO_FP32(x[i].d); - - const uint8_t * restrict q6 = x[i].ql; - const uint8_t * restrict qh = x[i].qh; - const int8_t * restrict q8 = y[i].qs; - - const int8_t * restrict scale = x[i].scales; - - int32_t isum = 0; - - size_t vl = 16; - - vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); - - // load Q6 - vuint8mf2_t q6_0 = __riscv_vle8_v_u8mf2(q6, vl); - vuint8mf2_t q6_1 = __riscv_vle8_v_u8mf2(q6+16, vl); - - // load qh - vuint8mf2_t qh_x = __riscv_vle8_v_u8mf2(qh, vl); - - vuint8mf2_t qh0 = __riscv_vsll_vx_u8mf2(__riscv_vand_vx_u8mf2(qh_x, 0x3, vl), 0x4, vl); - qh_x = __riscv_vsrl_vx_u8mf2(qh_x, 0x2, vl); - vuint8mf2_t qh1 = __riscv_vsll_vx_u8mf2(__riscv_vand_vx_u8mf2(qh_x, 0x3, vl), 0x4, vl); - qh_x = __riscv_vsrl_vx_u8mf2(qh_x, 0x2, vl); - vuint8mf2_t qh2 = __riscv_vsll_vx_u8mf2(__riscv_vand_vx_u8mf2(qh_x, 0x3, vl), 0x4, vl); - qh_x = __riscv_vsrl_vx_u8mf2(qh_x, 0x2, vl); - vuint8mf2_t qh3 = __riscv_vsll_vx_u8mf2(__riscv_vand_vx_u8mf2(qh_x, 0x3, vl), 0x4, vl); - - vuint8mf2_t q6h_0 = __riscv_vor_vv_u8mf2(__riscv_vand_vx_u8mf2(q6_0, 0xF, vl), qh0, vl); - vuint8mf2_t q6h_1 = __riscv_vor_vv_u8mf2(__riscv_vand_vx_u8mf2(q6_1, 0xF, vl), qh1, vl); - vuint8mf2_t q6h_2 = __riscv_vor_vv_u8mf2(__riscv_vsrl_vx_u8mf2(q6_0, 0x4, vl), qh2, vl); - vuint8mf2_t q6h_3 = __riscv_vor_vv_u8mf2(__riscv_vsrl_vx_u8mf2(q6_1, 0x4, vl), qh3, vl); - - vint8mf2_t q6v_0 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(q6h_0), 32, vl); - vint8mf2_t q6v_1 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(q6h_1), 32, vl); - vint8mf2_t q6v_2 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(q6h_2), 32, vl); - vint8mf2_t q6v_3 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(q6h_3), 32, vl); - - // load Q8 and take product - vint16m1_t p0 = __riscv_vwmul_vv_i16m1(q6v_0, __riscv_vle8_v_i8mf2(q8, vl), vl); - vint16m1_t p1 = __riscv_vwmul_vv_i16m1(q6v_1, __riscv_vle8_v_i8mf2(q8+16, vl), vl); - vint16m1_t p2 = __riscv_vwmul_vv_i16m1(q6v_2, __riscv_vle8_v_i8mf2(q8+32, vl), vl); - vint16m1_t p3 = __riscv_vwmul_vv_i16m1(q6v_3, __riscv_vle8_v_i8mf2(q8+48, vl), vl); - - vint32m1_t vs_0 = __riscv_vwredsum_vs_i16m1_i32m1(p0, vzero, vl); - vint32m1_t vs_1 = __riscv_vwredsum_vs_i16m1_i32m1(p1, vzero, vl); - vint32m1_t vs_2 = __riscv_vwredsum_vs_i16m1_i32m1(p2, vzero, vl); - vint32m1_t vs_3 = __riscv_vwredsum_vs_i16m1_i32m1(p3, vzero, vl); - - isum += __riscv_vmv_x_s_i32m1_i32(vs_0) * scale[0]; - isum += __riscv_vmv_x_s_i32m1_i32(vs_1) * scale[1]; - isum += __riscv_vmv_x_s_i32m1_i32(vs_2) * scale[2]; - isum += __riscv_vmv_x_s_i32m1_i32(vs_3) * scale[3]; - - sumf += isum * d_all * y[i].d; - - } - - *s = sumf; - -#elif defined(__POWER9_VECTOR__) - const vector signed char lowMask = vec_splats((signed char)0xF); - const vector unsigned char v2 = vec_splats((unsigned char)0x2); - const vector unsigned char v3 = vec_splats((unsigned char)0x3); - const vector unsigned char v4 = vec_splats((unsigned char)0x4); - const vector unsigned char v6 = vec_splats((unsigned char)0x6); - const vector signed char off = vec_splats((signed char)0x20); - - vector float vsumf0 = vec_splats(0.0f); - vector float vsumf1 = vec_splats(0.0f); - vector float vsumf2 = vec_splats(0.0f); - vector float vsumf3 = vec_splats(0.0f); - -#pragma GCC unroll 2 - for (int i = 0; i < nb; ++i) { - __builtin_prefetch(x[i].ql, 0, 1); - __builtin_prefetch(x[i].qh, 0, 1); - __builtin_prefetch(y[i].qs, 0, 1); - - vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d)); - vector float vyd = vec_splats(y[i].d); - vector float vd= vec_mul(vxd, vyd); - - vector signed char qxs0 = (vector signed char)vec_xl( 0, x[i].ql); - vector signed char qxs1 = (vector signed char)vec_xl(16, x[i].ql); - vector signed char qxs00 = vec_and(qxs0, lowMask); - vector signed char qxs01 = vec_sr(qxs0, v4); - vector signed char qxs10 = vec_and(qxs1, lowMask); - vector signed char qxs11 = vec_sr(qxs1, v4); - - vector signed char qxhs0 = (vector signed char)vec_xl( 0, x[i].qh); - - vector signed char qxh00 = vec_sl(vec_and((vector signed char)v3, qxhs0), v4); - vector signed char qxh01 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs0, v4)), v4); - vector signed char qxh10 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs0, v2)), v4); - vector signed char qxh11 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs0, v6)), v4); - - vector signed char q6x00 = vec_sub(vec_or(qxh00, qxs00), off); - vector signed char q6x01 = vec_sub(vec_or(qxh01, qxs01), off); - vector signed char q6x10 = vec_sub(vec_or(qxh10, qxs10), off); - vector signed char q6x11 = vec_sub(vec_or(qxh11, qxs11), off); - - vector signed char q8y00 = vec_xl( 0, y[i].qs); - vector signed char q8y10 = vec_xl(16, y[i].qs); - vector signed char q8y01 = vec_xl(32, y[i].qs); - vector signed char q8y11 = vec_xl(48, y[i].qs); - - vector signed short qv00 = vec_add(vec_mule(q6x00, q8y00), vec_mulo(q6x00, q8y00)); - vector signed short qv10 = vec_add(vec_mule(q6x10, q8y10), vec_mulo(q6x10, q8y10)); - vector signed short qv01 = vec_add(vec_mule(q6x01, q8y01), vec_mulo(q6x01, q8y01)); - vector signed short qv11 = vec_add(vec_mule(q6x11, q8y11), vec_mulo(q6x11, q8y11)); - - vector signed short vs = (vector signed short)vec_unpackh(vec_xl_len(x[i].scales, 4)); - vector signed short vs0 = vec_splat(vs, 0); - vector signed short vs1 = vec_splat(vs, 1); - vector signed short vs2 = vec_splat(vs, 2); - vector signed short vs3 = vec_splat(vs, 3); - - vector signed int vsumi0 = vec_add(vec_mule(qv00, vs0), vec_mulo(qv00, vs0)); - vector signed int vsumi1 = vec_add(vec_mule(qv10, vs1), vec_mulo(qv10, vs1)); - vector signed int vsumi2 = vec_add(vec_mule(qv01, vs2), vec_mulo(qv01, vs2)); - vector signed int vsumi3 = vec_add(vec_mule(qv11, vs3), vec_mulo(qv11, vs3)); - - vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); - vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); - vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); - vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); - } - - vsumf0 = vec_add(vsumf0, vsumf2); - vsumf1 = vec_add(vsumf1, vsumf3); - - vsumf0 = vec_add(vsumf0, vsumf1); - - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); - - *s = vec_extract(vsumf0, 0); - -#else - - int8_t aux8[QK_K]; - int16_t aux16[8]; - float sums [8]; - int32_t aux32[8]; - memset(sums, 0, 8*sizeof(float)); - - float sumf = 0; - for (int i = 0; i < nb; ++i) { - const uint8_t * restrict q4 = x[i].ql; - const uint8_t * restrict qh = x[i].qh; - const int8_t * restrict q8 = y[i].qs; - memset(aux32, 0, 8*sizeof(int32_t)); - int8_t * restrict a = aux8; - for (int l = 0; l < 16; ++l) { - a[l+ 0] = (int8_t)((q4[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; - a[l+16] = (int8_t)((q4[l+16] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; - a[l+32] = (int8_t)((q4[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; - a[l+48] = (int8_t)((q4[l+16] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; - } - int is = 0; - for (int j = 0; j < QK_K/16; ++j) { - int scale = x[i].scales[is++]; - for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; - q8 += 8; a += 8; - for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; - q8 += 8; a += 8; - } - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; - } - for (int l = 0; l < 8; ++l) sumf += sums[l]; - *s = sumf; -#endif -} - -#endif - -#if defined (__AVX2__) || defined (__ARM_NEON) || defined (__POWER9_VECTOR__) +#if defined (__AVX2__) || defined (__ARM_NEON) || defined (__POWER9_VECTOR__) || defined(__loongarch_asx) static const int8_t keven_signs_q2xs[1024] = { 1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, -1, @@ -9953,6 +9075,49 @@ void ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, size_t bs, const void vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); *s = 0.125f * vec_extract(vsumf0, 0); + +#elif defined(__loongarch_asx) + + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + + uint32_t aux32[4]; + const uint8_t * aux8 = (const uint8_t *)aux32; + + __m256 accumf = (__m256)__lasx_xvldi(0); + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint16_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + __m256i sumi1 = __lasx_xvldi(0); + __m256i sumi2 = __lasx_xvldi(0); + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + const __m256i q8_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; + const __m256i q8_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; + memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8; + + const __m256i q2_1 = lasx_set_d(iq2xxs_grid[aux8[ 3]], iq2xxs_grid[aux8[ 2]], iq2xxs_grid[aux8[1]], iq2xxs_grid[aux8[0]]); + const __m256i q2_2 = lasx_set_d(iq2xxs_grid[aux8[11]], iq2xxs_grid[aux8[10]], iq2xxs_grid[aux8[9]], iq2xxs_grid[aux8[8]]); + const __m256i s2_1 = lasx_set_d(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127], + signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]); + const __m256i s2_2 = lasx_set_d(signs64[(aux32[3] >> 21) & 127], signs64[(aux32[3] >> 14) & 127], + signs64[(aux32[3] >> 7) & 127], signs64[(aux32[3] >> 0) & 127]); + const __m256i q8s_1 = __lasx_xvsigncov_b(s2_1, q8_1); + const __m256i q8s_2 = __lasx_xvsigncov_b(s2_2, q8_2); + const __m256i dot1 = lasx_maddubs_h(q2_1, q8s_1); + const __m256i dot2 = lasx_maddubs_h(q2_2, q8s_2); + const uint16_t ls1 = aux32[1] >> 28; + const uint16_t ls2 = aux32[3] >> 28; + const __m256i p1 = lasx_madd_h(dot1, __lasx_xvreplgr2vr_h(2*ls1+1)); + const __m256i p2 = lasx_madd_h(dot2, __lasx_xvreplgr2vr_h(2*ls2+1)); + sumi1 = __lasx_xvadd_w(sumi1, p1); + sumi2 = __lasx_xvadd_w(sumi2, p2); + } + + accumf = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accumf); + } + + *s = 0.125f * hsum_float_8(accumf); + #else uint32_t aux32[2]; @@ -10071,64 +9236,6 @@ void ggml_vec_dot_iq2_xs_q8_K(int n, float * restrict s, size_t bs, const void * const __m256i block_sign_shuffle_1 = _mm256_loadu_si256((const __m256i*)block_sign_shuffle_mask_1); const __m256i block_sign_shuffle_2 = _mm256_loadu_si256((const __m256i*)block_sign_shuffle_mask_2); -#if QK_K == 64 - static const uint8_t k_bit_helper[16] = { - 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00, - }; - const __m128i bit_helper = _mm_loadu_si128((const __m128i*)k_bit_helper); - const __m128i m511 = _mm_set1_epi16(511); - typedef union { - __m128i vec_index; - uint16_t index[8]; - } index_t; - - index_t idx; - __m256 accumf = _mm256_setzero_ps(); - for (int i = 0; i < nb; ++i) { - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const __m128i q2_data = _mm_loadu_si128((const __m128i*)x[i].qs); - idx.vec_index = _mm_and_si128(q2_data, m511); - - const __m128i partial_sign_bits = _mm_srli_epi16(q2_data, 9); - const __m128i partial_sign_bits_upper = _mm_srli_epi16(q2_data, 13); - const __m128i partial_sign_bits_for_counting = _mm_xor_si128(partial_sign_bits, partial_sign_bits_upper); - - const __m128i odd_bits = _mm_shuffle_epi8(bit_helper, partial_sign_bits_for_counting); - const __m128i full_sign_bits = _mm_or_si128(partial_sign_bits, odd_bits); - const __m256i full_signs = MM256_SET_M128I(full_sign_bits, full_sign_bits); - - const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)y[i].qs); - const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)(y[i].qs+32)); - - const __m256i q2_1 = _mm256_set_epi64x(iq2xs_grid[idx.index[3]], iq2xs_grid[idx.index[2]], - iq2xs_grid[idx.index[1]], iq2xs_grid[idx.index[0]]); - const __m256i q2_2 = _mm256_set_epi64x(iq2xs_grid[idx.index[7]], iq2xs_grid[idx.index[6]], - iq2xs_grid[idx.index[5]], iq2xs_grid[idx.index[4]]); - - __m256i signs; - signs = _mm256_shuffle_epi8(full_signs, block_sign_shuffle_1); - signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask); - const __m256i q8s_1 = _mm256_sign_epi8(q8_1, _mm256_or_si256(signs, mone)); - - signs = _mm256_shuffle_epi8(full_signs, block_sign_shuffle_2); - signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask); - const __m256i q8s_2 = _mm256_sign_epi8(q8_2, _mm256_or_si256(signs, mone)); - - const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1); - const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2); - - const __m256i sc1 = MM256_SET_M128I(_mm_set1_epi16(2*(x[i].scales[0] >> 4)+1), _mm_set1_epi16(2*(x[i].scales[0] & 0xf)+1)); - const __m256i sc2 = MM256_SET_M128I(_mm_set1_epi16(2*(x[i].scales[1] >> 4)+1), _mm_set1_epi16(2*(x[i].scales[1] & 0xf)+1)); - - const __m256i sum = _mm256_add_epi32(_mm256_madd_epi16(sc1, dot1), _mm256_madd_epi16(sc2, dot2)); - - accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(sum), accumf); - - } - - *s = 0.125f * hsum_float_8(accumf); -#else - static const uint8_t k_bit_helper[32] = { 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00, @@ -10226,8 +9333,123 @@ void ggml_vec_dot_iq2_xs_q8_K(int n, float * restrict s, size_t bs, const void * } *s = 0.125f * hsum_float_8(accumf); -#endif +#elif defined(__loongarch_asx) + const __m256i mone = __lasx_xvreplgr2vr_b(1); + static const char block_sign_shuffle_mask_1[32] = { + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, + 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, + }; + static const char block_sign_shuffle_mask_2[32] = { + 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, + 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, + }; + static const uint8_t bit_selector_mask_bytes[32] = { + 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + }; + + const __m256i bit_selector_mask = __lasx_xvld((const __m256i*)bit_selector_mask_bytes, 0); + const __m256i block_sign_shuffle_1 = __lasx_xvld((const __m256i*)block_sign_shuffle_mask_1, 0); + const __m256i block_sign_shuffle_2 = __lasx_xvld((const __m256i*)block_sign_shuffle_mask_2, 0); + + static const uint8_t k_bit_helper[32] = { + 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00, + 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00, + }; + const __m256i bit_helper = __lasx_xvld((const __m256i*)k_bit_helper, 0); + const __m256i m511 = __lasx_xvreplgr2vr_h(511); + const __m128i m4 = __lsx_vreplgr2vr_b(0xf); + const __m128i m1 = __lsx_vreplgr2vr_b(1); + + uint64_t aux64; + + // somewhat hacky, but gives a significant boost in performance + __m256i aux_gindex; + const uint16_t * gindex = (const uint16_t *)&aux_gindex; + + __m256 accumf = (__m256)__lasx_xvldi(0); + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint16_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + memcpy(&aux64, x[i].scales, 8); + __m128i stmp = __lsx_vreplgr2vr_d(aux64); + stmp = __lsx_vilvl_b( __lsx_vand_v(__lsx_vsrli_h(stmp, 4), m4), __lsx_vand_v(stmp, m4)); + const __m128i scales = __lsx_vadd_b(__lsx_vslli_h(stmp, 1), m1); + + __m256i sumi1 = __lasx_xvldi(0); + __m256i sumi2 = __lasx_xvldi(0); + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 4) { + + const __m256i q2_data = __lasx_xvld((const __m256i*)q2, 0); q2 += 16; + aux_gindex = __lasx_xvand_v(q2_data, m511); + + const __m256i partial_sign_bits = __lasx_xvsrli_h(q2_data, 9); + const __m256i partial_sign_bits_upper = __lasx_xvsrli_h(q2_data, 13); + const __m256i partial_sign_bits_for_counting = __lasx_xvxor_v(partial_sign_bits, partial_sign_bits_upper); + + const __m256i odd_bits = lasx_shuffle_b(bit_helper, partial_sign_bits_for_counting); + const __m256i full_sign_bits = __lasx_xvor_v(partial_sign_bits, odd_bits); + + const __m256i q8_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; + const __m256i q8_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; + const __m256i q8_3 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; + const __m256i q8_4 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; + + const __m256i q2_1 = lasx_set_d(iq2xs_grid[gindex[ 3]], iq2xs_grid[gindex[ 2]], + iq2xs_grid[gindex[ 1]], iq2xs_grid[gindex[ 0]]); + const __m256i q2_2 = lasx_set_d(iq2xs_grid[gindex[ 7]], iq2xs_grid[gindex[ 6]], + iq2xs_grid[gindex[ 5]], iq2xs_grid[gindex[ 4]]); + const __m256i q2_3 = lasx_set_d(iq2xs_grid[gindex[11]], iq2xs_grid[gindex[10]], + iq2xs_grid[gindex[ 9]], iq2xs_grid[gindex[ 8]]); + const __m256i q2_4 = lasx_set_d(iq2xs_grid[gindex[15]], iq2xs_grid[gindex[14]], + iq2xs_grid[gindex[13]], iq2xs_grid[gindex[12]]); + + const __m128i full_signs_l = lasx_extracti128(full_sign_bits, 0); + const __m128i full_signs_h = lasx_extracti128(full_sign_bits, 1); + const __m256i full_signs_1 = lasx_insertf128(full_signs_l, full_signs_l); + const __m256i full_signs_2 = lasx_insertf128(full_signs_h, full_signs_h); + + __m256i signs; + signs = lasx_shuffle_b(full_signs_1, block_sign_shuffle_1); + signs = __lasx_xvseq_b(__lasx_xvand_v(signs, bit_selector_mask), bit_selector_mask); + const __m256i q8s_1 = __lasx_xvsigncov_b(__lasx_xvor_v(signs, mone), q8_1); + + signs = lasx_shuffle_b(full_signs_1, block_sign_shuffle_2); + signs = __lasx_xvseq_b(__lasx_xvand_v(signs, bit_selector_mask), bit_selector_mask); + const __m256i q8s_2 = __lasx_xvsigncov_b(__lasx_xvor_v(signs, mone), q8_2); + + signs = lasx_shuffle_b(full_signs_2, block_sign_shuffle_1); + signs = __lasx_xvseq_b(__lasx_xvand_v(signs, bit_selector_mask), bit_selector_mask); + const __m256i q8s_3 = __lasx_xvsigncov_b(__lasx_xvor_v(signs, mone), q8_3); + + signs = lasx_shuffle_b(full_signs_2, block_sign_shuffle_2); + signs = __lasx_xvseq_b(__lasx_xvand_v(signs, bit_selector_mask), bit_selector_mask); + const __m256i q8s_4 = __lasx_xvsigncov_b(__lasx_xvor_v(signs, mone), q8_4); + + const __m256i dot1 = lasx_maddubs_h(q2_1, q8s_1); + const __m256i dot2 = lasx_maddubs_h(q2_2, q8s_2); + const __m256i dot3 = lasx_maddubs_h(q2_3, q8s_3); + const __m256i dot4 = lasx_maddubs_h(q2_4, q8s_4); + + const __m256i sc1 = lasx_ext8_16(lsx_shuffle_b(scales, get_scale_shuffle(ib32+0))); + const __m256i sc2 = lasx_ext8_16(lsx_shuffle_b(scales, get_scale_shuffle(ib32+1))); + const __m256i sc3 = lasx_ext8_16(lsx_shuffle_b(scales, get_scale_shuffle(ib32+2))); + const __m256i sc4 = lasx_ext8_16(lsx_shuffle_b(scales, get_scale_shuffle(ib32+3))); + + sumi1 = __lasx_xvadd_w(sumi1, lasx_madd_h(dot1, sc1)); + sumi2 = __lasx_xvadd_w(sumi2, lasx_madd_h(dot2, sc2)); + sumi1 = __lasx_xvadd_w(sumi1, lasx_madd_h(dot3, sc3)); + sumi2 = __lasx_xvadd_w(sumi2, lasx_madd_h(dot4, sc4)); + } + + accumf = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accumf); + + } + + *s = 0.125f * hsum_float_8(accumf); #elif defined(__POWER9_VECTOR__) vector float vsumf0 = vec_splats(0.0f); vector float vsumf1 = vec_splats(0.0f); @@ -10644,6 +9866,81 @@ void ggml_vec_dot_iq2_s_q8_K(int n, float * restrict s, size_t bs, const void * vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); *s = 0.125f * vec_extract(vsumf0, 0); + +#elif defined(__loongarch_asx) + + static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, + 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 + }; + + static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + }; + + + const __m128i m4 = __lsx_vreplgr2vr_b(0xf); + const __m128i m1 = __lsx_vreplgr2vr_b(1); + + const __m256i mask1 = __lasx_xvld((const __m256i*)k_mask1, 0); + const __m256i mask2 = __lasx_xvld((const __m256i*)k_mask2, 0); + uint64_t aux64; + + __m256 accumf = (__m256)__lasx_xvldi(0); + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint8_t * restrict qs = x[i].qs; + const uint8_t * restrict qh = x[i].qh; + const uint16_t * restrict signs = (const uint16_t *)(x[i].qs + QK_K/8); + const int8_t * restrict q8 = y[i].qs; + + __m128i tmp1; + memcpy(&aux64, x[i].scales, 8); + tmp1 = __lsx_vinsgr2vr_d(tmp1, aux64, 0); + tmp1 = __lsx_vinsgr2vr_d(tmp1, aux64 >> 4, 1); + const __m128i scales8 = __lsx_vadd_b(__lsx_vslli_h(__lsx_vand_v(tmp1, m4), 1), m1); + const __m256i scales16 = lasx_ext8_16(scales8); // 0 2 4 6 8 10 12 14 1 3 5 7 9 11 13 15 + + __m256i sumi1 = __lasx_xvldi(0); + __m256i sumi2 = __lasx_xvldi(0); + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + const __m256i q8_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; + const __m256i q8_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; + const __m256i q2_1 = lasx_set_d(iq2s_grid[qs[3] | ((qh[ib32+0] << 2) & 0x300)], + iq2s_grid[qs[2] | ((qh[ib32+0] << 4) & 0x300)], + iq2s_grid[qs[1] | ((qh[ib32+0] << 6) & 0x300)], + iq2s_grid[qs[0] | ((qh[ib32+0] << 8) & 0x300)]); + const __m256i q2_2 = lasx_set_d(iq2s_grid[qs[7] | ((qh[ib32+1] << 2) & 0x300)], + iq2s_grid[qs[6] | ((qh[ib32+1] << 4) & 0x300)], + iq2s_grid[qs[5] | ((qh[ib32+1] << 6) & 0x300)], + iq2s_grid[qs[4] | ((qh[ib32+1] << 8) & 0x300)]); + qs += 8; + + __m256i aux256 = __lasx_xvreplgr2vr_w(signs[0] | ((uint32_t) signs[1] << 16)); + aux256 = __lasx_xvand_v(lasx_shuffle_b(aux256,mask1), mask2); + const __m256i s2_1 = __lasx_xvseq_b(aux256, mask2); + const __m256i q8s_1 = __lasx_xvsub_b(__lasx_xvxor_v(s2_1, q8_1), s2_1); + + aux256 = __lasx_xvreplgr2vr_w(signs[2] | ((uint32_t) signs[3] << 16)); + aux256 = __lasx_xvand_v(lasx_shuffle_b(aux256,mask1), mask2); + const __m256i s2_2 = __lasx_xvseq_b(aux256, mask2); + const __m256i q8s_2 = __lasx_xvsub_b(__lasx_xvxor_v(s2_2, q8_2), s2_2); + + signs += 4; + + const __m256i dot1 = lasx_maddubs_h(q2_1, q8s_1); // blocks 2*ib32+0, 2*ib32+1 + const __m256i dot2 = lasx_maddubs_h(q2_2, q8s_2); // blocks 2*ib32+2, 2*ib32+3 + + const __m256i p1 = lasx_madd_h(dot1, lasx_shuffle_b(scales16, get_scale_shuffle_k4(ib32+0))); + const __m256i p2 = lasx_madd_h(dot2, lasx_shuffle_b(scales16, get_scale_shuffle_k4(ib32+1))); + sumi1 = __lasx_xvadd_w(sumi1, p1); + sumi2 = __lasx_xvadd_w(sumi2, p2); + } + + accumf = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accumf); + } + + *s = 0.125f * hsum_float_8(accumf); + #else float sumf = 0; @@ -10883,6 +10180,54 @@ void ggml_vec_dot_iq3_xxs_q8_K(int n, float * restrict s, size_t bs, const void vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); *s = 0.25f * vec_extract(vsumf0, 0); + +#elif defined(__loongarch_asx) + + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + + uint32_t aux32[2]; + + __m256 accumf = (__m256)__lasx_xvldi(0); + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint8_t * restrict q3 = x[i].qs; + const uint8_t * restrict gas = x[i].qs + QK_K/4; + const int8_t * restrict q8 = y[i].qs; + __m256i sumi1 = __lasx_xvldi(0); + __m256i sumi2 = __lasx_xvldi(0); + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + const __m256i q8_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; + const __m256i q8_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; + const __m256i q2_1 = lasx_set_w(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]], + iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]); + q3 += 8; + const __m256i q2_2 = lasx_set_w(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]], + iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]); + q3 += 8; + memcpy(aux32, gas, 8); gas += 8; + + const __m256i s2_1 = lasx_set_d(signs64[(aux32[0] >> 21) & 127], signs64[(aux32[0] >> 14) & 127], + signs64[(aux32[0] >> 7) & 127], signs64[(aux32[0] >> 0) & 127]); + const __m256i s2_2 = lasx_set_d(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127], + signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]); + const __m256i q8s_1 = __lasx_xvsigncov_b(s2_1, q8_1); + const __m256i q8s_2 = __lasx_xvsigncov_b(s2_2, q8_2); + const __m256i dot1 = lasx_maddubs_h(q2_1, q8s_1); + const __m256i dot2 = lasx_maddubs_h(q2_2, q8s_2); + const uint16_t ls1 = aux32[0] >> 28; + const uint16_t ls2 = aux32[1] >> 28; + + const __m256i p1 = lasx_madd_h(dot1, __lasx_xvreplgr2vr_h(2*ls1+1)); + const __m256i p2 = lasx_madd_h(dot2, __lasx_xvreplgr2vr_h(2*ls2+1)); + sumi1 = __lasx_xvadd_w(sumi1, p1); + sumi2 = __lasx_xvadd_w(sumi2, p2); + } + + accumf = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accumf); + } + + *s = 0.25f * hsum_float_8(accumf); + #else uint32_t aux32; @@ -10957,10 +10302,8 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * restrict s, size_t bs, const void * ggml_int8x16x4_t q8b; vec_index_t idx; -#if QK_K == 256 uint32_t scales32[2]; const uint8_t * scales8 = (const uint8_t *)scales32; -#endif float sumf = 0; for (int i = 0; i < nb; ++i) { @@ -10970,11 +10313,9 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * restrict s, size_t bs, const void * const uint16_t * restrict signs = (const uint16_t *)x[i].signs; const int8_t * restrict q8 = y[i].qs; -#if QK_K == 256 memcpy(scales32, x[i].scales, 4); scales32[1] = (((scales32[0] >> 4) & 0x0f0f0f0f) << 1) | 0x01010101; scales32[0] = ((scales32[0] & 0x0f0f0f0f) << 1) | 0x01010101; -#endif int sumi1 = 0, sumi2 = 0; for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { @@ -11015,13 +10356,9 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * restrict s, size_t bs, const void * const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[0], q8b.val[0]), q3s.val[1], q8b.val[1]); const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[2], q8b.val[2]), q3s.val[3], q8b.val[3]); -#if QK_K == 256 + sumi1 += vaddvq_s32(p1) * scales8[ib32/2+0]; sumi2 += vaddvq_s32(p2) * scales8[ib32/2+4]; -#else - sumi1 += vaddvq_s32(p1) * (1 + 2*(x[i].scales[ib32/2] & 0xf)); - sumi2 += vaddvq_s32(p2) * (1 + 2*(x[i].scales[ib32/2] >> 4)); -#endif } sumf += d*(sumi1 + sumi2); } @@ -11228,6 +10565,89 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * restrict s, size_t bs, const void * vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); *s = vec_extract(vsumf0, 0); + +#elif defined(__loongarch_asx) + + static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, + 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 + }; + + static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + }; + + const __m256i mask1 = __lasx_xvld((const __m256i*)k_mask1, 0); + const __m256i mask2 = __lasx_xvld((const __m256i*)k_mask2, 0); + + __m256i idx_shift = lasx_set_w(1, 2, 3, 4, 5, 6, 7, 8); + const __m256i idx_mask = __lasx_xvreplgr2vr_w(256); + + typedef union { + __m256i vec[2]; + uint32_t index[16]; + } index_t; + + index_t idx; + + __m256 accumf = (__m256)__lasx_xvldi(0); + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint8_t * restrict qs = x[i].qs; + const uint8_t * restrict qh = x[i].qh; + const uint16_t * restrict signs = (const uint16_t *)x[i].signs; + const int8_t * restrict q8 = y[i].qs; + __m256i sumi1 = __lasx_xvldi(0); + __m256i sumi2 = __lasx_xvldi(0); + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + const __m256i q8_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; + const __m256i q8_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; + const __m256i idx_l = lasx_extu8_16(__lsx_vld(qs, 0)); qs += 16; + idx.vec[0] = __lasx_xvreplgr2vr_w(qh[ib32+0]); + idx.vec[1] = __lasx_xvreplgr2vr_w(qh[ib32+1]); + idx.vec[0] = __lasx_xvand_v(__lasx_xvsll_w(idx.vec[0], idx_shift), idx_mask); + idx.vec[1] = __lasx_xvand_v(__lasx_xvsll_w(idx.vec[1], idx_shift), idx_mask); + idx.vec[0] = __lasx_xvor_v(idx.vec[0], lasx_ext16_32(lasx_extracti128(idx_l, 0))); + idx.vec[1] = __lasx_xvor_v(idx.vec[1], lasx_ext16_32(lasx_extracti128(idx_l, 1))); + + // At leat on my CPU (Ryzen 7950X), using _mm256_i32gather_epi32 is slower than _mm256_set_epi32. Strange. + //const __m256i q2_1 = _mm256_i32gather_epi32((const int *)iq3s_grid, idx.vec[0], 4); + //const __m256i q2_2 = _mm256_i32gather_epi32((const int *)iq3s_grid, idx.vec[1], 4); + const __m256i q2_1 = lasx_set_w( + iq3s_grid[idx.index[7]], iq3s_grid[idx.index[6]], iq3s_grid[idx.index[5]], iq3s_grid[idx.index[4]], + iq3s_grid[idx.index[3]], iq3s_grid[idx.index[2]], iq3s_grid[idx.index[1]], iq3s_grid[idx.index[0]] + ); + const __m256i q2_2 = lasx_set_w( + iq3s_grid[idx.index[15]], iq3s_grid[idx.index[14]], iq3s_grid[idx.index[13]], iq3s_grid[idx.index[12]], + iq3s_grid[idx.index[11]], iq3s_grid[idx.index[10]], iq3s_grid[idx.index[ 9]], iq3s_grid[idx.index[ 8]] + ); + + __m256i aux256 = __lasx_xvreplgr2vr_w(signs[0] | (signs[1] << 16)); + aux256 = __lasx_xvand_v(lasx_shuffle_b(aux256,mask1), mask2); + const __m256i s2_1 = __lasx_xvseq_b(aux256, mask2); + const __m256i q8s_1 = __lasx_xvsub_b(__lasx_xvxor_v(s2_1, q8_1), s2_1); + + aux256 = __lasx_xvreplgr2vr_w(signs[2] | (signs[3] << 16)); + aux256 = __lasx_xvand_v(lasx_shuffle_b(aux256,mask1), mask2); + const __m256i s2_2 = __lasx_xvseq_b(aux256, mask2); + const __m256i q8s_2 = __lasx_xvsub_b(__lasx_xvxor_v(s2_2, q8_2), s2_2); + + signs += 4; + + const __m256i dot1 = lasx_maddubs_h(q2_1, q8s_1); + const __m256i dot2 = lasx_maddubs_h(q2_2, q8s_2); + const uint16_t ls1 = x[i].scales[ib32/2] & 0xf; + const uint16_t ls2 = x[i].scales[ib32/2] >> 4; + const __m256i p1 = lasx_madd_h(dot1, __lasx_xvreplgr2vr_h(2*ls1+1)); + const __m256i p2 = lasx_madd_h(dot2, __lasx_xvreplgr2vr_h(2*ls2+1)); + sumi1 = __lasx_xvadd_w(sumi1, p1); + sumi2 = __lasx_xvadd_w(sumi2, p2); + } + + accumf = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accumf); + } + + *s = hsum_float_8(accumf); + #else float sumf = 0.f; @@ -11275,12 +10695,22 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * restrict s, size_t bs, const void * } -#ifdef __AVX2__ +#if defined(__AVX2__) static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) { const __m256i ax = _mm256_sign_epi8(x, x); const __m256i sy = _mm256_sign_epi8(y, x); return _mm256_maddubs_epi16(ax, sy); } +#elif defined(__loongarch_asx) +static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) { + const __m256i ax = __lasx_xvsigncov_b(x, x); + const __m256i sy = __lasx_xvsigncov_b(x, y); + __m256i tmp1, tmp2, tmp3; + tmp1 = __lasx_xvmulwev_h_bu_b(ax, sy); + tmp2 = __lasx_xvmulwod_h_bu_b(ax, sy); + tmp3 = __lasx_xvadd_h(tmp1, tmp2); + return __lasx_xvsat_h(tmp3, 15); +} #endif void ggml_vec_dot_iq1_s_q8_K (int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { @@ -11489,6 +10919,62 @@ void ggml_vec_dot_iq1_s_q8_K (int n, float * restrict s, size_t bs, const void vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); *s = vec_extract(vsumf0, 0); + +#elif defined(__loongarch_asx) + + __m256 accum = (__m256)__lasx_xvldi(0); + float accum1 = 0; + for (int i = 0; i < nb; ++i) { + + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint16_t * qh = x[i].qh; + + __m256i sumi = __lasx_xvldi(0); + int sumi1 = 0; + for (int ib = 0; ib < QK_K/32; ib += 2) { + __m256i q1b_1 = __lasx_xvinsgr2vr_d(q1b_1, iq1s_grid[qs[0] | ((qh[ib+0] << 8) & 0x700)], 0); + q1b_1 = __lasx_xvinsgr2vr_d(q1b_1, iq1s_grid[qs[1] | ((qh[ib+0] << 5) & 0x700)], 1); + q1b_1 = __lasx_xvinsgr2vr_d(q1b_1, iq1s_grid[qs[2] | ((qh[ib+0] << 2) & 0x700)], 2); + q1b_1 = __lasx_xvinsgr2vr_d(q1b_1, iq1s_grid[qs[3] | ((qh[ib+0] >> 1) & 0x700)], 3); + + __m256i q1b_2 = __lasx_xvinsgr2vr_d(q1b_2, iq1s_grid[qs[4] | ((qh[ib+1] << 8) & 0x700)], 0); + q1b_2 = __lasx_xvinsgr2vr_d(q1b_2, iq1s_grid[qs[5] | ((qh[ib+1] << 5) & 0x700)], 1); + q1b_2 = __lasx_xvinsgr2vr_d(q1b_2, iq1s_grid[qs[6] | ((qh[ib+1] << 2) & 0x700)], 2); + q1b_2 = __lasx_xvinsgr2vr_d(q1b_2, iq1s_grid[qs[7] | ((qh[ib+1] >> 1) & 0x700)], 3); + + qs += 8; + const __m256i q8b_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; + const __m256i q8b_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; + + const __m256i dot1 = mul_add_epi8(q1b_1, q8b_1); + const __m256i dot2 = mul_add_epi8(q1b_2, q8b_2); + const int16_t ls1 = 2*((qh[ib+0] >> 12) & 7) + 1; + const int16_t ls2 = 2*((qh[ib+1] >> 12) & 7) + 1; + + __m256i tmp1, tmp5, tmp6; + tmp1 = __lasx_xvreplgr2vr_h(ls1); + tmp5 = __lasx_xvmulwev_w_h(dot1, tmp1); + tmp6 = __lasx_xvmulwod_w_h(dot1, tmp1); + const __m256i p1 = __lasx_xvadd_w(tmp5, tmp6); + + tmp1 = __lasx_xvreplgr2vr_h(ls2); + tmp5 = __lasx_xvmulwev_w_h(dot2, tmp1); + tmp6 = __lasx_xvmulwod_w_h(dot2, tmp1); + const __m256i p2 = __lasx_xvadd_w(tmp5, tmp6); + + sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p1, p2)); + sumi1 += (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]) * (qh[ib+0] & 0x8000 ? -1 : 1) * ls1 + + (y[i].bsums[2*ib+2] + y[i].bsums[2*ib+3]) * (qh[ib+1] & 0x8000 ? -1 : 1) * ls2; + } + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + accum = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), accum); + accum1 += d * sumi1; + } + + *s = hsum_float_8(accum) + IQ1S_DELTA * accum1; + #else float sumf = 0; @@ -11536,17 +11022,10 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void const int nb = n / QK_K; -#if QK_K != 64 iq1m_scale_t scale; -#endif #if defined __ARM_NEON - -#if QK_K == 64 - const int32x4_t mask = vdupq_n_s32(0xf); -#else const int32x4_t mask = vdupq_n_s32(0x7); -#endif const int32x4_t mone = vdupq_n_s32(1); const int32x4_t mzero = vdupq_n_s32(0); @@ -11570,9 +11049,7 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void const uint8_t * qh = x[i].qh; const uint16_t * sc = (const uint16_t *)x[i].scales; -#if QK_K != 64 scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); -#endif int32x4_t sumi1 = mzero; int32x4_t sumi2 = mzero; @@ -11601,11 +11078,8 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void const int32x4_t p4 = vpaddq_s32(ggml_vdotq_s32(mzero, deltas.val[aux8[2]], q8b.val[2]), ggml_vdotq_s32(mzero, deltas.val[aux8[3]], q8b.val[3])); const int32x4_t p34 = vpaddq_s32(p3, p4); -#if QK_K == 64 - int32x4_t scales_4 = ggml_vld1q_u32(sc[0] >> 0, sc[0] >> 4, sc[0] >> 8, sc[0] >> 12); -#else int32x4_t scales_4 = ggml_vld1q_u32(sc[ib/2] >> 0, sc[ib/2] >> 3, sc[ib/2] >> 6, sc[ib/2] >> 9); -#endif + scales_4 = vaddq_s32(vshlq_n_s32(vandq_s32(scales_4, mask), 1), mone); sumi1 = vmlaq_s32(sumi1, scales_4, p12); @@ -11615,22 +11089,14 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void } -#if QK_K == 64 - sumf += y[i].d * GGML_FP16_TO_FP32(x[i].d) * (vaddvq_s32(sumi1) + IQ1M_DELTA * vaddvq_s32(sumi2)); -#else sumf += y[i].d * GGML_FP16_TO_FP32(scale.f16) * (vaddvq_s32(sumi1) + IQ1M_DELTA * vaddvq_s32(sumi2)); -#endif } *s = sumf; #elif defined __AVX2__ -#if QK_K == 64 - const __m256i mask = _mm256_set1_epi16(0xf); -#else const __m256i mask = _mm256_set1_epi16(0x7); -#endif const __m256i mone = _mm256_set1_epi16(1); __m256 accum1 = _mm256_setzero_ps(); @@ -11642,9 +11108,7 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void const uint8_t * qh = x[i].qh; const uint16_t * sc = (const uint16_t *)x[i].scales; -#if QK_K != 64 scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); -#endif __m256i sumi1 = _mm256_setzero_si256(); __m256i sumi2 = _mm256_setzero_si256(); @@ -11674,13 +11138,10 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void const __m256i dot3 = mul_add_epi8(delta1, q8b_1); const __m256i dot4 = mul_add_epi8(delta2, q8b_2); -#if QK_K == 64 - __m256i scale1 = MM256_SET_M128I(_mm_set1_epi16(sc[0] >> 4), _mm_set1_epi16(sc[0] >> 0)); - __m256i scale2 = MM256_SET_M128I(_mm_set1_epi16(sc[0] >> 12), _mm_set1_epi16(sc[0] >> 8)); -#else + __m256i scale1 = MM256_SET_M128I(_mm_set1_epi16(sc[ib/2] >> 3), _mm_set1_epi16(sc[ib/2] >> 0)); __m256i scale2 = MM256_SET_M128I(_mm_set1_epi16(sc[ib/2] >> 9), _mm_set1_epi16(sc[ib/2] >> 6)); -#endif + scale1 = _mm256_add_epi16(_mm256_slli_epi16(_mm256_and_si256(scale1, mask), 1), mone); scale2 = _mm256_add_epi16(_mm256_slli_epi16(_mm256_and_si256(scale2, mask), 1), mone); const __m256i p1 = _mm256_madd_epi16(dot1, scale1); @@ -11694,14 +11155,10 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void qs += 8; qh += 4; } -#if QK_K == 64 - const __m256 d = _mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(x[i].d)); -#else const __m256 d = _mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(scale.f16)); -#endif + accum1 = _mm256_fmadd_ps(d, _mm256_cvtepi32_ps(sumi1), accum1); accum2 = _mm256_fmadd_ps(d, _mm256_cvtepi32_ps(sumi2), accum2); - } *s = hsum_float_8(accum1) + IQ1M_DELTA * hsum_float_8(accum2); @@ -11718,9 +11175,7 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void const uint8_t * qh = x[i].qh; const uint16_t * sc = (const uint16_t *)x[i].scales; -#if QK_K != 64 scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); -#endif int sumi1 = 0, sumi2 = 0; for (int ib = 0; ib < QK_K/32; ++ib) { @@ -11740,24 +11195,17 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void sum1[l/2] += lsum1; sum2[l/2] += lsum2*delta[l]; } -#if QK_K == 64 - const int ls1 = 2*((sc[0] >> (8*(ib%2)+0)) & 0xf) + 1; - const int ls2 = 2*((sc[0] >> (8*(ib%2)+4)) & 0xf) + 1; -#else + const int ls1 = 2*((sc[ib/2] >> (6*(ib%2)+0)) & 0x7) + 1; const int ls2 = 2*((sc[ib/2] >> (6*(ib%2)+3)) & 0x7) + 1; -#endif + sumi1 += sum1[0] * ls1 + sum1[1] * ls2; sumi2 += sum2[0] * ls1 + sum2[1] * ls2; qs += 4; qh += 2; } -#if QK_K == 64 - sumf += GGML_FP16_TO_FP32(x[i].d) * y[i].d * (sumi1 + IQ1M_DELTA * sumi2); -#else sumf += GGML_FP16_TO_FP32(scale.f16) * y[i].d * (sumi1 + IQ1M_DELTA * sumi2); -#endif } *s = sumf; @@ -11890,6 +11338,39 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void * vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); *s = vec_extract(vsumf0, 0); + +#elif defined (__loongarch_asx) + + const __m128i values128 = __lsx_vld((const __m128i*)kvalues_iq4nl, 0); + const __m128i m4b = __lsx_vreplgr2vr_b(0x0f); + const __m256i mone = __lasx_xvreplgr2vr_h(1); + + __m256 accum1 = (__m256)__lasx_xvldi(0); + __m256 accum2 = (__m256)__lasx_xvldi(0); + for (int ib = 0; ib < nb; ib += 2) { + const __m128i q4bits_1 = __lsx_vld((const __m128i*)x[0].qs, 0); + const __m128i q4bits_2 = __lsx_vld((const __m128i*)x[1].qs, 0); + const __m256i q8b_1 = __lasx_xvld((const __m256i *)y[0].qs, 0); + const __m256i q8b_2 = __lasx_xvld((const __m256i *)y[1].qs, 0); + const __m256i q4b_1 = lasx_insertf128(lsx_shuffle_b(values128, __lsx_vand_v(__lsx_vsrli_h(q4bits_1, 4), m4b)), + lsx_shuffle_b(values128, __lsx_vand_v(q4bits_1, m4b))); + const __m256i q4b_2 = lasx_insertf128(lsx_shuffle_b(values128, __lsx_vand_v(__lsx_vsrli_h(q4bits_2, 4), m4b)), + lsx_shuffle_b(values128, __lsx_vand_v(q4bits_2, m4b))); + const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1); + const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2); + const __m256i p_1 = lasx_madd_h(p16_1, mone); + const __m256i p_2 = lasx_madd_h(p16_2, mone); + accum1 = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(y[0].d)*GGML_FP16_TO_FP32(x[0].d)), + __lasx_xvffint_s_w(p_1), accum1); + accum2 = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(y[1].d)*GGML_FP16_TO_FP32(x[1].d)), + __lasx_xvffint_s_w(p_2), accum2); + + y += 2; + x += 2; + } + + *s = hsum_float_8(__lasx_xvfadd_s(accum1, accum2)); + #else float sumf = 0; for (int ib = 0; ib < nb; ++ib) { @@ -11912,9 +11393,6 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * restrict s, size_t bs, const void * UNUSED(by); UNUSED(bs); assert(n % QK_K == 0); -#if QK_K == 64 - ggml_vec_dot_iq4_nl_q8_0(n, s, bs, vx, bx, vy, by, nrc); -#else const block_iq4_xs * restrict x = vx; const block_q8_K * restrict y = vy; @@ -12100,6 +11578,80 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * restrict s, size_t bs, const void * vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); *s = vec_extract(vsumf0, 0); + +#elif defined(__loongarch_asx) + + const __m128i values128 = __lsx_vld((const __m128i*)kvalues_iq4nl, 0); + const __m128i m4b = __lsx_vreplgr2vr_b(0x0f); + + __m256 accum = (__m256)__lasx_xvldi(0); + __m256i tmp1; + __m128i tmp0, tmp2, tmp3, tmp4, mask_8f, mask; + + mask_8f = __lsx_vreplgr2vr_b(0x8f); + for (int ibl = 0; ibl < nb; ++ibl) { + const uint8_t * qs = x[ibl].qs; + const int8_t * q8 = y[ibl].qs; + uint16_t sh = x[ibl].scales_h; + __m256i sumi1 = __lasx_xvldi(0); + __m256i sumi2 = __lasx_xvldi(0); + __m128i zero = __lsx_vldi(0); + for (int ib = 0; ib < QK_K/32; ib += 2) { + const __m128i q4bits_1 = __lsx_vld((const __m128i*)qs, 0); qs += 16; + const __m128i q4bits_2 = __lsx_vld((const __m128i*)qs, 0); qs += 16; + const __m256i q8b_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; + const __m256i q8b_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; + tmp2 = __lsx_vand_v(__lsx_vand_v(__lsx_vsrli_h(q4bits_1, 4), m4b), mask_8f); + tmp0 = __lsx_vori_b(tmp2, 0x10); + mask = __lsx_vsle_b(zero, tmp2); + tmp3 = __lsx_vand_v(tmp0, mask); + tmp3 = __lsx_vshuf_b(values128, zero, tmp3); + + tmp2 = __lsx_vand_v(__lsx_vand_v(q4bits_1, m4b), mask_8f); + tmp0 = __lsx_vori_b(tmp2, 0x10); + mask = __lsx_vsle_b(zero, tmp2); + tmp4 = __lsx_vand_v(tmp0, mask); + tmp4 = __lsx_vshuf_b(values128, zero, tmp4); + + const __m256i q4b_1 = lasx_insertf128(tmp3, tmp4); + + tmp2 = __lsx_vand_v(__lsx_vand_v(__lsx_vsrli_h(q4bits_2, 4), m4b), mask_8f); + tmp0 = __lsx_vori_b(tmp2, 0x10); + mask = __lsx_vsle_b(zero, tmp2); + tmp3 = __lsx_vand_v(tmp0, mask); + tmp3 = __lsx_vshuf_b(values128, zero, tmp3); + + tmp2 = __lsx_vand_v(__lsx_vand_v(q4bits_2, m4b), mask_8f); + tmp0 = __lsx_vori_b(tmp2, 0x10); + mask = __lsx_vsle_b(zero, tmp2); + tmp4 = __lsx_vand_v(tmp0, mask); + tmp4 = __lsx_vshuf_b(values128, zero, tmp4); + + const __m256i q4b_2 = lasx_insertf128(tmp3, tmp4); + + const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1); + const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2); + const int16_t ls1 = ((x[ibl].scales_l[ib/2] & 0xf) | ((sh << 4) & 0x30)) - 32; + const int16_t ls2 = ((x[ibl].scales_l[ib/2] >> 4) | ((sh << 2) & 0x30)) - 32; + sh >>= 4; + __m256i tmp5, tmp6; + tmp1 = __lasx_xvreplgr2vr_h(ls1); + tmp5 = __lasx_xvmulwev_w_h(p16_1, tmp1); + tmp6 = __lasx_xvmulwod_w_h(p16_1, tmp1); + const __m256i p_1 = __lasx_xvadd_w(tmp5, tmp6); + tmp1 = __lasx_xvreplgr2vr_h(ls2); + tmp5 = __lasx_xvmulwev_w_h(p16_2, tmp1); + tmp6 = __lasx_xvmulwod_w_h(p16_2, tmp1); + const __m256i p_2 = __lasx_xvadd_w(tmp5, tmp6); + sumi1 = __lasx_xvadd_w(p_1, sumi1); + sumi2 = __lasx_xvadd_w(p_2, sumi2); + } + accum = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(x[ibl].d)*y[ibl].d), + __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accum); + } + + *s = hsum_float_8(accum); + #else float sumf = 0; for (int ibl = 0; ibl < nb; ++ibl) { @@ -12133,7 +11685,6 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * restrict s, size_t bs, const void * } *s = sumf; #endif -#endif } // ================================ IQ2 quantization ============================================= @@ -12709,7 +12260,7 @@ static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict printf("\n"); GGML_ASSERT(false); } - q2[2*ib+0] |= (grid_index << 8*k); + q2[2*ib+0] |= ((uint32_t) grid_index << 8*k); q2[2*ib+1] |= (block_signs[k] << 7*k); } GGML_ASSERT(scale >= 0); @@ -13951,10 +13502,6 @@ static void quantize_row_iq1_m_impl(const float * restrict x, void * restrict vy const float * xx; for (int ibl = 0; ibl < nbl; ++ibl) { - -#if QK_K == 64 - y[ibl].d = GGML_FP32_TO_FP16(0.f); -#endif memset(y[ibl].qs, 0, QK_K/8); memset(y[ibl].qh, 0, QK_K/16); memset(y[ibl].scales, 0, QK_K/32); @@ -14129,22 +13676,13 @@ static void quantize_row_iq1_m_impl(const float * restrict x, void * restrict vy } uint16_t * sc = (uint16_t *)y[ibl].scales; -#if QK_K == 64 - float d = max_scale/31; -#else float d = max_scale/15; -#endif float id = 1/d; float sumqx_f = 0, sumq2_f = 0; for (int ib = 0; ib < QK_K/block_size; ++ib) { int l = nearest_int(0.5f*(id*scales[ib+0]-1)); -#if QK_K == 64 - l = MAX(0, MIN(15, l)); - sc[ib/4] |= (l << 4*(ib%4)); -#else l = MAX(0, MIN(7, l)); sc[ib/4] |= (l << 3*(ib%4)); -#endif y[ibl].qh[ib] |= masks[shifts[ib]]; const float * xb = xbl + block_size*ib; if (quant_weights) { @@ -14167,14 +13705,10 @@ static void quantize_row_iq1_m_impl(const float * restrict x, void * restrict vy } if (sumq2_f > 0) d = sumqx_f/sumq2_f; s.f16 = GGML_FP32_TO_FP16(d*1.1125f); // 1.1125f is another fudge factor. Don't ask me why it is needed. -#if QK_K == 64 - y[ibl].d = s.f16; -#else sc[0] |= ((s.u16 & 0x000f) << 12); sc[1] |= ((s.u16 & 0x00f0) << 8); sc[2] |= ((s.u16 & 0x0f00) << 4); sc[3] |= ((s.u16 & 0xf000) << 0); -#endif } } @@ -14363,9 +13897,6 @@ void quantize_row_iq4_nl_reference(const float * restrict x, block_iq4_nl * rest } size_t quantize_iq4_xs(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { -#if QK_K == 64 - return quantize_iq4_nl(src, dst, nrow, n_per_row, quant_weights); -#else GGML_ASSERT(n_per_row%QK_K == 0); int64_t nblock = n_per_row/QK_K; char * qrow = (char *)dst; @@ -14383,7 +13914,6 @@ size_t quantize_iq4_xs(const float * restrict src, void * restrict dst, int64_t qrow += nblock*sizeof(block_iq4_xs); } return nrow * nblock * sizeof(block_iq4_xs); -#endif } void quantize_row_iq4_xs(const float * restrict x, void * restrict vy, int64_t k) { @@ -14795,19 +14325,11 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte } break; case GGML_TYPE_Q4_K: { - #ifdef GGML_QKK_64 - VALIDATE_ROW_DATA_DM_F16_IMPL(block_q4_K, data, nb, d[0], d[1]); - #else VALIDATE_ROW_DATA_DM_F16_IMPL(block_q4_K, data, nb, d, dmin); - #endif } break; case GGML_TYPE_Q5_K: { - #ifdef GGML_QKK_64 - VALIDATE_ROW_DATA_D_F16_IMPL(block_q5_K, data, nb); - #else VALIDATE_ROW_DATA_DM_F16_IMPL(block_q5_K, data, nb, d, dmin); - #endif } break; case GGML_TYPE_Q6_K: { @@ -14830,18 +14352,12 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte { const block_iq1_m * q = (const block_iq1_m *) data; for (size_t i = 0; i < nb; ++i) { - #if QK_K == 64 - if (!validate_fp16(q[i].d, i)) { - return false; - } - #else iq1m_scale_t scale; const uint16_t * sc = (const uint16_t *)q[i].scales; scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); if (!validate_fp16(scale.f16, i)) { return false; } - #endif } } break; case GGML_TYPE_IQ2_XXS: @@ -14866,12 +14382,9 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte VALIDATE_ROW_DATA_D_F16_IMPL(block_iq3_s, data, nb); } break; case GGML_TYPE_IQ4_XS: - #if QK_K != 64 { VALIDATE_ROW_DATA_D_F16_IMPL(block_iq4_xs, data, nb); } break; - #endif - // with QK_K == 64, iq4_xs is iq4_nl case GGML_TYPE_IQ4_NL: { VALIDATE_ROW_DATA_D_F16_IMPL(block_iq4_nl, data, nb); diff --git a/llama/ggml-quants.h b/llama/ggml-quants.h index 54ba7e38..f6d07143 100644 --- a/llama/ggml-quants.h +++ b/llama/ggml-quants.h @@ -1,5 +1,5 @@ /** - * llama.cpp - git 059031b8c40e1f4ba60586842c5b1ed3ddf61842 + * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28 * * MIT License * diff --git a/llama/ggml.c b/llama/ggml.c index b4d99047..398bf731 100644 --- a/llama/ggml.c +++ b/llama/ggml.c @@ -1,5 +1,5 @@ /** - * llama.cpp - git 059031b8c40e1f4ba60586842c5b1ed3ddf61842 + * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28 * * MIT License * @@ -31,6 +31,7 @@ #include "ggml-quants.h" #include "ggml.h" + #if defined(_MSC_VER) || defined(__MINGW32__) #include // using malloc.h with MSC/MINGW #elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__) @@ -54,6 +55,10 @@ #include #endif +#ifdef GGML_USE_OPENMP +#include +#endif + #ifdef GGML_USE_METAL #include #endif @@ -86,6 +91,9 @@ typedef volatile LONG atomic_int; typedef atomic_int atomic_bool; +typedef atomic_int atomic_flag; + +#define ATOMIC_FLAG_INIT 0 static void atomic_store(atomic_int * ptr, LONG val) { InterlockedExchange(ptr, val); @@ -99,6 +107,12 @@ static LONG atomic_fetch_add(atomic_int * ptr, LONG inc) { static LONG atomic_fetch_sub(atomic_int * ptr, LONG dec) { return atomic_fetch_add(ptr, -(dec)); } +static atomic_bool atomic_flag_test_and_set(atomic_flag * ptr) { + return InterlockedExchange(ptr, 1); +} +static void atomic_flag_clear(atomic_flag * ptr) { + InterlockedExchange(ptr, 0); +} typedef HANDLE pthread_t; @@ -309,17 +323,12 @@ inline static void * ggml_calloc(size_t num, size_t size) { #if defined(GGML_USE_ACCELERATE) #include -#if defined(GGML_USE_CLBLAST) // allow usage of CLBlast alongside Accelerate functions -#include "ggml-opencl.h" -#endif #elif defined(GGML_USE_OPENBLAS) #if defined(GGML_BLAS_USE_MKL) #include #else #include #endif -#elif defined(GGML_USE_CLBLAST) -#include "ggml-opencl.h" #endif // floating point type used to accumulate sums @@ -432,10 +441,10 @@ void ggml_fp32_to_bf16_row(const float * x, ggml_bf16_t * y, int64_t n) { int i = 0; #if defined(__AVX512BF16__) for (; i + 32 <= n; i += 32) { - _mm512_storeu_ps( - (__m512 *)(y + i), - (__m512)_mm512_cvtne2ps_pbh(_mm512_loadu_ps(x + i + 16), - _mm512_loadu_ps(x + i))); + _mm512_storeu_si512( + (__m512i *)(y + i), + m512i(_mm512_cvtne2ps_pbh(_mm512_loadu_ps(x + i + 16), + _mm512_loadu_ps(x + i)))); } #endif for (; i < n; i++) { @@ -897,22 +906,14 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { }, [GGML_TYPE_IQ4_XS] = { .type_name = "iq4_xs", -#if QK_K == 64 - .blck_size = QK4_NL, -#else .blck_size = QK_K, -#endif .type_size = sizeof(block_iq4_xs), .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_iq4_xs, .from_float = quantize_row_iq4_xs, .from_float_reference = (ggml_from_float_t)quantize_row_iq4_xs_reference, .vec_dot = ggml_vec_dot_iq4_xs_q8_K, -#if QK_K == 64 - .vec_dot_type = GGML_TYPE_Q8_0, -#else .vec_dot_type = GGML_TYPE_Q8_K, -#endif .nrows = 1, }, [GGML_TYPE_Q8_K] = { @@ -1549,6 +1550,196 @@ static inline void __sse_f16x4_store(ggml_fp16_t *x, __m128 y) { #define GGML_F16_VEC_MUL GGML_F32Cx4_MUL #define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE +#elif defined(__loongarch_asx) + +#define GGML_SIMD + +// F32 LASX +#define GGML_F32_STEP 32 +#define GGML_F32_EPR 8 + +#define GGML_F32x8 __m256 +#define GGML_F32x8_ZERO (__m256)__lasx_xvldi(0) +#define GGML_F32x8_SET1(x) (__m256)__lasx_xvreplfr2vr_s((x)) +#define GGML_F32x8_LOAD(x) (__m256)__lasx_xvld((x), 0) +#define GGML_F32x8_STORE(x,y) __lasx_xvst((y), (x), 0) +#define GGML_F32x8_FMA(a, b, c) __lasx_xvfmadd_s(b, c, a) +#define GGML_F32x8_ADD __lasx_xvfadd_s +#define GGML_F32x8_MUL __lasx_xvfmul_s +#define GGML_F32x8_REDUCE(res, x) \ +do { \ + int offset = GGML_F32_ARR >> 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = __lasx_xvfadd_s(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = __lasx_xvfadd_s(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = __lasx_xvfadd_s(x[i], x[offset+i]); \ + } \ + float *tmp_p = (float *)&x[0]; \ + res = tmp_p[0] + tmp_p[1] + tmp_p[2] + tmp_p[3] + tmp_p[4] + tmp_p[5] + tmp_p[6] + tmp_p[7]; \ +} while (0) +// TODO: is this optimal ? + +#define GGML_F32_VEC GGML_F32x8 +#define GGML_F32_VEC_ZERO GGML_F32x8_ZERO +#define GGML_F32_VEC_SET1 GGML_F32x8_SET1 +#define GGML_F32_VEC_LOAD GGML_F32x8_LOAD +#define GGML_F32_VEC_STORE GGML_F32x8_STORE +#define GGML_F32_VEC_FMA GGML_F32x8_FMA +#define GGML_F32_VEC_ADD GGML_F32x8_ADD +#define GGML_F32_VEC_MUL GGML_F32x8_MUL +#define GGML_F32_VEC_REDUCE GGML_F32x8_REDUCE + +// F16 LASX + +#define GGML_F16_STEP 32 +#define GGML_F16_EPR 8 + +// F16 arithmetic is not supported by AVX, so we use F32 instead + +#define GGML_F32Cx8 __m256 +#define GGML_F32Cx8_ZERO (__m256)__lasx_xvldi(0) +#define GGML_F32Cx8_SET1(x) (__m256)__lasx_xvreplgr2vr_w((x)) + +static inline __m256 __lasx_f32cx8_load(const ggml_fp16_t * x) { + float tmp[8]; + + for (int i = 0; i < 8; i++) { + tmp[i] = GGML_FP16_TO_FP32(x[i]); + } + + return (__m256)__lasx_xvld(tmp, 0); +} +static inline void __lasx_f32cx8_store(ggml_fp16_t * x, __m256 y) { + float arr[8]; + + __lasx_xvst(y, arr, 0); + + for (int i = 0; i < 8; i++) { + x[i] = GGML_FP32_TO_FP16(arr[i]); + } +} +#define GGML_F32Cx8_LOAD(x) __lasx_f32cx8_load(x) +#define GGML_F32Cx8_STORE(x, y) __lasx_f32cx8_store(x, y) + +#define GGML_F32Cx8_FMA GGML_F32x8_FMA +#define GGML_F32Cx8_ADD __lasx_xvfadd_s +#define GGML_F32Cx8_MUL __lasx_xvfmul_s +#define GGML_F32Cx8_REDUCE GGML_F32x8_REDUCE + +#define GGML_F16_VEC GGML_F32Cx8 +#define GGML_F16_VEC_ZERO GGML_F32Cx8_ZERO +#define GGML_F16_VEC_SET1 GGML_F32Cx8_SET1 +#define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx8_LOAD(p) +#define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx8_STORE(p, r[i]) +#define GGML_F16_VEC_FMA GGML_F32Cx8_FMA +#define GGML_F16_VEC_ADD GGML_F32Cx8_ADD +#define GGML_F16_VEC_MUL GGML_F32Cx8_MUL +#define GGML_F16_VEC_REDUCE GGML_F32Cx8_REDUCE + +#elif defined(__loongarch_sx) + +#define GGML_SIMD + +// F32 LSX + +#define GGML_F32_STEP 32 +#define GGML_F32_EPR 4 + +#define GGML_F32x4 __m128 +#define GGML_F32x4_ZERO __lsx_vldi(0) +#define GGML_F32x4_SET1(x) __lsx_vinsgr2vr_w(__lsx_vldi(0),(x), 0) +#define GGML_F32x4_LOAD(x) __lsx_vld((x), 0) +#define GGML_F32x4_STORE((x),(y)) __lsx_vst((y), (x), 0) +#define GGML_F32x4_FMA(a, b, c) __lsx_vfmadd_s(b, c, a) +#define GGML_F32x4_ADD __lsx_vfadd_s +#define GGML_F32x4_MUL __lsx_vfmul_s +#define GGML_F32x4_REDUCE(res, x) \ +{ \ + int offset = GGML_F32_ARR >> 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \ + } \ + __m128i tmp = __lsx_vsrli_d((__m128i)x[0], 32); \ + tmp = (__m128i)__lsx_vfadd_s((__m128)tmp, x[0]); \ + tmp = __lsx_vpickev_w(__lsx_vldi(0), tmp); \ + const __m128 t0 = __lsx_vshuf4i_w(tmp, 0x88); \ + tmp = __lsx_vsrli_d((__m128i)t0, 32); \ + tmp = (__m128i)__lsx_vfadd_s((__m128)tmp, t0); \ + tmp = __lsx_vpickev_w(__lsx_vldi(0), tmp); \ + res = (ggml_float) __lsx_vpickve2gr_w(__lsx_vshuf4i_w(tmp, 0x88), 0); \ +} + +#define GGML_F32_VEC GGML_F32x4 +#define GGML_F32_VEC_ZERO GGML_F32x4_ZERO +#define GGML_F32_VEC_SET1 GGML_F32x4_SET1 +#define GGML_F32_VEC_LOAD GGML_F32x4_LOAD +#define GGML_F32_VEC_STORE GGML_F32x4_STORE +#define GGML_F32_VEC_FMA GGML_F32x4_FMA +#define GGML_F32_VEC_ADD GGML_F32x4_ADD +#define GGML_F32_VEC_MUL GGML_F32x4_MUL +#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE + +// F16 LSX + +#define GGML_F16_STEP 32 +#define GGML_F16_EPR 4 + +static inline __m128 __lsx_f16x4_load(const ggml_fp16_t * x) { + float tmp[4]; + + tmp[0] = GGML_FP16_TO_FP32(x[0]); + tmp[1] = GGML_FP16_TO_FP32(x[1]); + tmp[2] = GGML_FP16_TO_FP32(x[2]); + tmp[3] = GGML_FP16_TO_FP32(x[3]); + + return __lsx_vld(tmp, 0); +} + +static inline void __lsx_f16x4_store(ggml_fp16_t * x, __m128 y) { + float arr[4]; + + __lsx_vst(y, arr, 0); + + x[0] = GGML_FP32_TO_FP16(arr[0]); + x[1] = GGML_FP32_TO_FP16(arr[1]); + x[2] = GGML_FP32_TO_FP16(arr[2]); + x[3] = GGML_FP32_TO_FP16(arr[3]); +} + +#define GGML_F32Cx4 __m128 +#define GGML_F32Cx4_ZERO __lsx_vldi(0) +#define GGML_F32Cx4_SET1(x) __lsx_vinsgr2vr_w(__lsx_vldi(0),(x), 0) +#define GGML_F32Cx4_LOAD(x) __lsx_f16x4_load(x) +#define GGML_F32Cx4_STORE(x, y) __lsx_f16x4_store(x, y) +#define GGML_F32Cx4_FMA GGML_F32x4_FMA +#define GGML_F32Cx4_ADD __lsx_vfadd_s +#define GGML_F32Cx4_MUL __lsx_vfmul_s +#define GGML_F32Cx4_REDUCE GGML_F32x4_REDUCE + +#define GGML_F16_VEC GGML_F32Cx4 +#define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO +#define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1 +#define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p) +#define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE(p, r[i]) +#define GGML_F16_VEC_FMA GGML_F32Cx4_FMA +#define GGML_F16_VEC_ADD GGML_F32Cx4_ADD +#define GGML_F16_VEC_MUL GGML_F32Cx4_MUL +#define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE + #endif // GGML_F32_ARR / GGML_F16_ARR @@ -1591,7 +1782,7 @@ struct ggml_compute_state_shared { int64_t perf_node_start_cycles; int64_t perf_node_start_time_us; - const int n_threads; + int n_threads; // synchronization primitives atomic_int n_active; // num active threads @@ -1692,10 +1883,10 @@ static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t __m512 c1 = _mm512_setzero_ps(); __m512 c2 = _mm512_setzero_ps(); for (; i + 64 <= n; i += 64) { - c1 = _mm512_dpbf16_ps(c1, (__m512bh)_mm512_loadu_ps((const float *)(x + i)), - (__m512bh)_mm512_loadu_ps((const float *)(y + i))); - c2 = _mm512_dpbf16_ps(c2, (__m512bh)_mm512_loadu_ps((const float *)(x + i + 32)), - (__m512bh)_mm512_loadu_ps((const float *)(y + i + 32))); + c1 = _mm512_dpbf16_ps(c1, m512bh(_mm512_loadu_si512((x + i))), + m512bh(_mm512_loadu_si512((y + i)))); + c2 = _mm512_dpbf16_ps(c2, m512bh(_mm512_loadu_si512((x + i + 32))), + m512bh(_mm512_loadu_si512((y + i + 32)))); } sumf += (ggml_float)_mm512_reduce_add_ps(c1); sumf += (ggml_float)_mm512_reduce_add_ps(c2); @@ -2102,6 +2293,11 @@ inline static float ggml_silu_f32(float x) { return x/(1.0f + expf(-x)); } +#if __FINITE_MATH_ONLY__ +#error "some routines in ggml.c require non-finite math arithmetics -- pass -fno-finite-math-only to the compiler to fix" +#error "ref: https://github.com/ggerganov/llama.cpp/pull/7154#issuecomment-2143844461" +#endif + #if defined(__ARM_NEON) && defined(__aarch64__) // adapted from arm limited optimized routine @@ -2151,32 +2347,27 @@ inline static __m512 ggml_v_expf(__m512 x) { const __m512 r = _mm512_set1_ps(0x1.8p23f); const __m512 z = _mm512_fmadd_ps(x, _mm512_set1_ps(0x1.715476p+0f), r); const __m512 n = _mm512_sub_ps(z, r); - const __m512 b = _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.7f7d1cp-20f), - _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.62e4p-1f), x)); - const __m512i e = _mm512_slli_epi32(_mm512_castps_si512(z), 23); - const __m512 k = _mm512_castsi512_ps(_mm512_add_epi32(e, _mm512_castps_si512(_mm512_set1_ps(1)))); - const __mmask16 c = _mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(126), _CMP_GT_OQ); - const __m512 u = _mm512_mul_ps(b, b); - const __m512 j = _mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_set1_ps(0x1.0e4020p-7f), b, - _mm512_set1_ps(0x1.573e2ep-5f)), u, - _mm512_fmadd_ps(_mm512_set1_ps(0x1.555e66p-3f), b, - _mm512_set1_ps(0x1.fffdb6p-2f))), - u, _mm512_mul_ps(_mm512_set1_ps(0x1.ffffecp-1f), b)); - if (_mm512_kortestz(c, c)) - return _mm512_fmadd_ps(j, k, k); - const __m512i g = _mm512_and_si512( - _mm512_movm_epi32(_mm512_cmp_ps_mask(n, _mm512_setzero_ps(), _CMP_LE_OQ)), - _mm512_set1_epi32(0x82000000u)); - const __m512 s1 = - _mm512_castsi512_ps(_mm512_add_epi32(g, _mm512_set1_epi32(0x7f000000u))); - const __m512 s2 = _mm512_castsi512_ps(_mm512_sub_epi32(e, g)); + const __m512 b = + _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.7f7d1cp-20f), + _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.62e4p-1f), x)); const __mmask16 d = _mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(192), _CMP_GT_OQ); - return _mm512_mask_blend_ps( - d, _mm512_mask_blend_ps( - c, _mm512_fmadd_ps(k, j, k), - _mm512_mul_ps(_mm512_fmadd_ps(s2, j, s2), s1)), - _mm512_mul_ps(s1, s1)); + const __m512 u = _mm512_mul_ps(b, b); + const __m512 j = _mm512_fmadd_ps( + _mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_set1_ps(0x1.0e4020p-7f), b, + _mm512_set1_ps(0x1.573e2ep-5f)), + u, + _mm512_fmadd_ps(_mm512_set1_ps(0x1.555e66p-3f), b, + _mm512_set1_ps(0x1.fffdb6p-2f))), + u, + _mm512_fmadd_ps(_mm512_set1_ps(0x1.ffffecp-1f), b, _mm512_set1_ps(1.0F))); + const __m512 res = _mm512_scalef_ps(j, n); + if (_mm512_kortestz(d, d)) + return res; + const __m512 zero = _mm512_setzero_ps(); + const __m512 alt = _mm512_mask_blend_ps( + _mm512_cmp_ps_mask(n, zero, _CMP_LE_OQ), _mm512_set1_ps(INFINITY), zero); + return _mm512_mask_blend_ps(d, res, alt); } // computes silu x/(1+exp(-x)) in single precision vector @@ -2515,9 +2706,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "ARGSORT", "LEAKY_RELU", - "FLASH_ATTN", "FLASH_ATTN_EXT", - "FLASH_FF", "FLASH_ATTN_BACK", "SSM_CONV", "SSM_SCAN", @@ -2543,7 +2732,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CROSS_ENTROPY_LOSS_BACK", }; -static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76"); +static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -2605,9 +2794,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "argsort(x)", "leaky_relu(x)", - "flash_attn(x)", "flash_attn_ext(x)", - "flash_ff(x)", "flash_attn_back(x)", "ssm_conv(x)", "ssm_scan(x)", @@ -2633,7 +2820,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "cross_entropy_loss_back(x,y)", }; -static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76"); +static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -2732,24 +2919,20 @@ struct ggml_state { // global state static struct ggml_state g_state; -static atomic_int g_state_barrier = 0; +static atomic_flag g_state_critical = ATOMIC_FLAG_INIT; // barrier via spin lock inline static void ggml_critical_section_start(void) { - int processing = atomic_fetch_add(&g_state_barrier, 1); - - while (processing > 0) { - // wait for other threads to finish - atomic_fetch_sub(&g_state_barrier, 1); - sched_yield(); // TODO: reconsider this - processing = atomic_fetch_add(&g_state_barrier, 1); + while (atomic_flag_test_and_set(&g_state_critical)) { + // spin + sched_yield(); } } // TODO: make this somehow automatically executed // some sort of "sentry" mechanism inline static void ggml_critical_section_end(void) { - atomic_fetch_sub(&g_state_barrier, 1); + atomic_flag_clear(&g_state_critical); } #if defined(__gnu_linux__) @@ -3065,7 +3248,11 @@ GGML_CALL bool ggml_is_contiguous(const struct ggml_tensor * tensor) { tensor->nb[3] == tensor->nb[2]*tensor->ne[2]; } -static inline bool ggml_is_contiguous_except_dim_1(const struct ggml_tensor * tensor) { +GGML_CALL bool ggml_is_contiguous_0(const struct ggml_tensor * tensor) { + return ggml_is_contiguous(tensor); +} + +GGML_CALL bool ggml_is_contiguous_1(const struct ggml_tensor * tensor) { static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); return @@ -3074,6 +3261,14 @@ static inline bool ggml_is_contiguous_except_dim_1(const struct ggml_tensor * te tensor->nb[3] == tensor->nb[2]*tensor->ne[2]; } +GGML_CALL bool ggml_is_contiguous_2(const struct ggml_tensor * tensor) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return + tensor->nb[0] == ggml_type_size(tensor->type) && + tensor->nb[3] == tensor->nb[2]*tensor->ne[2]; +} + GGML_CALL bool ggml_is_permuted(const struct ggml_tensor * tensor) { static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); @@ -3206,10 +3401,6 @@ struct ggml_context * ggml_init(struct ggml_init_params params) { GGML_PRINT_DEBUG("%s: g_state initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f); } -#if defined(GGML_USE_CLBLAST) - ggml_cl_init(); -#endif - ggml_setup_op_has_task_pass(); is_first_call = false; @@ -4731,10 +4922,21 @@ struct ggml_tensor * ggml_repeat_back( // ggml_concat struct ggml_tensor * ggml_concat( - struct ggml_context* ctx, - struct ggml_tensor* a, - struct ggml_tensor* b) { - GGML_ASSERT(a->ne[0] == b->ne[0] && a->ne[1] == b->ne[1] && a->ne[3] == b->ne[3]); + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int dim) { + GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS); + + int64_t ne[GGML_MAX_DIMS]; + for (int d = 0; d < GGML_MAX_DIMS; ++d) { + if (d == dim) { + ne[d] = a->ne[d] + b->ne[d]; + continue; + } + GGML_ASSERT(a->ne[d] == b->ne[d]); + ne[d] = a->ne[d]; + } bool is_node = false; @@ -4742,7 +4944,9 @@ struct ggml_tensor * ggml_concat( is_node = true; } - struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, a->ne[0], a->ne[1], a->ne[2] + b->ne[2], a->ne[3]); + struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, GGML_MAX_DIMS, ne); + + ggml_set_op_params_i32(result, 0, dim); result->op = GGML_OP_CONCAT; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; @@ -4862,6 +5066,7 @@ struct ggml_tensor * ggml_leaky_relu( } struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + ggml_set_op_params(result, &negative_slope, sizeof(negative_slope)); result->op = GGML_OP_LEAKY_RELU; @@ -6068,23 +6273,28 @@ static struct ggml_tensor * ggml_rope_impl( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, + struct ggml_tensor * c, int n_dims, int mode, - int n_ctx, - int n_orig_ctx, + int n_ctx_orig, float freq_base, float freq_scale, float ext_factor, float attn_factor, float beta_fast, float beta_slow, - float xpos_base, - bool xpos_down, bool inplace) { + GGML_ASSERT((mode & 1) == 0 && "mode & 1 == 1 is no longer supported"); + GGML_ASSERT(ggml_is_vector(b)); GGML_ASSERT(b->type == GGML_TYPE_I32); GGML_ASSERT(a->ne[2] == b->ne[0]); + if (c) { + GGML_ASSERT(c->type == GGML_TYPE_F32); + GGML_ASSERT(c->ne[0] >= n_dims / 2); + } + bool is_node = false; if (a->grad) { @@ -6093,21 +6303,20 @@ static struct ggml_tensor * ggml_rope_impl( struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - int32_t params[13] = { /*n_past*/ 0, n_dims, mode, n_ctx, n_orig_ctx }; + int32_t params[11] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig }; memcpy(params + 5, &freq_base, sizeof(float)); memcpy(params + 6, &freq_scale, sizeof(float)); memcpy(params + 7, &ext_factor, sizeof(float)); memcpy(params + 8, &attn_factor, sizeof(float)); memcpy(params + 9, &beta_fast, sizeof(float)); memcpy(params + 10, &beta_slow, sizeof(float)); - memcpy(params + 11, &xpos_base, sizeof(float)); - memcpy(params + 12, &xpos_down, sizeof(bool)); ggml_set_op_params(result, params, sizeof(params)); result->op = GGML_OP_ROPE; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; result->src[1] = b; + result->src[2] = c; return result; } @@ -6117,10 +6326,9 @@ struct ggml_tensor * ggml_rope( struct ggml_tensor * a, struct ggml_tensor * b, int n_dims, - int mode, - int n_ctx) { + int mode) { return ggml_rope_impl( - ctx, a, b, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, false + ctx, a, b, NULL, n_dims, mode, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, false ); } @@ -6129,10 +6337,49 @@ struct ggml_tensor * ggml_rope_inplace( struct ggml_tensor * a, struct ggml_tensor * b, int n_dims, - int mode, - int n_ctx) { + int mode) { return ggml_rope_impl( - ctx, a, b, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, true + ctx, a, b, NULL, n_dims, mode, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, true + ); +} + +struct ggml_tensor * ggml_rope_ext( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c, + int n_dims, + int mode, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow) { + return ggml_rope_impl( + ctx, a, b, c, n_dims, mode, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow, false + ); +} + +struct ggml_tensor * ggml_rope_ext_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c, + int n_dims, + int mode, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow) { + return ggml_rope_impl( + ctx, a, b, c, n_dims, mode, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow, true ); } @@ -6142,8 +6389,7 @@ struct ggml_tensor * ggml_rope_custom( struct ggml_tensor * b, int n_dims, int mode, - int n_ctx, - int n_orig_ctx, + int n_ctx_orig, float freq_base, float freq_scale, float ext_factor, @@ -6151,8 +6397,8 @@ struct ggml_tensor * ggml_rope_custom( float beta_fast, float beta_slow) { return ggml_rope_impl( - ctx, a, b, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false + ctx, a, b, NULL, n_dims, mode, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow, false ); } @@ -6162,8 +6408,7 @@ struct ggml_tensor * ggml_rope_custom_inplace( struct ggml_tensor * b, int n_dims, int mode, - int n_ctx, - int n_orig_ctx, + int n_ctx_orig, float freq_base, float freq_scale, float ext_factor, @@ -6171,42 +6416,31 @@ struct ggml_tensor * ggml_rope_custom_inplace( float beta_fast, float beta_slow) { return ggml_rope_impl( - ctx, a, b, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true + ctx, a, b, NULL, n_dims, mode, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow, true ); } -struct ggml_tensor * ggml_rope_xpos_inplace( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - int n_dims, - float base, - bool down) { - return ggml_rope_impl(ctx, a, b, n_dims, 0, 0, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, base, down, true); -} - // ggml_rope_back struct ggml_tensor * ggml_rope_back( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, + struct ggml_tensor * c, int n_dims, int mode, - int n_ctx, - int n_orig_ctx, + int n_ctx_orig, float freq_base, float freq_scale, float ext_factor, float attn_factor, float beta_fast, - float beta_slow, - float xpos_base, - bool xpos_down) { + float beta_slow) { GGML_ASSERT(ggml_is_vector(b)); GGML_ASSERT(b->type == GGML_TYPE_I32); GGML_ASSERT(a->ne[2] == b->ne[0]); + GGML_ASSERT(c == NULL && "freq factors not implemented yet"); GGML_ASSERT((mode & 4) == 0 && "ggml_rope_back() for ChatGLM not implemented yet"); @@ -6218,15 +6452,13 @@ struct ggml_tensor * ggml_rope_back( struct ggml_tensor * result = ggml_dup_tensor(ctx, a); - int32_t params[13] = { /*n_past*/ 0, n_dims, mode, n_ctx, n_orig_ctx }; + int32_t params[11] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig }; memcpy(params + 5, &freq_base, sizeof(float)); memcpy(params + 6, &freq_scale, sizeof(float)); memcpy(params + 7, &ext_factor, sizeof(float)); memcpy(params + 8, &attn_factor, sizeof(float)); memcpy(params + 9, &beta_fast, sizeof(float)); memcpy(params + 10, &beta_slow, sizeof(float)); - memcpy(params + 11, &xpos_base, sizeof(float)); - memcpy(params + 12, &xpos_down, sizeof(bool)); ggml_set_op_params(result, params, sizeof(params)); result->op = GGML_OP_ROPE_BACK; @@ -6750,38 +6982,6 @@ struct ggml_tensor * ggml_top_k( return result; } -// ggml_flash_attn - -struct ggml_tensor * ggml_flash_attn( - struct ggml_context * ctx, - struct ggml_tensor * q, - struct ggml_tensor * k, - struct ggml_tensor * v, - bool masked) { - GGML_ASSERT(ggml_can_mul_mat(k, q)); - // TODO: check if vT can be multiplied by (k*qT) - - bool is_node = false; - - if (q->grad || k->grad || v->grad) { - is_node = true; - } - - //struct ggml_tensor * result = ggml_dup_tensor(ctx, q); - struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, q->ne); - - int32_t t = masked ? 1 : 0; - ggml_set_op_params(result, &t, sizeof(t)); - - result->op = GGML_OP_FLASH_ATTN; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = q; - result->src[1] = k; - result->src[2] = v; - - return result; -} - // ggml_flash_attn_ext struct ggml_tensor * ggml_flash_attn_ext( @@ -6841,38 +7041,6 @@ void ggml_flash_attn_ext_set_prec( ggml_set_op_params_i32(a, 2, prec_i32); // scale is on first pos, max_bias on second } -// ggml_flash_ff - -struct ggml_tensor * ggml_flash_ff( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b0, - struct ggml_tensor * b1, - struct ggml_tensor * c0, - struct ggml_tensor * c1) { - GGML_ASSERT(ggml_can_mul_mat(b0, a)); - // TODO: more checks - - bool is_node = false; - - if (a->grad || b0->grad || b1->grad || c0->grad || c1->grad) { - is_node = true; - } - - //struct ggml_tensor * result = ggml_dup_tensor(ctx, a); - struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, a->ne); - - result->op = GGML_OP_FLASH_FF; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src[0] = a; - result->src[1] = b0; - result->src[2] = b1; - result->src[3] = c0; - result->src[4] = c1; - - return result; -} - // ggml_flash_attn_back struct ggml_tensor * ggml_flash_attn_back( @@ -6882,6 +7050,8 @@ struct ggml_tensor * ggml_flash_attn_back( struct ggml_tensor * v, struct ggml_tensor * d, bool masked) { + GGML_ASSERT(false && "TODO: adapt to ggml_flash_attn_ext() changes"); + GGML_ASSERT(ggml_can_mul_mat(k, q)); // TODO: check if vT can be multiplied by (k*qT) @@ -8874,17 +9044,6 @@ static void ggml_compute_forward_add_f32( const int ith = params->ith; const int nth = params->nth; -#ifdef GGML_USE_CLBLAST - if (src1->backend == GGML_BACKEND_TYPE_GPU) { - // TODO: OpenCL kernel support full broadcast - GGML_ASSERT(ggml_can_repeat_rows(src1, src0)); - if (ith == 0) { - ggml_cl_add(src0, src1, dst); - } - return; - } -#endif - const int nr = ggml_nrows(src0); GGML_TENSOR_BINARY_OP_LOCALS @@ -9992,17 +10151,6 @@ static void ggml_compute_forward_mul_f32( const int ith = params->ith; const int nth = params->nth; -#if defined(GGML_USE_CLBLAST) - if (src1->backend == GGML_BACKEND_TYPE_GPU) { - // TODO: OpenCL kernel support full broadcast - GGML_ASSERT(ggml_can_repeat_rows(src1, src0)); - if (ith == 0) { - ggml_cl_mul(src0, src1, dst); - } - return; - } -#endif - const int64_t nr = ggml_nrows(src0); GGML_TENSOR_BINARY_OP_LOCALS @@ -10835,26 +10983,29 @@ static void ggml_compute_forward_concat_f32( GGML_ASSERT(nb00 == sizeof(float)); GGML_ASSERT(nb10 == sizeof(float)); + const int32_t dim = ggml_get_op_params_i32(dst, 0); + + GGML_ASSERT(dim >= 0 && dim < 4); + + int64_t o[4] = {0, 0, 0, 0}; + o[dim] = src0->ne[dim]; + + const float * x; + + // TODO: smarter multi-theading for (int i3 = 0; i3 < ne3; i3++) { for (int i2 = ith; i2 < ne2; i2 += nth) { - if (i2 < ne02) { // src0 - for (int i1 = 0; i1 < ne1; i1++) { - for (int i0 = 0; i0 < ne0; i0++) { - const float * x = (float *)((char *) src0->data + i0 * nb00 + i1 * nb01 + i2 * nb02 + i3 * nb03); - - float * y = (float *)((char *)dst->data + i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3); - *y = *x; + for (int i1 = 0; i1 < ne1; i1++) { + for (int i0 = 0; i0 < ne0; i0++) { + if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) { + x = (const float *) ((const char *)src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03); + } else { + x = (const float *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13); } - } - } // src1 - else { - for (int i1 = 0; i1 < ne1; i1++) { - for (int i0 = 0; i0 < ne0; i0++) { - const float * x = (float *)((char *) src1->data + i0 * nb10 + i1 * nb11 + (i2 - ne02) * nb12 + i3 * nb13); - float * y = (float *)((char *)dst->data + i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3); - *y = *x; - } + float * y = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3); + + *y = *x; } } } @@ -10862,8 +11013,8 @@ static void ggml_compute_forward_concat_f32( } static void ggml_compute_forward_concat( - const struct ggml_compute_params* params, - struct ggml_tensor* dst) { + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { const struct ggml_tensor * src0 = dst->src[0]; @@ -11256,8 +11407,8 @@ static void ggml_compute_forward_gelu_f32( const struct ggml_tensor * src0 = dst->src[0]; - GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0)); - GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst)); + GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(ggml_is_contiguous_1(dst)); GGML_ASSERT(ggml_are_same_shape(src0, dst)); if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { @@ -11319,8 +11470,8 @@ static void ggml_compute_forward_gelu_quick_f32( const struct ggml_tensor * src0 = dst->src[0]; - GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0)); - GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst)); + GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(ggml_is_contiguous_1(dst)); GGML_ASSERT(ggml_are_same_shape(src0, dst)); if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { @@ -11382,8 +11533,8 @@ static void ggml_compute_forward_silu_f32( const struct ggml_tensor * src0 = dst->src[0]; - GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0)); - GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst)); + GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(ggml_is_contiguous_1(dst)); GGML_ASSERT(ggml_are_same_shape(src0, dst)); if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { @@ -11494,9 +11645,9 @@ static void ggml_compute_forward_silu_back_f32( const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * grad = dst->src[1]; - GGML_ASSERT(ggml_is_contiguous_except_dim_1(grad)); - GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0)); - GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst)); + GGML_ASSERT(ggml_is_contiguous_1(grad)); + GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(ggml_is_contiguous_1(dst)); GGML_ASSERT(ggml_are_same_shape(src0, dst)); GGML_ASSERT(ggml_are_same_shape(src0, grad)); @@ -12235,15 +12386,6 @@ static void ggml_compute_forward_mul_mat( // nb01 >= nb00 - src0 is not transposed // compute by src0 rows -#if defined(GGML_USE_CLBLAST) - if (ggml_cl_can_mul_mat(src0, src1, dst)) { - if (params->ith == 0 && params->type == GGML_TASK_TYPE_COMPUTE) { - ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize); - } - return; - } -#endif - #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) if (ggml_compute_forward_mul_mat_use_blas(dst)) { const int64_t ne_plane = ne01*ne00; @@ -12691,8 +12833,6 @@ static void ggml_compute_forward_out_prod_f32( // nb01 >= nb00 - src0 is not transposed // compute by src0 rows - // TODO: #if defined(GGML_USE_CLBLAST) - #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) bool use_blas = ggml_is_matrix(src0) && ggml_is_matrix(src1) && @@ -12890,7 +13030,7 @@ static void ggml_compute_forward_out_prod_q_f32( // nb01 >= nb00 - src0 is not transposed // compute by src0 rows - // TODO: #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST) + // TODO: #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) if (params->type == GGML_TASK_TYPE_INIT) { if (ith != 0) { @@ -14087,8 +14227,7 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) { // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. static void rope_yarn( float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale, - float * cos_theta, float * sin_theta -) { + float * cos_theta, float * sin_theta) { // Get n-d rotational scaling corrected for extrapolation float theta_interp = freq_scale * theta_extrap; float theta = theta_interp; @@ -14105,18 +14244,19 @@ static void rope_yarn( // Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get // `corr_dim(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))` -static float ggml_rope_yarn_corr_dim(int n_dims, int n_orig_ctx, float n_rot, float base) { - return n_dims * logf(n_orig_ctx / (n_rot * 2 * (float)M_PI)) / (2 * logf(base)); +static float ggml_rope_yarn_corr_dim(int n_dims, int n_ctx_orig, float n_rot, float base) { + return n_dims * logf(n_ctx_orig / (n_rot * 2 * (float)M_PI)) / (2 * logf(base)); } static void ggml_rope_cache_init( - float theta_base, float freq_scale, float corr_dims[2], int64_t ne0, float ext_factor, float mscale, - float * cache, float sin_sign, float theta_scale -) { + float theta_base, float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale, + float * cache, float sin_sign, float theta_scale) { + // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py float theta = theta_base; for (int64_t i0 = 0; i0 < ne0; i0 += 2) { + const float ff = freq_factors ? freq_factors[i0/2] : 1.0f; rope_yarn( - theta, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1] + theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1] ); cache[i0 + 1] *= sin_sign; @@ -14125,11 +14265,11 @@ static void ggml_rope_cache_init( } GGML_CALL void ggml_rope_yarn_corr_dims( - int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2] + int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2] ) { // start and end correction dims - float start = floorf(ggml_rope_yarn_corr_dim(n_dims, n_orig_ctx, beta_fast, freq_base)); - float end = ceilf(ggml_rope_yarn_corr_dim(n_dims, n_orig_ctx, beta_slow, freq_base)); + float start = floorf(ggml_rope_yarn_corr_dim(n_dims, n_ctx_orig, beta_fast, freq_base)); + float end = ceilf(ggml_rope_yarn_corr_dim(n_dims, n_ctx_orig, beta_slow, freq_base)); dims[0] = MAX(0, start); dims[1] = MIN(n_dims - 1, end); } @@ -14141,6 +14281,7 @@ static void ggml_compute_forward_rope_f32( const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * src1 = dst->src[1]; + const struct ggml_tensor * src2 = dst->src[2]; if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; @@ -14148,15 +14289,11 @@ static void ggml_compute_forward_rope_f32( float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; - // these two only relevant for xPos RoPE: - float xpos_base; - bool xpos_down; - //const int n_past = ((int32_t *) dst->op_params)[0]; const int n_dims = ((int32_t *) dst->op_params)[1]; const int mode = ((int32_t *) dst->op_params)[2]; - const int n_ctx = ((int32_t *) dst->op_params)[3]; - const int n_orig_ctx = ((int32_t *) dst->op_params)[4]; + //const int n_ctx = ((int32_t *) dst->op_params)[3]; + const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); @@ -14164,8 +14301,6 @@ static void ggml_compute_forward_rope_f32( memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); - memcpy(&xpos_base, (int32_t *) dst->op_params + 11, sizeof(float)); - memcpy(&xpos_down, (int32_t *) dst->op_params + 12, sizeof(bool)); GGML_TENSOR_UNARY_OP_LOCALS @@ -14193,12 +14328,18 @@ static void ggml_compute_forward_rope_f32( int ir = 0; const float theta_scale = powf(freq_base, -2.0f/n_dims); - const float inv_ndims = -1.f/n_dims; + float corr_dims[2]; - ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims); + ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); const bool is_neox = mode & 2; - const bool is_glm = mode & 4; + + const float * freq_factors = NULL; + if (src2 != NULL) { + GGML_ASSERT(src2->type == GGML_TYPE_F32); + GGML_ASSERT(src2->ne[0] >= n_dims / 2); + freq_factors = (const float *) src2->data; + } // backward process uses inverse rotation by cos and sin. // cos and sin build a rotation matrix, where the inverse is the transpose. @@ -14212,106 +14353,57 @@ static void ggml_compute_forward_rope_f32( const int64_t p = pos[i2]; float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith; - if (!is_glm && !is_neox) { // TODO: cache sin/cos for glm, neox - ggml_rope_cache_init(p, freq_scale, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale); - } + ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale); for (int64_t i1 = 0; i1 < ne1; i1++) { if (ir++ < ir0) continue; if (ir > ir1) break; - float theta_base = (float)p; - - if (is_glm) { - theta_base = MIN(p, n_ctx - 2); - float block_theta = MAX(p - (n_ctx - 2), 0); - for (int64_t i0 = 0; i0 < ne0 / 4; i0++) { - const float cos_theta = cosf(theta_base); - const float sin_theta = sinf(theta_base) * sin_sign; - const float cos_block_theta = cosf(block_theta); - const float sin_block_theta = sinf(block_theta) * sin_sign; - - theta_base *= theta_scale; - block_theta *= theta_scale; - - const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - const float x0 = src[0]; - const float x1 = src[n_dims/2]; - const float x2 = src[n_dims]; - const float x3 = src[n_dims/2*3]; - - dst_data[0] = x0*cos_theta - x1*sin_theta; - dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta; - dst_data[n_dims] = x2*cos_block_theta - x3*sin_block_theta; - dst_data[n_dims/2*3] = x2*sin_block_theta + x3*cos_block_theta; - } - } else if (!is_neox) { - for (int64_t i0 = 0; i0 < ne0; i0 += 2) { + if (!is_neox) { + for (int64_t i0 = 0; i0 < n_dims; i0 += 2) { const float cos_theta = cache[i0 + 0]; const float sin_theta = cache[i0 + 1]; - // zeta scaling for xPos only: - float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f; - if (xpos_down) zeta = 1.0f / zeta; - const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); const float x0 = src[0]; const float x1 = src[1]; - dst_data[0] = x0*cos_theta*zeta - x1*sin_theta*zeta; - dst_data[1] = x0*sin_theta*zeta + x1*cos_theta*zeta; + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[1] = x0*sin_theta + x1*cos_theta; } } else { - // TODO: this might be wrong for ne0 != n_dims - need double check - // it seems we have to rope just the first n_dims elements and do nothing with the rest - // ref: https://github.com/ml-explore/mlx/blob/dc2edc762c797e3b8de50b1dad4dc0a131691033/benchmarks/python/llama_jax_bench.py#L11-L26 - theta_base *= freq_scale; - for (int64_t ic = 0; ic < ne0; ic += 2) { - if (ic < n_dims) { - const int64_t ib = 0; + for (int64_t i0 = 0; i0 < n_dims; i0 += 2) { + const int64_t ic = i0/2; - // simplified from `(ib * n_dims + ic) * inv_ndims` - float cur_rot = inv_ndims * ic - ib; + const float cos_theta = cache[i0 + 0]; + const float sin_theta = cache[i0 + 1]; - float cos_theta, sin_theta; - rope_yarn( - theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, - &cos_theta, &sin_theta - ); - sin_theta *= sin_sign; + const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); + float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); - theta_base *= theta_scale; + const float x0 = src[0]; + const float x1 = src[n_dims/2]; - const int64_t i0 = ib*n_dims + ic/2; - - const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - const float x0 = src[0]; - const float x1 = src[n_dims/2]; - - dst_data[0] = x0*cos_theta - x1*sin_theta; - dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta; - } else { - const int64_t i0 = ic; - - const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - dst_data[0] = src[0]; - dst_data[1] = src[1]; - } + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta; } } + + for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) { + const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } } } } } +// TODO: deduplicate f16/f32 code static void ggml_compute_forward_rope_f16( const struct ggml_compute_params * params, struct ggml_tensor * dst, @@ -14319,6 +14411,7 @@ static void ggml_compute_forward_rope_f16( const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * src1 = dst->src[1]; + const struct ggml_tensor * src2 = dst->src[2]; if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; @@ -14329,8 +14422,8 @@ static void ggml_compute_forward_rope_f16( //const int n_past = ((int32_t *) dst->op_params)[0]; const int n_dims = ((int32_t *) dst->op_params)[1]; const int mode = ((int32_t *) dst->op_params)[2]; - const int n_ctx = ((int32_t *) dst->op_params)[3]; - const int n_orig_ctx = ((int32_t *) dst->op_params)[4]; + //const int n_ctx = ((int32_t *) dst->op_params)[3]; + const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); @@ -14364,12 +14457,18 @@ static void ggml_compute_forward_rope_f16( int ir = 0; const float theta_scale = powf(freq_base, -2.0f/n_dims); - const float inv_ndims = -1.f/n_dims; + float corr_dims[2]; - ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims); + ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); const bool is_neox = mode & 2; - const bool is_glm = mode & 4; + + const float * freq_factors = NULL; + if (src2 != NULL) { + GGML_ASSERT(src2->type == GGML_TYPE_F32); + GGML_ASSERT(src2->ne[0] >= n_dims / 2); + freq_factors = (const float *) src2->data; + } // backward process uses inverse rotation by cos and sin. // cos and sin build a rotation matrix, where the inverse is the transpose. @@ -14383,43 +14482,14 @@ static void ggml_compute_forward_rope_f16( const int64_t p = pos[i2]; float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith; - if (!is_glm && !is_neox) { // TODO: cache sin/cos for glm, neox - ggml_rope_cache_init(p, freq_scale, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale); - } + ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale); for (int64_t i1 = 0; i1 < ne1; i1++) { if (ir++ < ir0) continue; if (ir > ir1) break; - float theta_base = (float)p; - - if (is_glm) { - theta_base = MIN(p, n_ctx - 2); - float block_theta = MAX(p - (n_ctx - 2), 0); - for (int64_t i0 = 0; i0 < ne0 / 4; i0++) { - const float cos_theta = cosf(theta_base); - const float sin_theta = sinf(theta_base) * sin_sign; - const float cos_block_theta = cosf(block_theta); - const float sin_block_theta = sinf(block_theta) * sin_sign; - - theta_base *= theta_scale; - block_theta *= theta_scale; - - const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - const float x0 = GGML_FP16_TO_FP32(src[0]); - const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]); - const float x2 = GGML_FP16_TO_FP32(src[n_dims]); - const float x3 = GGML_FP16_TO_FP32(src[n_dims/2*3]); - - dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta); - dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta); - dst_data[n_dims] = GGML_FP32_TO_FP16(x2*cos_block_theta - x3*sin_block_theta); - dst_data[n_dims/2*3] = GGML_FP32_TO_FP16(x2*sin_block_theta + x3*cos_block_theta); - } - } else if (!is_neox) { - for (int64_t i0 = 0; i0 < ne0; i0 += 2) { + if (!is_neox) { + for (int64_t i0 = 0; i0 < n_dims; i0 += 2) { const float cos_theta = cache[i0 + 0]; const float sin_theta = cache[i0 + 1]; @@ -14433,47 +14503,30 @@ static void ggml_compute_forward_rope_f16( dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta); } } else { - // TODO: this might be wrong for ne0 != n_dims - need double check - // it seems we have to rope just the first n_dims elements and do nothing with the rest - // ref: https://github.com/ml-explore/mlx/blob/dc2edc762c797e3b8de50b1dad4dc0a131691033/benchmarks/python/llama_jax_bench.py#L11-L26 - theta_base *= freq_scale; - for (int64_t ic = 0; ic < ne0; ic += 2) { - if (ic < n_dims) { - const int64_t ib = 0; + for (int64_t i0 = 0; i0 < n_dims; i0 += 2) { + const int64_t ic = i0/2; - // simplified from `(ib * n_dims + ic) * inv_ndims` - float cur_rot = inv_ndims * ic - ib; + const float cos_theta = cache[i0 + 0]; + const float sin_theta = cache[i0 + 1]; - float cos_theta, sin_theta; - rope_yarn( - theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, - &cos_theta, &sin_theta - ); - sin_theta *= sin_sign; + const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); + ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); - theta_base *= theta_scale; + const float x0 = GGML_FP16_TO_FP32(src[0]); + const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]); - const int64_t i0 = ib*n_dims + ic/2; - - const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - const float x0 = GGML_FP16_TO_FP32(src[0]); - const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]); - - dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta); - dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta); - } else { - const int64_t i0 = ic; - - const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - dst_data[0] = src[0]; - dst_data[1] = src[1]; - } + dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta); + dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta); } } + + for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) { + const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } } } } @@ -15484,400 +15537,6 @@ static void ggml_compute_forward_argsort( } } -// ggml_compute_forward_flash_attn - -static void ggml_compute_forward_flash_attn_f32( - const struct ggml_compute_params * params, - const bool masked, - struct ggml_tensor * dst) { - - const struct ggml_tensor * q = dst->src[0]; - const struct ggml_tensor * k = dst->src[1]; - const struct ggml_tensor * v = dst->src[2]; - - int64_t t0 = ggml_perf_time_us(); - UNUSED(t0); - - GGML_TENSOR_LOCALS(int64_t, neq, q, ne) - GGML_TENSOR_LOCALS(size_t, nbq, q, nb) - GGML_TENSOR_LOCALS(int64_t, nek, k, ne) - GGML_TENSOR_LOCALS(size_t, nbk, k, nb) - GGML_TENSOR_LOCALS(int64_t, nev, v, ne) - GGML_TENSOR_LOCALS(size_t, nbv, v, nb) - GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) - GGML_TENSOR_LOCALS(size_t, nb, dst, nb) - - const int ith = params->ith; - const int nth = params->nth; - - const int64_t D = neq0; - const int64_t N = neq1; - const int64_t P = nek1 - N; - const int64_t M = P + N; - - const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL); - - GGML_ASSERT(ne0 == D); - GGML_ASSERT(ne1 == N); - GGML_ASSERT(P >= 0); - - GGML_ASSERT(nbq0 == sizeof(float)); - GGML_ASSERT(nbk0 == sizeof(float)); - GGML_ASSERT(nbv0 == sizeof(float)); - - GGML_ASSERT(neq0 == D); - GGML_ASSERT(nek0 == D); - GGML_ASSERT(nev1 == D); - - GGML_ASSERT(neq1 == N); - GGML_ASSERT(nek1 == N + P); - GGML_ASSERT(nev1 == D); - - // dst cannot be transposed or permuted - GGML_ASSERT(nb0 == sizeof(float)); - GGML_ASSERT(nb0 <= nb1); - GGML_ASSERT(nb1 <= nb2); - GGML_ASSERT(nb2 <= nb3); - - if (params->type == GGML_TASK_TYPE_INIT) { - return; - } - - if (params->type == GGML_TASK_TYPE_FINALIZE) { - return; - } - - // parallelize by q rows using ggml_vec_dot_f32 - - // total rows in q - const int nr = neq1*neq2*neq3; - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - const float scale = 1.0f/sqrtf(D); - - //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale); - - for (int ir = ir0; ir < ir1; ++ir) { - // q indices - const int iq3 = ir/(neq2*neq1); - const int iq2 = (ir - iq3*neq2*neq1)/neq1; - const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1); - - float * S = (float *) params->wdata + ith*(Mup + CACHE_LINE_SIZE_F32); - - for (int i = M; i < Mup; ++i) { - S[i] = -INFINITY; - } - - const int64_t masked_begin = masked ? (P + iq1 + 1) : M; - for (int64_t ic = 0; ic < masked_begin; ++ic) { - // k indices - const int ik3 = iq3; - const int ik2 = iq2 % nek2; - const int ik1 = ic; - - // S indices - const int i1 = ik1; - - ggml_vec_dot_f32(neq0, - S + i1, 0, - (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), 0, - (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), 0, 1); - } - - // scale - ggml_vec_scale_f32(masked_begin, S, scale); - - for (int64_t i = masked_begin; i < M; i++) { - S[i] = -INFINITY; - } - - // softmax - // exclude known -INF S[..] values from max and loop - // dont forget to set their SW values to zero - { - float max = -INFINITY; - ggml_vec_max_f32(masked_begin, &max, S); - - ggml_float sum = 0.0; - { -#ifdef GGML_SOFT_MAX_ACCELERATE - max = -max; - vDSP_vsadd(S, 1, &max, S, 1, Mup); - vvexpf(S, S, &Mup); - ggml_vec_sum_f32(Mup, &sum, S); -#else - sum = ggml_vec_soft_max_f32(Mup, S, S, max); -#endif - } - - assert(sum > 0.0); - - sum = 1.0/sum; - ggml_vec_scale_f32(masked_begin, S, sum); - -#ifndef NDEBUG - for (int i = 0; i < masked_begin; ++i) { - assert(!isnan(S[i])); - assert(!isinf(S[i])); - } -#endif - } - - for (int64_t ic = 0; ic < nev1; ++ic) { - // dst indices - const int i1 = iq1; - const int i2 = iq2; - const int i3 = iq3; - - // v indices - const int iv2 = iq2 % nev2; - const int iv3 = iq3; - - ggml_vec_dot_f32(masked_begin, - (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), 0, - (float *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), 0, - S, 0, 1); - } - } -} - -static void ggml_compute_forward_flash_attn_f16( - const struct ggml_compute_params * params, - const bool masked, - struct ggml_tensor * dst) { - - const struct ggml_tensor * q = dst->src[0]; - const struct ggml_tensor * k = dst->src[1]; - const struct ggml_tensor * v = dst->src[2]; - - int64_t t0 = ggml_perf_time_us(); - UNUSED(t0); - - GGML_TENSOR_LOCALS(int64_t, neq, q, ne) - GGML_TENSOR_LOCALS(size_t, nbq, q, nb) - GGML_TENSOR_LOCALS(int64_t, nek, k, ne) - GGML_TENSOR_LOCALS(size_t, nbk, k, nb) - GGML_TENSOR_LOCALS(int64_t, nev, v, ne) - GGML_TENSOR_LOCALS(size_t, nbv, v, nb) - GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) - GGML_TENSOR_LOCALS(size_t, nb, dst, nb) - - const int ith = params->ith; - const int nth = params->nth; - - const int64_t D = neq0; - const int64_t N = neq1; - const int64_t P = nek1 - N; - const int64_t M = P + N; - - const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL); - - GGML_ASSERT(ne0 == D); - GGML_ASSERT(ne1 == N); - GGML_ASSERT(P >= 0); - - GGML_ASSERT(nbq0 == sizeof(ggml_fp16_t)); - GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t)); - GGML_ASSERT(nbv0 == sizeof(ggml_fp16_t)); - - GGML_ASSERT(neq0 == D); - GGML_ASSERT(nek0 == D); - GGML_ASSERT(nev1 == D); - - GGML_ASSERT(neq1 == N); - GGML_ASSERT(nek1 == N + P); - GGML_ASSERT(nev1 == D); - - // dst cannot be transposed or permuted - GGML_ASSERT(nb0 == sizeof(float)); - GGML_ASSERT(nb0 <= nb1); - GGML_ASSERT(nb1 <= nb2); - GGML_ASSERT(nb2 <= nb3); - - if (params->type == GGML_TASK_TYPE_INIT) { - return; - } - - if (params->type == GGML_TASK_TYPE_FINALIZE) { - return; - } - - // parallelize by q rows using ggml_vec_dot_f32 - - // total rows in q - const int nr = neq1*neq2*neq3; - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - const float scale = 1.0f/sqrtf(D); - - //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale); - - for (int ir = ir0; ir < ir1; ++ir) { - // q indices - const int iq3 = ir/(neq2*neq1); - const int iq2 = (ir - iq3*neq2*neq1)/neq1; - const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1); - - float * S = (float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32); - - for (int i = M; i < Mup; ++i) { - S[i] = -INFINITY; - } - - if (GGML_VEC_DOT_UNROLL > 2 || nek1 % GGML_VEC_DOT_UNROLL != 0) { - for (int64_t ic = 0; ic < nek1; ++ic) { - // k indices - const int ik3 = iq3; - const int ik2 = iq2 % nek2; - const int ik1 = ic; - - // S indices - const int i1 = ik1; - - ggml_vec_dot_f16(neq0, - S + i1, 0, - (ggml_fp16_t *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), 0, - (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), 0, 1); - } - } else { - for (int64_t ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) { - // k indices - const int ik3 = iq3; - const int ik2 = iq2 % nek2; - const int ik1 = ic; - - // S indices - const int i1 = ik1; - - ggml_vec_dot_f16_unroll(neq0, nbk1, - S + i1, - ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), - (ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); - } - } - - // scale - ggml_vec_scale_f32(nek1, S, scale); - - if (masked) { - for (int64_t i = P; i < M; i++) { - if (i > P + iq1) { - S[i] = -INFINITY; - } - } - } - - // softmax - // todo: exclude known -INF S[..] values from max and loop, assuming their results to be zero. - // dont forget to set their S values to zero - { - float max = -INFINITY; - ggml_vec_max_f32(M, &max, S); - - ggml_float sum = 0.0; - { -#ifdef GGML_SOFT_MAX_ACCELERATE - max = -max; - vDSP_vsadd(S, 1, &max, S, 1, Mup); - vvexpf(S, S, &Mup); - ggml_vec_sum_f32(Mup, &sum, S); -#else - sum = ggml_vec_soft_max_f32(Mup, S, S, max); -#endif - } - - assert(sum > 0.0); - - sum = 1.0/sum; - ggml_vec_scale_f32(M, S, sum); - -#ifndef NDEBUG - for (int i = 0; i < M; ++i) { - assert(!isnan(S[i])); - assert(!isinf(S[i])); - } -#endif - } - - ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32) + Mup); - - for (int64_t i = 0; i < M; i++) { - S16[i] = GGML_FP32_TO_FP16(S[i]); - } - - // todo: exclude known zero S[..] values from dot (reducing nev0 and increasing begin of v and S16). - if (GGML_VEC_DOT_UNROLL == 1 || (nev1 % GGML_VEC_DOT_UNROLL != 0)) { - for (int64_t ic = 0; ic < nev1; ++ic) { - // dst indices - const int i1 = iq1; - const int i2 = iq2; - const int i3 = iq3; - - // v indices - const int iv2 = iq2 % nev2; - const int iv3 = iq3; - - ggml_vec_dot_f16(nev0, - (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), 0, - (ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), 0, - S16, 0, 1); - } - } else { - for (int64_t ic = 0; ic < nev1; ic += GGML_VEC_DOT_UNROLL) { - // dst indices - const int i1 = iq1; - const int i2 = iq2; - const int i3 = iq3; - - // v indices - const int iv2 = iq2 % nev2; - const int iv3 = iq3; - - ggml_vec_dot_f16_unroll(nev0, nbv1, - (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), - ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), - S16); - } - } - } -} - -static void ggml_compute_forward_flash_attn( - const struct ggml_compute_params * params, - const bool masked, - struct ggml_tensor * dst) { - - const struct ggml_tensor * q = dst->src[0]; - - switch (q->type) { - case GGML_TYPE_F16: - { - ggml_compute_forward_flash_attn_f16(params, masked, dst); - } break; - case GGML_TYPE_F32: - { - ggml_compute_forward_flash_attn_f32(params, masked, dst); - } break; - default: - { - GGML_ASSERT(false); - } break; - } -} - // ggml_compute_forward_flash_attn_ext static void ggml_compute_forward_flash_attn_ext_f16( @@ -15908,9 +15567,10 @@ static void ggml_compute_forward_flash_attn_ext_f16( GGML_ASSERT(ne0 == D); GGML_ASSERT(ne2 == N); - GGML_ASSERT(nbq0 == sizeof(float)); - GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t)); - GGML_ASSERT(nbv0 == sizeof(ggml_fp16_t)); + // input tensor rows must be contiguous + GGML_ASSERT(nbq0 == ggml_type_size(q->type)); + GGML_ASSERT(nbk0 == ggml_type_size(k->type)); + GGML_ASSERT(nbv0 == ggml_type_size(v->type)); GGML_ASSERT(neq0 == D); GGML_ASSERT(nek0 == D); @@ -15964,6 +15624,11 @@ static void ggml_compute_forward_flash_attn_ext_f16( const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + enum ggml_type const k_vec_dot_type = type_traits[k->type].vec_dot_type; + ggml_from_float_t const q_to_vec_dot = type_traits[k_vec_dot_type].from_float; + ggml_vec_dot_t const kq_vec_dot = type_traits[k->type].vec_dot; + ggml_to_float_t const v_to_float = type_traits[v->type].to_float; + // loop over n_batch and n_head for (int ir = ir0; ir < ir1; ++ir) { // q indices @@ -15971,17 +15636,22 @@ static void ggml_compute_forward_flash_attn_ext_f16( const int iq2 = (ir - iq3*neq2*neq1)/neq1; const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1); - const uint32_t h = iq2; // head + const uint32_t h = iq2; // head index const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f; - float S = 0.0f; - float M = -INFINITY; + float S = 0.0f; // sum + float M = -INFINITY; // maximum KQ value - float * V32 = (float *) params->wdata + ith*(2*D + CACHE_LINE_SIZE_F32); - ggml_fp16_t * Q16 = (ggml_fp16_t *) (V32); // reuse memory - ggml_fp16_t * V16 = (ggml_fp16_t *) (V32 + D); + float * VKQ32 = (float *) params->wdata + ith*(3*D + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator + float * V32 = (VKQ32 + 1*D); // (temporary) FP32 V buffer + ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*D); // (temporary) FP16 VKQ accumulator + ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*D); // (temporary) buffer for Q converted to quantized/FP16 - memset(V16, 0, D*sizeof(ggml_fp16_t)); + if (v->type == GGML_TYPE_F16) { + memset(VKQ16, 0, D*sizeof(ggml_fp16_t)); + } else { + memset(VKQ32, 0, D*sizeof(float)); + } const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL; @@ -15993,6 +15663,9 @@ static void ggml_compute_forward_flash_attn_ext_f16( const int iv3 = iq3 / rv3; const int iv2 = iq2 / rv2; + const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)); + q_to_vec_dot(pq, Q_q, D); + // online softmax / attention // loop over n_kv and n_head_kv // ref: https://arxiv.org/pdf/2112.05682.pdf @@ -16002,51 +15675,66 @@ static void ggml_compute_forward_flash_attn_ext_f16( continue; } - float s; + float s; // KQ value - // convert Q to F16 in V32 - { - const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)); + const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3); + kq_vec_dot(D, &s, 0, k_data, 0, Q_q, 0, 1); - for (int64_t d = 0; d < D; ++d) { - Q16[d] = GGML_FP32_TO_FP16(pq[d]); - } - } - - ggml_vec_dot_f16(D, - &s, 0, - (ggml_fp16_t *) ((char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3)), 0, - Q16, 0, 1); - - s = s*scale + mv; + s = s*scale + mv; // scale KQ value and apply mask const float Mold = M; - float ms = 1.0f; - float vs = 1.0f; + float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value + float vs = 1.0f; // post-softmax KQ value, expf(s - M) - if (s > M) { - M = s; - ms = expf(Mold - M); + const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3)); - // V = V*expf(Mold - M) - ggml_vec_scale_f16(D, V16, ms); + if (v->type== GGML_TYPE_F16) { + if (s > M) { + // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f + M = s; + ms = expf(Mold - M); + + // V = V*expf(Mold - M) + ggml_vec_scale_f16(D, VKQ16, ms); + } else { + // no new maximum, ms == 1.0f, vs != 1.0f + vs = expf(s - M); + } + + // V += v*expf(s - M) + ggml_vec_mad_f16(D, VKQ16, (const ggml_fp16_t *) v_data, vs); } else { - vs = expf(s - M); + if (s > M) { + // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f + M = s; + ms = expf(Mold - M); + + // V = V*expf(Mold - M) + ggml_vec_scale_f32(D, VKQ32, ms); + } else { + // no new maximum, ms == 1.0f, vs != 1.0f + vs = expf(s - M); + } + + v_to_float(v_data, V32, D); + + // V += v*expf(s - M) + ggml_vec_mad_f32(D, VKQ32, V32, vs); } - const ggml_fp16_t * v16 = (const ggml_fp16_t *) ((char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3)); + S = S*ms + vs; // scale and increment sum with partial sum + } - // V += v*expf(s - M) - ggml_vec_mad_f16(D, V16, v16, vs); - - S = S*ms + vs; + if (v->type == GGML_TYPE_F16) { + for (int64_t d = 0; d < D; ++d) { + VKQ32[d] = GGML_FP16_TO_FP32(VKQ16[d]); + } } // V /= S - for (int64_t d = 0; d < D; ++d) { - V32[d] = GGML_FP16_TO_FP32(V16[d])/S; - } + const float S_inv = 1.0f/S; + ggml_vec_scale_f32(D, VKQ32, S_inv); // dst indices const int i1 = iq1; @@ -16057,7 +15745,7 @@ static void ggml_compute_forward_flash_attn_ext_f16( //memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float)); // permute(0, 2, 1, 3) - memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, V32, nb1); + memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1); } } @@ -16082,165 +15770,6 @@ static void ggml_compute_forward_flash_attn_ext( } } -// ggml_compute_forward_flash_ff - -static void ggml_compute_forward_flash_ff_f16( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * a = dst->src[0]; // F16 - const struct ggml_tensor * b0 = dst->src[1]; // F16 fc_w - const struct ggml_tensor * b1 = dst->src[2]; // F32 fc_b - const struct ggml_tensor * c0 = dst->src[3]; // F16 proj_w - const struct ggml_tensor * c1 = dst->src[4]; // F32 proj_b - - int64_t t0 = ggml_perf_time_us(); - UNUSED(t0); - - GGML_TENSOR_LOCALS(int64_t, nea, a, ne) - GGML_TENSOR_LOCALS(size_t, nba, a, nb) - GGML_TENSOR_LOCALS(int64_t, neb0, b0, ne) - GGML_TENSOR_LOCALS(size_t, nbb0, b0, nb) - GGML_TENSOR_LOCALS(int64_t, neb1, b1, ne) - GGML_TENSOR_LOCALS(size_t, nbb1, b1, nb) - GGML_TENSOR_LOCALS(int64_t, nec0, c0, ne) - GGML_TENSOR_LOCALS(size_t, nbc0, c0, nb) - GGML_TENSOR_LOCALS(int64_t, nec1, c1, ne) - GGML_TENSOR_LOCALS(size_t, nbc1, c1, nb) - GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) - GGML_TENSOR_LOCALS(size_t, nb, dst, nb) - - const int ith = params->ith; - const int nth = params->nth; - - const int64_t D = nea0; - //const int64_t N = nea1; - const int64_t M = neb01; - - GGML_ASSERT(ne0 == nea0); - GGML_ASSERT(ne1 == nea1); - GGML_ASSERT(ne2 == nea2); - - GGML_ASSERT(nba0 == sizeof(ggml_fp16_t)); - GGML_ASSERT(nbb00 == sizeof(ggml_fp16_t)); - GGML_ASSERT(nbb10 == sizeof(float)); - GGML_ASSERT(nbc00 == sizeof(ggml_fp16_t)); - GGML_ASSERT(nbc10 == sizeof(float)); - - GGML_ASSERT(neb00 == D); - GGML_ASSERT(neb01 == M); - GGML_ASSERT(neb10 == M); - GGML_ASSERT(neb11 == 1); - - GGML_ASSERT(nec00 == M); - GGML_ASSERT(nec01 == D); - GGML_ASSERT(nec10 == D); - GGML_ASSERT(nec11 == 1); - - // dst cannot be transposed or permuted - GGML_ASSERT(nb0 == sizeof(float)); - GGML_ASSERT(nb0 <= nb1); - GGML_ASSERT(nb1 <= nb2); - GGML_ASSERT(nb2 <= nb3); - - if (params->type == GGML_TASK_TYPE_INIT) { - return; - } - - if (params->type == GGML_TASK_TYPE_FINALIZE) { - return; - } - - // parallelize by a rows using ggml_vec_dot_f32 - - // total rows in a - const int nr = nea1*nea2*nea3; - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - for (int ir = ir0; ir < ir1; ++ir) { - // a indices - const int ia3 = ir/(nea2*nea1); - const int ia2 = (ir - ia3*nea2*nea1)/nea1; - const int ia1 = (ir - ia3*nea2*nea1 - ia2*nea1); - - float * S = (float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32); - - for (int64_t ic = 0; ic < neb01; ++ic) { - // b0 indices - const int ib03 = ia3; - const int ib02 = ia2; - const int ib01 = ic; - - // S indices - const int i1 = ib01; - - ggml_vec_dot_f16(nea0, - S + i1, 0, - (ggml_fp16_t *) ((char *) b0->data + (ib01*nbb01 + ib02*nbb02 + ib03*nbb03)), 0, - (ggml_fp16_t *) ((char *) a->data + ( ia1*nba1 + ia2*nba2 + ia3*nba3)), 0, 1); - } - - ggml_vec_add_f32(neb01, S, S, (float *) b1->data); - //ggml_vec_gelu_f32(neb01, S, S); - - ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32) + M); - - for (int64_t i = 0; i < M; i++) { - S16[i] = GGML_FP32_TO_FP16(S[i]); - } - - ggml_vec_gelu_f16(neb01, S16, S16); - - { - // dst indices - const int i1 = ia1; - const int i2 = ia2; - const int i3 = ia3; - - for (int64_t ic = 0; ic < nec01; ++ic) { - - ggml_vec_dot_f16(neb01, - (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), 0, - (ggml_fp16_t *) ((char *) c0->data + ( ic*nbc01 + i2*nbc02 + i3*nbc03)), 0, - S16, 0, 1); - } - - ggml_vec_add_f32(nec01, - (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)), - (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)), - (float *) c1->data); - } - } -} - -static void ggml_compute_forward_flash_ff( - const struct ggml_compute_params * params, - struct ggml_tensor * dst) { - - const struct ggml_tensor * b0 = dst->src[1]; - - switch (b0->type) { - case GGML_TYPE_F16: - { - ggml_compute_forward_flash_ff_f16(params, dst); - } break; - case GGML_TYPE_F32: - { - GGML_ASSERT(false); // TODO - } break; - default: - { - GGML_ASSERT(false); - } break; - } -} - // ggml_compute_forward_flash_attn_back static void ggml_compute_forward_flash_attn_back_f32( @@ -17811,21 +17340,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_leaky_relu(params, tensor); } break; - case GGML_OP_FLASH_ATTN: - { - const int32_t t = ggml_get_op_params_i32(tensor, 0); - GGML_ASSERT(t == 0 || t == 1); - const bool masked = t != 0; - ggml_compute_forward_flash_attn(params, masked, tensor); - } break; case GGML_OP_FLASH_ATTN_EXT: { ggml_compute_forward_flash_attn_ext(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor); } break; - case GGML_OP_FLASH_FF: - { - ggml_compute_forward_flash_ff(params, tensor); - } break; case GGML_OP_FLASH_ATTN_BACK: { int32_t t = ggml_get_op_params_i32(tensor, 0); @@ -18195,6 +17713,7 @@ static struct ggml_tensor * ggml_sub_or_set(struct ggml_context * ctx, struct gg static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, struct ggml_hash_set zero_table) { struct ggml_tensor * src0 = tensor->src[0]; struct ggml_tensor * src1 = tensor->src[1]; + struct ggml_tensor * src2 = tensor->src[2]; switch (tensor->op) { case GGML_OP_DUP: @@ -18708,9 +18227,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor //const int n_past = ((int32_t *) tensor->op_params)[0]; const int n_dims = ((int32_t *) tensor->op_params)[1]; const int mode = ((int32_t *) tensor->op_params)[2]; - const int n_ctx = ((int32_t *) tensor->op_params)[3]; - const int n_orig_ctx = ((int32_t *) tensor->op_params)[4]; - float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, xpos_base, xpos_down; + //const int n_ctx = ((int32_t *) tensor->op_params)[3]; + const int n_ctx_orig = ((int32_t *) tensor->op_params)[4]; + float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; memcpy(&freq_base, (int32_t *) tensor->op_params + 5, sizeof(float)); memcpy(&freq_scale, (int32_t *) tensor->op_params + 6, sizeof(float)); @@ -18718,26 +18237,22 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor memcpy(&attn_factor, (int32_t *) tensor->op_params + 8, sizeof(float)); memcpy(&beta_fast, (int32_t *) tensor->op_params + 9, sizeof(float)); memcpy(&beta_slow, (int32_t *) tensor->op_params + 10, sizeof(float)); - memcpy(&xpos_base, (int32_t *) tensor->op_params + 11, sizeof(float)); - memcpy(&xpos_down, (int32_t *) tensor->op_params + 12, sizeof(bool)); src0->grad = ggml_add_or_set(ctx, src0->grad, ggml_rope_back(ctx, tensor->grad, src1, + src2, n_dims, mode, - n_ctx, - n_orig_ctx, + n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, - beta_slow, - xpos_base, - xpos_down), + beta_slow), zero_table); } } break; @@ -18747,9 +18262,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor //const int n_past = ((int32_t *) tensor->op_params)[0]; const int n_dims = ((int32_t *) tensor->op_params)[1]; const int mode = ((int32_t *) tensor->op_params)[2]; - const int n_ctx = ((int32_t *) tensor->op_params)[3]; - const int n_orig_ctx = ((int32_t *) tensor->op_params)[4]; - float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, xpos_base, xpos_down; + //const int n_ctx = ((int32_t *) tensor->op_params)[3]; + const int n_ctx_orig = ((int32_t *) tensor->op_params)[4]; + float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; memcpy(&freq_base, (int32_t *) tensor->op_params + 5, sizeof(float)); memcpy(&freq_scale, (int32_t *) tensor->op_params + 6, sizeof(float)); @@ -18757,26 +18272,22 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor memcpy(&attn_factor, (int32_t *) tensor->op_params + 8, sizeof(float)); memcpy(&beta_fast, (int32_t *) tensor->op_params + 9, sizeof(float)); memcpy(&beta_slow, (int32_t *) tensor->op_params + 10, sizeof(float)); - memcpy(&xpos_base, (int32_t *) tensor->op_params + 11, sizeof(float)); - memcpy(&xpos_down, (int32_t *) tensor->op_params + 12, sizeof(bool)); src0->grad = ggml_add_or_set(ctx, src0->grad, ggml_rope_impl(ctx, tensor->grad, src1, + src2, n_dims, mode, - n_ctx, - n_orig_ctx, + n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, - xpos_base, - xpos_down, false), zero_table); } @@ -18829,7 +18340,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { GGML_ASSERT(false); // TODO: not implemented } break; - case GGML_OP_FLASH_ATTN: case GGML_OP_FLASH_ATTN_EXT: { struct ggml_tensor * flash_grad = NULL; @@ -18846,7 +18356,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor masked); } - struct ggml_tensor * src2 = tensor->src[2]; const int64_t elem_q = ggml_nelements(src0); const int64_t elem_k = ggml_nelements(src1); const int64_t elem_v = ggml_nelements(src2); @@ -18884,10 +18393,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor zero_table); } } break; - case GGML_OP_FLASH_FF: - { - GGML_ASSERT(false); // not supported - } break; case GGML_OP_FLASH_ATTN_BACK: { GGML_ASSERT(false); // not supported @@ -19574,15 +19079,10 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_ { n_tasks = n_threads; } break; - case GGML_OP_FLASH_ATTN: case GGML_OP_FLASH_ATTN_EXT: { n_tasks = n_threads; } break; - case GGML_OP_FLASH_FF: - { - n_tasks = n_threads; - } break; case GGML_OP_FLASH_ATTN_BACK: { n_tasks = n_threads; @@ -19894,11 +19394,6 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa { const enum ggml_type vec_dot_type = type_traits[node->src[0]->type].vec_dot_type; -#if defined(GGML_USE_CLBLAST) - if (ggml_cl_can_mul_mat(node->src[0], node->src[1], node)) { - cur = ggml_cl_mul_mat_get_wsize(node->src[0], node->src[1], node); - } else -#endif #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) if (ggml_compute_forward_mul_mat_use_blas(node)) { if (node->src[0]->type != GGML_TYPE_F32) { @@ -19979,39 +19474,11 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02*ne03; cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12; } break; - case GGML_OP_FLASH_ATTN: - { - const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL); - - if (node->src[1]->type == GGML_TYPE_F32) { - cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1) - cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2 - } else if (node->src[1]->type == GGML_TYPE_F16) { - cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1) - cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2 - } else if (node->src[1]->type == GGML_TYPE_BF16) { - cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1) - cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2 - } - } break; case GGML_OP_FLASH_ATTN_EXT: { const int64_t ne00 = node->src[0]->ne[0]; // D - cur = 2*sizeof(float)*ne00*n_tasks; // 2x head size - } break; - case GGML_OP_FLASH_FF: - { - if (node->src[1]->type == GGML_TYPE_F32) { - cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1) - cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2 - } else if (node->src[1]->type == GGML_TYPE_F16) { - cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1) - cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2 - } else if (node->src[1]->type == GGML_TYPE_BF16) { - cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1) - cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2 - } + cur = 3*sizeof(float)*ne00*n_tasks; // 3x head size/thread } break; case GGML_OP_FLASH_ATTN_BACK: { @@ -20056,6 +19523,59 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa return cplan; } +static enum ggml_status ggml_graph_compute_parallel(struct ggml_compute_state * workers, int n_threads) { + enum ggml_status compute_status = GGML_STATUS_SUCCESS; + +#ifdef GGML_USE_OPENMP + if (n_threads > 1) { + #pragma omp parallel num_threads(n_threads) + { + #pragma omp single + { + // update the number of threads from the actual number of threads that we got from OpenMP + n_threads = omp_get_num_threads(); + workers[0].shared->n_threads = n_threads; + workers[0].shared->n_active = n_threads; + } + ggml_graph_compute_thread(&workers[omp_get_thread_num()]); + } + } else { + ggml_graph_compute_thread(&workers[0]); + } +#else + // create thread pool + if (n_threads > 1) { + for (int j = 1; j < n_threads; ++j) { + const int rc = ggml_thread_create(&workers[j].thrd, NULL, ggml_graph_compute_thread, &workers[j]); + GGML_ASSERT(rc == 0); + UNUSED(rc); + } + } + + // this is a work thread too + ggml_graph_compute_thread(&workers[0]); + + // join or kill thread pool + if (n_threads > 1) { + for (int j = 1; j < n_threads; j++) { + const int rc = ggml_thread_join(workers[j].thrd, NULL); + GGML_ASSERT(rc == 0); + UNUSED(rc); + } + } +#endif + // don't leave affinity set on the main thread + clear_numa_thread_affinity(); + + for (int j = 0; j < n_threads; j++) { + if (workers[j].ec != GGML_STATUS_SUCCESS) { + compute_status = workers[j].ec; + break; + } + } + return compute_status; +} + enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) { { GGML_ASSERT(cplan); @@ -20066,7 +19586,11 @@ enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cpl } } - const int n_threads = cplan->n_threads; + int n_threads = cplan->n_threads; + +#if defined(GGML_USE_OPENMP) + n_threads = MIN(n_threads, omp_get_max_threads()); +#endif struct ggml_compute_state_shared state_shared = { /*.cgraph =*/ cgraph, @@ -20082,47 +19606,20 @@ enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cpl /*.current_chunk; =*/ 0, }; struct ggml_compute_state * workers = alloca(sizeof(struct ggml_compute_state)*n_threads); - - // create thread pool - if (n_threads > 1) { - for (int j = 1; j < n_threads; ++j) { - workers[j] = (struct ggml_compute_state) { - .thrd = 0, - .ith = j, - .shared = &state_shared, - .ec = GGML_STATUS_SUCCESS, - }; - - const int rc = ggml_thread_create(&workers[j].thrd, NULL, ggml_graph_compute_thread, &workers[j]); - GGML_ASSERT(rc == 0); - UNUSED(rc); - } - } - - workers[0].ith = 0; - workers[0].shared = &state_shared; - workers[0].ec = GGML_STATUS_SUCCESS; - const int64_t perf_start_cycles = ggml_perf_cycles(); const int64_t perf_start_time_us = ggml_perf_time_us(); - // this is a work thread too - ggml_graph_compute_thread(&workers[0]); - enum ggml_status compute_status = workers[0].ec; - - // don't leave affinity set on the main thread - clear_numa_thread_affinity(); - - // join or kill thread pool - if (n_threads > 1) { - for (int j = 1; j < n_threads; j++) { - const int rc = ggml_thread_join(workers[j].thrd, NULL); - GGML_ASSERT(rc == 0); - if (workers[j].ec != GGML_STATUS_SUCCESS) - compute_status = workers[j].ec; - } + for (int j = 0; j < n_threads; ++j) { + workers[j] = (struct ggml_compute_state) { + .thrd = 0, + .ith = j, + .shared = &state_shared, + .ec = GGML_STATUS_SUCCESS, + }; } + enum ggml_status compute_status = ggml_graph_compute_parallel(workers, n_threads); + // performance stats (graph) { int64_t perf_cycles_cur = ggml_perf_cycles() - perf_start_cycles; @@ -21853,11 +21350,7 @@ size_t ggml_quantize_chunk( case GGML_TYPE_IQ1_S: result = quantize_iq1_s (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ1_M: result = quantize_iq1_m (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ4_NL: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; -#if QK_K == 64 - case GGML_TYPE_IQ4_XS: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; -#else case GGML_TYPE_IQ4_XS: result = quantize_iq4_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; -#endif case GGML_TYPE_F16: { size_t elemsize = sizeof(ggml_fp16_t); @@ -23134,6 +22627,14 @@ int ggml_cpu_has_avx512_vnni(void) { #endif } +int ggml_cpu_has_avx512_bf16(void) { +#if defined(__AVX512BF16__) + return 1; +#else + return 0; +#endif +} + int ggml_cpu_has_fma(void) { #if defined(__FMA__) return 1; @@ -23150,6 +22651,16 @@ int ggml_cpu_has_neon(void) { #endif } +int ggml_cpu_has_sve(void) { +#if defined(__ARM_FEATURE_SVE) + // TODO: Currently, SVE 256 bit is only supported. + GGML_ASSERT(svcntb() == QK8_0); + return 1; +#else + return 0; +#endif +} + int ggml_cpu_has_arm_fma(void) { #if defined(__ARM_FEATURE_FMA) return 1; @@ -23191,7 +22702,7 @@ int ggml_cpu_has_wasm_simd(void) { } int ggml_cpu_has_blas(void) { -#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUDA) || defined(GGML_USE_VULKAN) || defined(GGML_USE_CLBLAST) || defined(GGML_USE_SYCL) +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUDA) || defined(GGML_USE_VULKAN) || defined(GGML_USE_SYCL) return 1; #else return 0; @@ -23206,14 +22717,6 @@ int ggml_cpu_has_cuda(void) { #endif } -int ggml_cpu_has_clblast(void) { -#if defined(GGML_USE_CLBLAST) - return 1; -#else - return 0; -#endif -} - int ggml_cpu_has_vulkan(void) { #if defined(GGML_USE_VULKAN) return 1; @@ -23238,9 +22741,16 @@ int ggml_cpu_has_sycl(void) { #endif } +int ggml_cpu_has_rpc(void) { +#if defined(GGML_USE_RPC) + return 1; +#else + return 0; +#endif +} + int ggml_cpu_has_gpublas(void) { - return ggml_cpu_has_cuda() || ggml_cpu_has_clblast() || ggml_cpu_has_vulkan() || ggml_cpu_has_kompute() || - ggml_cpu_has_sycl(); + return ggml_cpu_has_cuda() || ggml_cpu_has_vulkan() || ggml_cpu_has_kompute() || ggml_cpu_has_sycl(); } int ggml_cpu_has_sse3(void) { diff --git a/llama/ggml.h b/llama/ggml.h index b45ceb7f..00ef4582 100644 --- a/llama/ggml.h +++ b/llama/ggml.h @@ -1,5 +1,5 @@ /** - * llama.cpp - git 059031b8c40e1f4ba60586842c5b1ed3ddf61842 + * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28 * * MIT License * @@ -507,9 +507,7 @@ extern "C" { GGML_OP_ARGSORT, GGML_OP_LEAKY_RELU, - GGML_OP_FLASH_ATTN, GGML_OP_FLASH_ATTN_EXT, - GGML_OP_FLASH_FF, GGML_OP_FLASH_ATTN_BACK, GGML_OP_SSM_CONV, GGML_OP_SSM_SCAN, @@ -784,7 +782,6 @@ extern "C" { GGML_API enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype); GGML_API GGML_CALL bool ggml_is_transposed(const struct ggml_tensor * tensor); - GGML_API GGML_CALL bool ggml_is_contiguous(const struct ggml_tensor * tensor); GGML_API GGML_CALL bool ggml_is_permuted (const struct ggml_tensor * tensor); GGML_API GGML_CALL bool ggml_is_empty (const struct ggml_tensor * tensor); GGML_API bool ggml_is_scalar (const struct ggml_tensor * tensor); @@ -793,6 +790,11 @@ extern "C" { GGML_API bool ggml_is_3d (const struct ggml_tensor * tensor); GGML_API int ggml_n_dims (const struct ggml_tensor * tensor); // returns 1 for scalars + GGML_API GGML_CALL bool ggml_is_contiguous (const struct ggml_tensor * tensor); + GGML_API GGML_CALL bool ggml_is_contiguous_0(const struct ggml_tensor * tensor); // same as ggml_is_contiguous() + GGML_API GGML_CALL bool ggml_is_contiguous_1(const struct ggml_tensor * tensor); // contiguous for dims >= 1 + GGML_API GGML_CALL bool ggml_is_contiguous_2(const struct ggml_tensor * tensor); // contiguous for dims >= 2 + GGML_API bool ggml_are_same_shape (const struct ggml_tensor * t0, const struct ggml_tensor * t1); GGML_API bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tensor * t1); @@ -1035,12 +1037,13 @@ extern "C" { struct ggml_tensor * a, struct ggml_tensor * b); - // concat a and b on dim 2 + // concat a and b along dim // used in stable-diffusion GGML_API struct ggml_tensor * ggml_concat( struct ggml_context * ctx, struct ggml_tensor * a, - struct ggml_tensor * b); + struct ggml_tensor * b, + int dim); GGML_API struct ggml_tensor * ggml_abs( struct ggml_context * ctx, @@ -1486,18 +1489,17 @@ extern "C" { struct ggml_tensor * b); // rotary position embedding - // if mode & 1 == 1, skip n_past elements (DEPRECATED) + // if mode & 1 == 1, skip n_past elements (NOT SUPPORTED) // if mode & 2 == 1, GPT-NeoX style - // if mode & 4 == 1, ChatGLM style // // b is an int32 vector with size a->ne[2], it contains the positions + // c is freq factors (e.g. phi3-128k), (optional) GGML_API struct ggml_tensor * ggml_rope( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, int n_dims, - int mode, - int n_ctx); + int mode); // in-place, returns view(a) GGML_API struct ggml_tensor * ggml_rope_inplace( @@ -1505,18 +1507,17 @@ extern "C" { struct ggml_tensor * a, struct ggml_tensor * b, int n_dims, - int mode, - int n_ctx); + int mode); // custom RoPE - GGML_API struct ggml_tensor * ggml_rope_custom( + GGML_API struct ggml_tensor * ggml_rope_ext( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, + struct ggml_tensor * c, int n_dims, int mode, - int n_ctx, - int n_orig_ctx, + int n_ctx_orig, float freq_base, float freq_scale, float ext_factor, @@ -1525,14 +1526,14 @@ extern "C" { float beta_slow); // in-place, returns view(a) - GGML_API struct ggml_tensor * ggml_rope_custom_inplace( + GGML_API struct ggml_tensor * ggml_rope_ext_inplace( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, + struct ggml_tensor * c, int n_dims, int mode, - int n_ctx, - int n_orig_ctx, + int n_ctx_orig, float freq_base, float freq_scale, float ext_factor, @@ -1540,18 +1541,39 @@ extern "C" { float beta_fast, float beta_slow); - // compute correction dims for YaRN RoPE scaling - GGML_CALL void ggml_rope_yarn_corr_dims( - int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]); - - // xPos RoPE, in-place, returns view(a) - GGML_API struct ggml_tensor * ggml_rope_xpos_inplace( + GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_rope_custom( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, int n_dims, - float base, - bool down); + int mode, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow), + "use ggml_rope_ext instead"); + + GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_rope_custom_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int n_dims, + int mode, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow), + "use ggml_rope_ext_inplace instead"); + + // compute correction dims for YaRN RoPE scaling + GGML_CALL void ggml_rope_yarn_corr_dims( + int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]); // rotary position embedding backward, i.e compute dx from dy // a - dy @@ -1559,18 +1581,16 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, + struct ggml_tensor * c, int n_dims, int mode, - int n_ctx, - int n_orig_ctx, + int n_ctx_orig, float freq_base, float freq_scale, float ext_factor, float attn_factor, float beta_fast, - float beta_slow, - float xpos_base, - bool xpos_down); + float beta_slow); // clamp // in-place, returns view(a) @@ -1760,13 +1780,6 @@ extern "C" { struct ggml_tensor * a, int k); - GGML_API struct ggml_tensor * ggml_flash_attn( - struct ggml_context * ctx, - struct ggml_tensor * q, - struct ggml_tensor * k, - struct ggml_tensor * v, - bool masked); - #define GGML_KQ_MASK_PAD 32 // q: [n_embd, n_batch, n_head, 1] @@ -1787,6 +1800,7 @@ extern "C" { struct ggml_tensor * a, enum ggml_prec prec); + // TODO: needs to be adapted to ggml_flash_attn_ext GGML_API struct ggml_tensor * ggml_flash_attn_back( struct ggml_context * ctx, struct ggml_tensor * q, @@ -1795,14 +1809,6 @@ extern "C" { struct ggml_tensor * d, bool masked); - GGML_API struct ggml_tensor * ggml_flash_ff( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b0, - struct ggml_tensor * b1, - struct ggml_tensor * c0, - struct ggml_tensor * c1); - GGML_API struct ggml_tensor * ggml_ssm_conv( struct ggml_context * ctx, struct ggml_tensor * s, @@ -2416,8 +2422,10 @@ extern "C" { GGML_API int ggml_cpu_has_avx512 (void); GGML_API int ggml_cpu_has_avx512_vbmi(void); GGML_API int ggml_cpu_has_avx512_vnni(void); + GGML_API int ggml_cpu_has_avx512_bf16(void); GGML_API int ggml_cpu_has_fma (void); GGML_API int ggml_cpu_has_neon (void); + GGML_API int ggml_cpu_has_sve (void); GGML_API int ggml_cpu_has_arm_fma (void); GGML_API int ggml_cpu_has_metal (void); GGML_API int ggml_cpu_has_f16c (void); @@ -2425,13 +2433,13 @@ extern "C" { GGML_API int ggml_cpu_has_wasm_simd (void); GGML_API int ggml_cpu_has_blas (void); GGML_API int ggml_cpu_has_cuda (void); - GGML_API int ggml_cpu_has_clblast (void); GGML_API int ggml_cpu_has_vulkan (void); GGML_API int ggml_cpu_has_kompute (void); GGML_API int ggml_cpu_has_gpublas (void); GGML_API int ggml_cpu_has_sse3 (void); GGML_API int ggml_cpu_has_ssse3 (void); GGML_API int ggml_cpu_has_sycl (void); + GGML_API int ggml_cpu_has_rpc (void); GGML_API int ggml_cpu_has_vsx (void); GGML_API int ggml_cpu_has_matmul_int8(void); diff --git a/llama/grammar-parser.cpp b/llama/grammar-parser.cpp index 336b0d59..6e69d073 100644 --- a/llama/grammar-parser.cpp +++ b/llama/grammar-parser.cpp @@ -1,5 +1,5 @@ /** - * llama.cpp - git 059031b8c40e1f4ba60586842c5b1ed3ddf61842 + * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28 * * MIT License * @@ -72,8 +72,12 @@ namespace grammar_parser { state.rules[rule_id] = rule; } + static bool is_digit_char(char c) { + return '0' <= c && c <= '9'; + } + static bool is_word_char(char c) { - return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9'); + return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || is_digit_char(c); } static std::pair parse_hex(const char * src, int size) { @@ -125,6 +129,17 @@ namespace grammar_parser { return pos; } + static const char * parse_int(const char * src) { + const char * pos = src; + while (is_digit_char(*pos)) { + pos++; + } + if (pos == src) { + throw std::runtime_error(std::string("expecting integer at ") + src); + } + return pos; + } + static std::pair parse_char(const char * src) { if (*src == '\\') { switch (src[1]) { @@ -163,6 +178,60 @@ namespace grammar_parser { bool is_nested) { size_t last_sym_start = out_elements.size(); const char * pos = src; + + auto handle_repetitions = [&](int min_times, int max_times) { + + if (last_sym_start == out_elements.size()) { + throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos); + } + + // apply transformation to previous symbol (last_sym_start to end) according to + // the following rewrite rules: + // S{m,n} --> S S S (m times) S'(n-m) + // S'(x) ::= S S'(x-1) | + // (... n-m definitions of these S' rules ...) + // S'(1) ::= S | + // S{m,} --> S S S (m times) S' + // S' ::= S S' | + // S* --> S{0,} + // --> S' ::= S S' | + // S+ --> S{1,} + // --> S S' + // S' ::= S S' | + // S? --> S{0,1} + // --> S' + // S' ::= S | + + std::vector previous_elements(out_elements.begin() + last_sym_start, out_elements.end()); + if (min_times == 0) { + out_elements.resize(last_sym_start); + } else { + // Repeat the previous elements (min_times - 1) times + for (int i = 1; i < min_times; i++) { + out_elements.insert(out_elements.end(), previous_elements.begin(), previous_elements.end()); + } + } + + uint32_t last_rec_rule_id = 0; + auto n_opt = max_times < 0 ? 1 : max_times - min_times; + + std::vector rec_rule(previous_elements); + for (int i = 0; i < n_opt; i++) { + rec_rule.resize(previous_elements.size()); + uint32_t rec_rule_id = generate_symbol_id(state, rule_name); + if (i > 0 || max_times < 0) { + rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id}); + } + rec_rule.push_back({LLAMA_GRETYPE_ALT, 0}); + rec_rule.push_back({LLAMA_GRETYPE_END, 0}); + add_rule(state, rec_rule_id, rec_rule); + last_rec_rule_id = rec_rule_id; + } + if (n_opt > 0) { + out_elements.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id}); + } + }; + while (*pos) { if (*pos == '"') { // literal string pos++; @@ -223,40 +292,51 @@ namespace grammar_parser { throw std::runtime_error(std::string("expecting ')' at ") + pos); } pos = parse_space(pos + 1, is_nested); - } else if (*pos == '*' || *pos == '+' || *pos == '?') { // repetition operator - if (last_sym_start == out_elements.size()) { - throw std::runtime_error(std::string("expecting preceding item to */+/? at ") + pos); - } - - // apply transformation to previous symbol (last_sym_start to end) according to - // rewrite rules: - // S* --> S' ::= S S' | - // S+ --> S' ::= S S' | S - // S? --> S' ::= S | - uint32_t sub_rule_id = generate_symbol_id(state, rule_name); - std::vector sub_rule; - // add preceding symbol to generated rule - sub_rule.insert( - sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end()); - if (*pos == '*' || *pos == '+') { - // cause generated rule to recurse - sub_rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id}); - } - // mark start of alternate def - sub_rule.push_back({LLAMA_GRETYPE_ALT, 0}); - if (*pos == '+') { - // add preceding symbol as alternate only for '+' (otherwise empty) - sub_rule.insert( - sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end()); - } - sub_rule.push_back({LLAMA_GRETYPE_END, 0}); - add_rule(state, sub_rule_id, sub_rule); - - // in original rule, replace previous symbol with reference to generated rule - out_elements.resize(last_sym_start); - out_elements.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id}); - + } else if (*pos == '.') { // any char + last_sym_start = out_elements.size(); + out_elements.push_back({LLAMA_GRETYPE_CHAR_ANY, 0}); pos = parse_space(pos + 1, is_nested); + } else if (*pos == '*') { + pos = parse_space(pos + 1, is_nested); + handle_repetitions(0, -1); + } else if (*pos == '+') { + pos = parse_space(pos + 1, is_nested); + handle_repetitions(1, -1); + } else if (*pos == '?') { + pos = parse_space(pos + 1, is_nested); + handle_repetitions(0, 1); + } else if (*pos == '{') { + pos = parse_space(pos + 1, is_nested); + + if (!is_digit_char(*pos)) { + throw std::runtime_error(std::string("expecting an int at ") + pos); + } + const char * int_end = parse_int(pos); + int min_times = std::stoul(std::string(pos, int_end - pos)); + pos = parse_space(int_end, is_nested); + + int max_times = -1; + + if (*pos == '}') { + max_times = min_times; + pos = parse_space(pos + 1, is_nested); + } else if (*pos == ',') { + pos = parse_space(pos + 1, is_nested); + + if (is_digit_char(*pos)) { + const char * int_end = parse_int(pos); + max_times = std::stoul(std::string(pos, int_end - pos)); + pos = parse_space(int_end, is_nested); + } + + if (*pos != '}') { + throw std::runtime_error(std::string("expecting '}' at ") + pos); + } + pos = parse_space(pos + 1, is_nested); + } else { + throw std::runtime_error(std::string("expecting ',' at ") + pos); + } + handle_repetitions(min_times, max_times); } else { break; } @@ -351,6 +431,7 @@ namespace grammar_parser { case LLAMA_GRETYPE_CHAR_NOT: return true; case LLAMA_GRETYPE_CHAR_ALT: return true; case LLAMA_GRETYPE_CHAR_RNG_UPPER: return true; + case LLAMA_GRETYPE_CHAR_ANY: return true; default: return false; } } @@ -365,6 +446,7 @@ namespace grammar_parser { case LLAMA_GRETYPE_CHAR_NOT: fprintf(file, "CHAR_NOT"); break; case LLAMA_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break; case LLAMA_GRETYPE_CHAR_ALT: fprintf(file, "CHAR_ALT"); break; + case LLAMA_GRETYPE_CHAR_ANY: fprintf(file, "CHAR_ANY"); break; } switch (elem.type) { case LLAMA_GRETYPE_END: @@ -376,6 +458,7 @@ namespace grammar_parser { case LLAMA_GRETYPE_CHAR_NOT: case LLAMA_GRETYPE_CHAR_RNG_UPPER: case LLAMA_GRETYPE_CHAR_ALT: + case LLAMA_GRETYPE_CHAR_ANY: fprintf(file, "(\""); print_grammar_char(file, elem.value); fprintf(file, "\") "); @@ -433,11 +516,15 @@ namespace grammar_parser { } print_grammar_char(file, elem.value); break; + case LLAMA_GRETYPE_CHAR_ANY: + fprintf(file, "."); + break; } if (is_char_element(elem)) { switch (rule[i + 1].type) { case LLAMA_GRETYPE_CHAR_ALT: case LLAMA_GRETYPE_CHAR_RNG_UPPER: + case LLAMA_GRETYPE_CHAR_ANY: break; default: fprintf(file, "] "); diff --git a/llama/grammar-parser.h b/llama/grammar-parser.h index ed918160..bb25744d 100644 --- a/llama/grammar-parser.h +++ b/llama/grammar-parser.h @@ -1,5 +1,5 @@ /** - * llama.cpp - git 059031b8c40e1f4ba60586842c5b1ed3ddf61842 + * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28 * * MIT License * diff --git a/llama/json-schema-to-grammar.cpp b/llama/json-schema-to-grammar.cpp index ed876779..40212c0d 100644 --- a/llama/json-schema-to-grammar.cpp +++ b/llama/json-schema-to-grammar.cpp @@ -1,5 +1,5 @@ /** - * llama.cpp - git 059031b8c40e1f4ba60586842c5b1ed3ddf61842 + * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28 * * MIT License * @@ -42,58 +42,27 @@ static std::string join(Iterator begin, Iterator end, const std::string & separa static std::string repeat(const std::string & str, size_t n); -static std::string build_repetition(const std::string & item_rule, int min_items, int max_items, const std::string & separator_rule = "", bool item_rule_is_literal = false) { +static std::string build_repetition(const std::string & item_rule, int min_items, int max_items, const std::string & separator_rule = "") { + auto has_max = max_items != std::numeric_limits::max(); + + if (min_items == 0 && max_items == 1) { + return item_rule + "?"; + } + if (separator_rule.empty()) { - if (min_items == 0 && max_items == 1) { - return item_rule + "?"; - } else if (min_items == 1 && max_items == std::numeric_limits::max()) { + if (min_items == 1 && !has_max) { return item_rule + "+"; - } - } - - std::string result; - if (min_items > 0) { - if (item_rule_is_literal && separator_rule.empty()) { - result = "\"" + repeat(std::string(item_rule.begin() + 1, item_rule.end() - 1), min_items) + "\""; + } else if (min_items == 0 && !has_max) { + return item_rule + "*"; } else { - std::vector items(min_items, item_rule); - result = join(items.begin(), items.end(), separator_rule.empty() ? " " : " " + separator_rule + " "); + return item_rule + "{" + std::to_string(min_items) + "," + (has_max ? std::to_string(max_items) : "") + "}"; } } - std::function opt_repetitions = [&](int up_to_n, bool prefix_with_sep) -> std::string { - auto content = prefix_with_sep && !separator_rule.empty() ? separator_rule + " " + item_rule : item_rule; - - if (up_to_n == 0) { - return ""; - } else if (up_to_n == 1) { - return "(" + content + ")?"; - } else if (!separator_rule.empty() && !prefix_with_sep) { - return "(" + content + " " + opt_repetitions(up_to_n - 1, true) + ")?"; - } else { - std::string res = repeat("(" + content + " ", up_to_n); - // strip trailing space - res = res.substr(0, res.length() - 1); - res += repeat(")?", up_to_n); - return res; - } - }; - - if (min_items > 0 && max_items != min_items) { - result += " "; + auto result = item_rule + " " + build_repetition("(" + separator_rule + " " + item_rule + ")", min_items == 0 ? 0 : min_items - 1, has_max ? max_items - 1 : max_items); + if (min_items == 0) { + result = "(" + result + ")?"; } - - if (max_items != std::numeric_limits::max()) { - result += opt_repetitions(max_items - min_items, min_items > 0); - } else { - std::string item_operator = "(" + (separator_rule.empty() ? "" : separator_rule + " ") + item_rule + ")"; - if (min_items == 0 && !separator_rule.empty()) { - result = "(" + item_rule + " " + item_operator + "*)?"; - } else { - result += item_operator + "*"; - } - } - return result; } @@ -104,30 +73,24 @@ struct BuiltinRule { std::vector deps; }; -const std::string _up_to_15_digits = build_repetition("[0-9]", 0, 15); - std::unordered_map PRIMITIVE_RULES = { {"boolean", {"(\"true\" | \"false\") space", {}}}, - {"decimal-part", {"[0-9] " + _up_to_15_digits, {}}}, - {"integral-part", {"[0-9] | [1-9] " + _up_to_15_digits, {}}}, + {"decimal-part", {"[0-9]{1,16}", {}}}, + {"integral-part", {"[0] | [1-9] [0-9]{0,15}", {}}}, {"number", {"(\"-\"? integral-part) (\".\" decimal-part)? ([eE] [-+]? integral-part)? space", {"integral-part", "decimal-part"}}}, {"integer", {"(\"-\"? integral-part) space", {"integral-part"}}}, {"value", {"object | array | string | number | boolean | null", {"object", "array", "string", "number", "boolean", "null"}}}, {"object", {"\"{\" space ( string \":\" space value (\",\" space string \":\" space value)* )? \"}\" space", {"string", "value"}}}, {"array", {"\"[\" space ( value (\",\" space value)* )? \"]\" space", {"value"}}}, - {"uuid", {"\"\\\"\" [0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F] " - "\"-\" [0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F] " - "\"-\" [0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F] " - "\"-\" [0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F] " - "\"-\" [0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F] \"\\\"\" space", {}}}, - {"char", {"[^\"\\\\] | \"\\\\\" ([\"\\\\/bfnrt] | \"u\" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])", {}}}, + {"uuid", {"\"\\\"\" [0-9a-fA-F]{8} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{12} \"\\\"\" space", {}}}, + {"char", {"[^\"\\\\] | \"\\\\\" ([\"\\\\/bfnrt] | \"u\" [0-9a-fA-F]{4})", {}}}, {"string", {"\"\\\"\" char* \"\\\"\" space", {"char"}}}, {"null", {"\"null\" space", {}}}, }; std::unordered_map STRING_FORMAT_RULES = { - {"date", {"[0-9] [0-9] [0-9] [0-9] \"-\" ( \"0\" [1-9] | \"1\" [0-2] ) \"-\" ( \"0\" [1-9] | [1-2] [0-9] | \"3\" [0-1] )", {}}}, - {"time", {"([01] [0-9] | \"2\" [0-3]) \":\" [0-5] [0-9] \":\" [0-5] [0-9] ( \".\" [0-9] [0-9] [0-9] )? ( \"Z\" | ( \"+\" | \"-\" ) ( [01] [0-9] | \"2\" [0-3] ) \":\" [0-5] [0-9] )", {}}}, + {"date", {"[0-9]{4} \"-\" ( \"0\" [1-9] | \"1\" [0-2] ) \"-\" ( \"0\" [1-9] | [1-2] [0-9] | \"3\" [0-1] )", {}}}, + {"time", {"([01] [0-9] | \"2\" [0-3]) \":\" [0-5] [0-9] \":\" [0-5] [0-9] ( \".\" [0-9]{3} )? ( \"Z\" | ( \"+\" | \"-\" ) ( [01] [0-9] | \"2\" [0-3] ) \":\" [0-5] [0-9] )", {}}}, {"date-time", {"date \"T\" time", {"date", "time"}}}, {"date-string", {"\"\\\"\" date \"\\\"\" space", {"date"}}}, {"time-string", {"\"\\\"\" time \"\\\"\" space", {"time"}}}, @@ -411,8 +374,7 @@ private: sub_is_literal ? "\"" + sub + "\"" : sub, min_times, max_times, - "", - sub_is_literal + "" ); seq.back().second = false; } else { diff --git a/llama/json-schema-to-grammar.h b/llama/json-schema-to-grammar.h index 36ca9652..9761e6ce 100644 --- a/llama/json-schema-to-grammar.h +++ b/llama/json-schema-to-grammar.h @@ -1,5 +1,5 @@ /** - * llama.cpp - git 059031b8c40e1f4ba60586842c5b1ed3ddf61842 + * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28 * * MIT License * diff --git a/llama/llama.cpp b/llama/llama.cpp index 48cb19d5..5d99ab1a 100644 --- a/llama/llama.cpp +++ b/llama/llama.cpp @@ -1,5 +1,5 @@ /** - * llama.cpp - git 059031b8c40e1f4ba60586842c5b1ed3ddf61842 + * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28 * * MIT License * @@ -39,8 +39,6 @@ #ifdef GGML_USE_CUDA # include "ggml-cuda.h" -#elif defined(GGML_USE_CLBLAST) -# include "ggml-opencl.h" #elif defined(GGML_USE_VULKAN) # include "ggml-vulkan.h" #elif defined(GGML_USE_SYCL) @@ -52,16 +50,9 @@ #ifdef GGML_USE_METAL # include "ggml-metal.h" #endif -#ifdef GGML_USE_MPI -# include "ggml-mpi.h" -#endif -#ifndef QK_K -# ifdef GGML_QKK_64 -# define QK_K 64 -# else -# define QK_K 256 -# endif -#endif + +// TODO: replace with ggml API call +#define QK_K 256 #ifdef __has_include #if __has_include() @@ -136,14 +127,14 @@ #endif #define LLAMA_MAX_NODES 8192 -#define LLAMA_MAX_EXPERTS 60 +#define LLAMA_MAX_EXPERTS 160 // // logging // LLAMA_ATTRIBUTE_FORMAT(2, 3) -static void llama_log_internal (ggml_log_level level, const char* format, ...); +static void llama_log_internal (ggml_log_level level, const char * format, ...); static void llama_log_callback_default(ggml_log_level level, const char * text, void * user_data); #define LLAMA_LOG_INFO(...) llama_log_internal(GGML_LOG_LEVEL_INFO , __VA_ARGS__) @@ -231,7 +222,6 @@ enum llm_arch { LLM_ARCH_GPTNEOX, LLM_ARCH_MPT, LLM_ARCH_STARCODER, - LLM_ARCH_PERSIMMON, LLM_ARCH_REFACT, LLM_ARCH_BERT, LLM_ARCH_NOMIC_BERT, @@ -255,6 +245,8 @@ enum llm_arch { LLM_ARCH_COMMAND_R, LLM_ARCH_DBRX, LLM_ARCH_OLMO, + LLM_ARCH_ARCTIC, + LLM_ARCH_DEEPSEEK2, LLM_ARCH_UNKNOWN, }; @@ -268,7 +260,6 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_MPT, "mpt" }, { LLM_ARCH_BAICHUAN, "baichuan" }, { LLM_ARCH_STARCODER, "starcoder" }, - { LLM_ARCH_PERSIMMON, "persimmon" }, { LLM_ARCH_REFACT, "refact" }, { LLM_ARCH_BERT, "bert" }, { LLM_ARCH_NOMIC_BERT, "nomic-bert" }, @@ -292,6 +283,8 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_COMMAND_R, "command-r" }, { LLM_ARCH_DBRX, "dbrx" }, { LLM_ARCH_OLMO, "olmo" }, + { LLM_ARCH_ARCTIC, "arctic" }, + { LLM_ARCH_DEEPSEEK2, "deepseek2" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -312,11 +305,15 @@ enum llm_kv { LLM_KV_CONTEXT_LENGTH, LLM_KV_EMBEDDING_LENGTH, LLM_KV_BLOCK_COUNT, + LLM_KV_LEADING_DENSE_BLOCK_COUNT, LLM_KV_FEED_FORWARD_LENGTH, + LLM_KV_EXPERT_FEED_FORWARD_LENGTH, LLM_KV_USE_PARALLEL_RESIDUAL, LLM_KV_TENSOR_DATA_LAYOUT, LLM_KV_EXPERT_COUNT, LLM_KV_EXPERT_USED_COUNT, + LLM_KV_EXPERT_SHARED_COUNT, + LLM_KV_EXPERT_WEIGHTS_SCALE, LLM_KV_POOLING_TYPE, LLM_KV_LOGIT_SCALE, @@ -329,14 +326,18 @@ enum llm_kv { LLM_KV_ATTENTION_LAYERNORM_EPS, LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, LLM_KV_ATTENTION_CAUSAL, + LLM_KV_ATTENTION_Q_LORA_RANK, + LLM_KV_ATTENTION_KV_LORA_RANK, LLM_KV_ROPE_DIMENSION_COUNT, LLM_KV_ROPE_FREQ_BASE, LLM_KV_ROPE_SCALE_LINEAR, LLM_KV_ROPE_SCALING_TYPE, LLM_KV_ROPE_SCALING_FACTOR, + LLM_KV_ROPE_SCALING_ATTN_FACTOR, LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, LLM_KV_ROPE_SCALING_FINETUNED, + LLM_KV_ROPE_SCALING_YARN_LOG_MUL, LLM_KV_SPLIT_NO, LLM_KV_SPLIT_COUNT, @@ -385,17 +386,21 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_GENERAL_SOURCE_URL, "general.source.url" }, { LLM_KV_GENERAL_SOURCE_HF_REPO, "general.source.huggingface.repository" }, - { LLM_KV_VOCAB_SIZE, "%s.vocab_size" }, - { LLM_KV_CONTEXT_LENGTH, "%s.context_length" }, - { LLM_KV_EMBEDDING_LENGTH, "%s.embedding_length" }, - { LLM_KV_BLOCK_COUNT, "%s.block_count" }, - { LLM_KV_FEED_FORWARD_LENGTH, "%s.feed_forward_length" }, - { LLM_KV_USE_PARALLEL_RESIDUAL, "%s.use_parallel_residual" }, - { LLM_KV_TENSOR_DATA_LAYOUT, "%s.tensor_data_layout" }, - { LLM_KV_EXPERT_COUNT, "%s.expert_count" }, - { LLM_KV_EXPERT_USED_COUNT, "%s.expert_used_count" }, - { LLM_KV_POOLING_TYPE , "%s.pooling_type" }, - { LLM_KV_LOGIT_SCALE, "%s.logit_scale" }, + { LLM_KV_VOCAB_SIZE, "%s.vocab_size" }, + { LLM_KV_CONTEXT_LENGTH, "%s.context_length" }, + { LLM_KV_EMBEDDING_LENGTH, "%s.embedding_length" }, + { LLM_KV_BLOCK_COUNT, "%s.block_count" }, + { LLM_KV_LEADING_DENSE_BLOCK_COUNT, "%s.leading_dense_block_count" }, + { LLM_KV_FEED_FORWARD_LENGTH, "%s.feed_forward_length" }, + { LLM_KV_EXPERT_FEED_FORWARD_LENGTH, "%s.expert_feed_forward_length" }, + { LLM_KV_USE_PARALLEL_RESIDUAL, "%s.use_parallel_residual" }, + { LLM_KV_TENSOR_DATA_LAYOUT, "%s.tensor_data_layout" }, + { LLM_KV_EXPERT_COUNT, "%s.expert_count" }, + { LLM_KV_EXPERT_USED_COUNT, "%s.expert_used_count" }, + { LLM_KV_EXPERT_SHARED_COUNT, "%s.expert_shared_count" }, + { LLM_KV_EXPERT_WEIGHTS_SCALE, "%s.expert_weights_scale" }, + { LLM_KV_POOLING_TYPE , "%s.pooling_type" }, + { LLM_KV_LOGIT_SCALE, "%s.logit_scale" }, { LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" }, { LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" }, @@ -406,14 +411,18 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ATTENTION_LAYERNORM_EPS, "%s.attention.layer_norm_epsilon" }, { LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, "%s.attention.layer_norm_rms_epsilon" }, { LLM_KV_ATTENTION_CAUSAL, "%s.attention.causal" }, + { LLM_KV_ATTENTION_Q_LORA_RANK, "%s.attention.q_lora_rank" }, + { LLM_KV_ATTENTION_KV_LORA_RANK, "%s.attention.kv_lora_rank" }, { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, { LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" }, { LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" }, { LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" }, { LLM_KV_ROPE_SCALING_FACTOR, "%s.rope.scaling.factor" }, + { LLM_KV_ROPE_SCALING_ATTN_FACTOR, "%s.rope.scaling.attn_factor" }, { LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" }, { LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned" }, + { LLM_KV_ROPE_SCALING_YARN_LOG_MUL, "%s.rope.scaling.yarn_log_multiplier" }, { LLM_KV_SPLIT_NO, "split.no" }, { LLM_KV_SPLIT_COUNT, "split.count" }, @@ -467,6 +476,8 @@ enum llm_tensor { LLM_TENSOR_OUTPUT, LLM_TENSOR_OUTPUT_NORM, LLM_TENSOR_ROPE_FREQS, + LLM_TENSOR_ROPE_FACTORS_LONG, + LLM_TENSOR_ROPE_FACTORS_SHORT, LLM_TENSOR_ATTN_Q, LLM_TENSOR_ATTN_K, LLM_TENSOR_ATTN_V, @@ -486,6 +497,7 @@ enum llm_tensor { LLM_TENSOR_FFN_DOWN_EXP, // split experts for backward compatibility LLM_TENSOR_FFN_GATE_EXP, LLM_TENSOR_FFN_UP_EXP, + LLM_TENSOR_FFN_NORM_EXPS, LLM_TENSOR_FFN_DOWN_EXPS, // merged experts LLM_TENSOR_FFN_GATE_EXPS, LLM_TENSOR_FFN_UP_EXPS, @@ -502,6 +514,12 @@ enum llm_tensor { LLM_TENSOR_SSM_A, LLM_TENSOR_SSM_D, LLM_TENSOR_SSM_OUT, + LLM_TENSOR_ATTN_Q_A, + LLM_TENSOR_ATTN_Q_B, + LLM_TENSOR_ATTN_KV_A_MQA, + LLM_TENSOR_ATTN_KV_B, + LLM_TENSOR_ATTN_Q_A_NORM, + LLM_TENSOR_ATTN_KV_A_NORM, }; static const std::map> LLM_TENSOR_NAMES = { @@ -624,23 +642,6 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, - { - LLM_ARCH_PERSIMMON, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd"}, - { LLM_TENSOR_OUTPUT_NORM, "output_norm"}, - { LLM_TENSOR_OUTPUT, "output"}, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm"}, - { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv"}, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output"}, - { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm"}, - { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm"}, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm"}, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down"}, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up"}, - { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd"}, - }, - }, { LLM_ARCH_MPT, { @@ -729,6 +730,7 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, { LLM_TENSOR_TOKEN_TYPES, "token_types" }, + { LLM_TENSOR_ATTN_NORM_2, "blk.%d.attn_norm_2" }, { LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" }, { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, @@ -851,18 +853,20 @@ static const std::map> LLM_TENSOR_NA { LLM_ARCH_PHI3, { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FACTORS_LONG, "rope_factors_long" }, + { LLM_TENSOR_ROPE_FACTORS_SHORT, "rope_factors_short" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, { @@ -1078,6 +1082,57 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, + { + LLM_ARCH_ARCTIC, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_NORM_EXPS, "blk.%d.ffn_norm_exps" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + }, + }, + { + LLM_ARCH_DEEPSEEK2, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q_A_NORM, "blk.%d.attn_q_a_norm" }, + { LLM_TENSOR_ATTN_KV_A_NORM, "blk.%d.attn_kv_a_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_Q_A, "blk.%d.attn_q_a" }, + { LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" }, + { LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" }, + { LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" }, + { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, + { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, + { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, + }, + }, { LLM_ARCH_UNKNOWN, { @@ -1672,12 +1727,13 @@ struct llama_mlock { }; using llama_mlocks = std::vector>; -static std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token, bool special) { +// NOTE: avoid ever using this except for building the token_to_piece caches +static std::string llama_token_to_piece(const struct llama_model * model, llama_token token, bool special) { std::vector result(8, 0); - const int n_tokens = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size(), special); + const int n_tokens = llama_token_to_piece(model, token, result.data(), result.size(), special); if (n_tokens < 0) { result.resize(-n_tokens); - int check = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size(), special); + int check = llama_token_to_piece(model, token, result.data(), result.size(), special); GGML_ASSERT(check == -n_tokens); } else { @@ -1723,6 +1779,8 @@ struct llama_state { llama_state() { #ifdef GGML_USE_METAL ggml_backend_metal_log_set_callback(log_callback, log_callback_user_data); +#elif defined(GGML_USE_CUDA) + ggml_backend_cuda_log_set_callback(log_callback, log_callback_user_data); #endif } @@ -1736,23 +1794,31 @@ static llama_state g_state; // available llama models enum e_model { MODEL_UNKNOWN, + MODEL_14M, MODEL_17M, MODEL_22M, MODEL_33M, + MODEL_70M, MODEL_109M, MODEL_137M, + MODEL_160M, MODEL_335M, + MODEL_410M, MODEL_0_5B, MODEL_1B, + MODEL_1_4B, MODEL_2B, + MODEL_2_8B, MODEL_3B, MODEL_4B, + MODEL_6_9B, MODEL_7B, MODEL_8B, MODEL_12B, MODEL_13B, MODEL_14B, MODEL_15B, + MODEL_16B, MODEL_20B, MODEL_30B, MODEL_34B, @@ -1760,6 +1826,7 @@ enum e_model { MODEL_40B, MODEL_65B, MODEL_70B, + MODEL_236B, MODEL_314B, MODEL_SMALL, MODEL_MEDIUM, @@ -1769,6 +1836,7 @@ enum e_model { MODEL_8x7B, MODEL_8x22B, MODEL_16x12B, + MODEL_10B_128x3_66B, }; static const size_t kiB = 1024; @@ -1778,6 +1846,7 @@ static const size_t GiB = 1024*MiB; struct llama_hparams { bool vocab_only; bool rope_finetuned; + bool use_par_res; uint32_t n_vocab; uint32_t n_ctx_train; // context size the model was trained on @@ -1793,12 +1862,21 @@ struct llama_hparams { uint32_t n_expert_used = 0; uint32_t n_vocab_type = 0; // for BERT-style token types + uint32_t n_layer_dense_lead = 0; + uint32_t n_lora_q = 0; + uint32_t n_lora_kv = 0; + uint32_t n_ff_exp = 0; + uint32_t n_expert_shared = 0; + float expert_weights_scale = 0.0; + float f_norm_eps; float f_norm_rms_eps; + float rope_attn_factor = 1.0f; float rope_freq_base_train; float rope_freq_scale_train; - uint32_t n_yarn_orig_ctx; + uint32_t n_ctx_orig_yarn; + float rope_yarn_log_mul; // for State Space Models uint32_t ssm_d_conv = 0; @@ -1832,8 +1910,14 @@ struct llama_hparams { if (this->n_expert != other.n_expert) return true; if (this->n_expert_used != other.n_expert_used) return true; + if (this->n_layer_dense_lead != other.n_layer_dense_lead) return true; + if (this->n_lora_q != other.n_lora_q) return true; + if (this->n_lora_kv != other.n_lora_kv) return true; + if (this->n_ff_exp != other.n_ff_exp) return true; + if (this->n_expert_shared != other.n_expert_shared) return true; + if (this->rope_finetuned != other.rope_finetuned) return true; - if (this->n_yarn_orig_ctx != other.n_yarn_orig_ctx) return true; + if (this->n_ctx_orig_yarn != other.n_ctx_orig_yarn) return true; if (this->ssm_d_conv != other.ssm_d_conv) return true; if (this->ssm_d_inner != other.ssm_d_inner) return true; @@ -1844,8 +1928,11 @@ struct llama_hparams { if (!is_float_close(this->f_norm_eps, other.f_norm_eps, EPSILON)) return true; if (!is_float_close(this->f_norm_rms_eps, other.f_norm_rms_eps, EPSILON)) return true; + if (!is_float_close(this->rope_attn_factor, other.rope_attn_factor, EPSILON)) return true; if (!is_float_close(this->rope_freq_base_train, other.rope_freq_base_train, EPSILON)) return true; if (!is_float_close(this->rope_freq_scale_train, other.rope_freq_scale_train, EPSILON)) return true; + if (!is_float_close(this->expert_weights_scale, other.expert_weights_scale, EPSILON)) return true; + if (!is_float_close(this->rope_yarn_log_mul, other.rope_yarn_log_mul, EPSILON)) return true; return false; } @@ -1889,7 +1976,7 @@ struct llama_cparams { float rope_freq_base; float rope_freq_scale; - uint32_t n_yarn_orig_ctx; + uint32_t n_ctx_orig_yarn; // These hyperparameters are not exposed in GGUF, because all // existing YaRN models use the same values for them. float yarn_ext_factor; @@ -1921,6 +2008,8 @@ struct llama_layer { struct ggml_tensor * attn_k_norm_b; struct ggml_tensor * attn_out_norm; struct ggml_tensor * attn_out_norm_b; + struct ggml_tensor * attn_q_a_norm; + struct ggml_tensor * attn_kv_a_norm; // attention struct ggml_tensor * wq; @@ -1928,6 +2017,10 @@ struct llama_layer { struct ggml_tensor * wv; struct ggml_tensor * wo; struct ggml_tensor * wqkv; + struct ggml_tensor * wq_a; + struct ggml_tensor * wq_b; + struct ggml_tensor * wkv_a_mqa; + struct ggml_tensor * wkv_b; // attention bias struct ggml_tensor * bq; @@ -1941,6 +2034,7 @@ struct llama_layer { struct ggml_tensor * ffn_norm_b; struct ggml_tensor * layer_out_norm; struct ggml_tensor * layer_out_norm_b; + struct ggml_tensor * ffn_norm_exps; // ff struct ggml_tensor * ffn_gate; // w1 @@ -1960,8 +2054,9 @@ struct llama_layer { struct ggml_tensor * ffn_up_shexp; // ff bias - struct ggml_tensor * ffn_down_b; // b2 - struct ggml_tensor * ffn_up_b; // b3 + struct ggml_tensor * ffn_gate_b = nullptr; + struct ggml_tensor * ffn_down_b = nullptr; // b2 + struct ggml_tensor * ffn_up_b = nullptr; // b3 struct ggml_tensor * ffn_act; // mamba proj @@ -1978,6 +2073,10 @@ struct llama_layer { // mamba bias struct ggml_tensor * ssm_conv1d_b; struct ggml_tensor * ssm_dt_b; + + // long rope factors + struct ggml_tensor * rope_long = nullptr; + struct ggml_tensor * rope_short = nullptr; }; struct llama_kv_cell { @@ -2075,12 +2174,12 @@ struct llama_control_vector { struct llama_vocab { using id = int32_t; using token = std::string; - using ttype = llama_token_type; + using tattr = llama_token_attr; struct token_data { token text; float score; - ttype type; + tattr attr; }; enum llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM; @@ -2089,7 +2188,8 @@ struct llama_vocab { std::unordered_map token_to_id; std::vector id_to_token; - std::unordered_map special_tokens_cache; + std::vector cache_special_tokens; + std::vector cache_token_to_piece; // llama_token_to_piece(special = true); std::map, int> bpe_ranks; @@ -2294,19 +2394,36 @@ struct llama_context { // control vectors struct llama_control_vector cvec; - -#ifdef GGML_USE_MPI - ggml_mpi_context * ctx_mpi = NULL; -#endif }; +static size_t llama_get_device_count(const llama_model & model) { + size_t count = 1; +#if defined(GGML_USE_CUDA) + count = ggml_backend_cuda_get_device_count(); +#elif defined(GGML_USE_SYCL) + count = ggml_backend_sycl_get_device_count(); +#elif defined(GGML_USE_VULKAN) + count = ggml_backend_vk_get_device_count(); +#endif +#if defined(GGML_USE_RPC) + count += model.rpc_servers.size(); +#endif + return count; + GGML_UNUSED(model); +} + static ggml_backend_buffer_type_t llama_default_buffer_type_offload(const llama_model & model, int gpu) { ggml_backend_buffer_type_t buft = nullptr; -#ifdef GGML_USE_RPC - std::string endpoint = model.rpc_servers[gpu]; - buft = ggml_backend_rpc_buffer_type(endpoint.c_str()); -#elif defined(GGML_USE_METAL) +#if defined(GGML_USE_RPC) + int dev_count = (int)llama_get_device_count(model); + int rpc_count = (int)model.rpc_servers.size(); + if (gpu >= dev_count - rpc_count) { + const char * endpoint = model.rpc_servers[gpu - dev_count + rpc_count].c_str(); + return ggml_backend_rpc_buffer_type(endpoint); + } +#endif +#if defined(GGML_USE_METAL) buft = ggml_backend_metal_buffer_type(); #elif defined(GGML_USE_CUDA) buft = ggml_backend_cuda_buffer_type(gpu); @@ -2314,8 +2431,6 @@ static ggml_backend_buffer_type_t llama_default_buffer_type_offload(const llama_ buft = ggml_backend_vk_buffer_type(gpu); #elif defined(GGML_USE_SYCL) buft = ggml_backend_sycl_buffer_type(gpu); -#elif defined(GGML_USE_CLBLAST) - buft = ggml_backend_opencl_buffer_type(); #elif defined(GGML_USE_KOMPUTE) buft = ggml_backend_kompute_buffer_type(gpu); if (buft == nullptr) { @@ -2354,29 +2469,19 @@ static ggml_backend_buffer_type_t llama_default_buffer_type_split(const llama_mo GGML_UNUSED(tensor_split); } -static size_t llama_get_device_count(const llama_model & model) { -#if defined(GGML_USE_RPC) - return model.rpc_servers.size(); -#elif defined(GGML_USE_CUDA) - return ggml_backend_cuda_get_device_count(); -#elif defined(GGML_USE_SYCL) - return ggml_backend_sycl_get_device_count(); -#elif defined(GGML_USE_VULKAN) - return ggml_backend_vk_get_device_count(); -#else - return 1; -#endif - GGML_UNUSED(model); -} - static size_t llama_get_device_memory(const llama_model & model, int device) { #if defined(GGML_USE_RPC) - size_t total; - size_t free; - std::string endpoint = model.rpc_servers[device]; - ggml_backend_rpc_get_device_memory(endpoint.c_str(), &free, &total); - return free; -#elif defined(GGML_USE_CUDA) + int dev_count = (int)llama_get_device_count(model); + int rpc_count = (int)model.rpc_servers.size(); + if (device >= dev_count - rpc_count) { + size_t total; + size_t free; + const char * endpoint = model.rpc_servers[device - dev_count + rpc_count].c_str(); + ggml_backend_rpc_get_device_memory(endpoint, &free, &total); + return free; + } +#endif +#if defined(GGML_USE_CUDA) size_t total; size_t free; ggml_backend_cuda_get_device_memory(device, &free, &total); @@ -2448,10 +2553,6 @@ static bool llama_kv_cache_init( } } -#ifdef GGML_USE_CLBLAST - offload = false; -#endif - // count used buffer types std::map buft_layer_count; if (offload) { @@ -2517,7 +2618,6 @@ static bool llama_kv_cache_init( static bool llama_kv_cache_find_slot( struct llama_kv_cache & cache, const struct llama_batch & batch) { - const uint32_t n_ctx = cache.size; const uint32_t n_tokens = batch.n_tokens; if (cache.recurrent) { @@ -2568,16 +2668,16 @@ static bool llama_kv_cache_find_slot( } // otherwise, one cell per token. - if (n_tokens > n_ctx) { - LLAMA_LOG_ERROR("%s: n_tokens=%d > n_ctx=%d\n", __func__, n_tokens, n_ctx); + if (n_tokens > cache.size) { + LLAMA_LOG_ERROR("%s: n_tokens=%d > cache.size=%d\n", __func__, n_tokens, cache.size); return false; } uint32_t n_tested = 0; while (true) { - if (cache.head + n_tokens > n_ctx) { - n_tested += n_ctx - cache.head; + if (cache.head + n_tokens > cache.size) { + n_tested += cache.size - cache.head; cache.head = 0; continue; } @@ -2596,7 +2696,7 @@ static bool llama_kv_cache_find_slot( break; } - if (n_tested >= n_ctx) { + if (n_tested >= cache.size) { //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens); return false; } @@ -3356,6 +3456,39 @@ struct llama_model_loader { return get_arr_n(llm_kv(kid), result, required); } + template + bool get_arr(const std::string & key, std::vector & result, const bool required = true) { + const int kid = gguf_find_key(meta, key.c_str()); + + if (kid < 0) { + if (required) { + throw std::runtime_error(format("key not found in model: %s", key.c_str())); + } + return false; + } + + struct GGUFMeta::ArrayInfo arr_info = + GGUFMeta::GKV::get_kv(meta, kid); + + if (arr_info.gt != GGUF_TYPE_FLOAT32 && arr_info.gt != GGUF_TYPE_INT32) { + throw std::runtime_error(format("%s is not a float32 or int32 array", key.c_str())); + } + + // GGML_ASSERT(gguf_type_size(arr_info.gt) == sizeof(T)); + GGML_ASSERT((arr_info.gt != GGUF_TYPE_FLOAT32 || std::is_same::value)); + GGML_ASSERT((arr_info.gt != GGUF_TYPE_INT32 || std::is_same::value)); + + result.resize(arr_info.length); + result.assign((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length); + + return true; + } + + template + bool get_arr(const enum llm_kv kid, T& result, const bool required = true) { + return get_arr(llm_kv(kid), result, required); + } + template bool get_key(const std::string & key, T & result, const bool required = true) { auto it = kv_overrides.find(key); @@ -3430,11 +3563,15 @@ struct llama_model_loader { return get_tensor_meta(get_tensor_name(i)); } - struct ggml_tensor * create_tensor_for(struct ggml_context * ctx, const struct ggml_tensor * cur) { + struct ggml_tensor * create_tensor_for(struct ggml_context * ctx, const struct ggml_tensor * cur, bool duplicated) { struct ggml_tensor * tensor = ggml_dup_tensor(ctx, cur); ggml_set_name(tensor, ggml_get_name(cur)); - n_created++; + if (duplicated) { + size_data += ggml_nbytes(cur); + } else { + n_created++; + } return tensor; } @@ -3469,14 +3606,17 @@ struct llama_model_loader { return cur; } - struct ggml_tensor * create_tensor(struct ggml_context * ctx, const std::string & name, const std::vector & ne, bool required = true) { - const struct ggml_tensor * cur = check_tensor_dims(name, ne, required); + static const int TENSOR_NOT_REQUIRED = 1; + static const int TENSOR_DUPLICATED = 2; + + struct ggml_tensor * create_tensor(struct ggml_context * ctx, const std::string & name, const std::vector & ne, int flags = 0) { + const struct ggml_tensor * cur = check_tensor_dims(name, ne, !(flags & TENSOR_NOT_REQUIRED)); if (cur == NULL) { return NULL; } - return create_tensor_for(ctx, cur); + return create_tensor_for(ctx, cur, flags & TENSOR_DUPLICATED); } struct ggml_tensor * create_tensor_as_view(struct ggml_context * ctx, struct ggml_tensor * base, const std::string & name, const std::vector & ne, size_t offset, bool required = true) { @@ -3776,37 +3916,50 @@ static std::string llama_model_ftype_name(llama_ftype ftype) { static const char * llama_model_type_name(e_model type) { switch (type) { - case MODEL_22M: return "22M"; - case MODEL_33M: return "33M"; - case MODEL_109M: return "109M"; - case MODEL_137M: return "137M"; - case MODEL_0_5B: return "0.5B"; - case MODEL_1B: return "1B"; - case MODEL_2B: return "2B"; - case MODEL_3B: return "3B"; - case MODEL_7B: return "7B"; - case MODEL_8B: return "8B"; - case MODEL_12B: return "12B"; - case MODEL_13B: return "13B"; - case MODEL_14B: return "14B"; - case MODEL_15B: return "15B"; - case MODEL_20B: return "20B"; - case MODEL_30B: return "30B"; - case MODEL_34B: return "34B"; - case MODEL_35B: return "35B"; - case MODEL_40B: return "40B"; - case MODEL_65B: return "65B"; - case MODEL_70B: return "70B"; - case MODEL_314B: return "314B"; - case MODEL_SMALL: return "0.1B"; - case MODEL_MEDIUM: return "0.4B"; - case MODEL_LARGE: return "0.8B"; - case MODEL_XL: return "1.5B"; - case MODEL_A2_7B: return "A2.7B"; - case MODEL_8x7B: return "8x7B"; - case MODEL_8x22B: return "8x22B"; - case MODEL_16x12B: return "16x12B"; - default: return "?B"; + case MODEL_14M: return "14M"; + case MODEL_17M: return "17M"; + case MODEL_22M: return "22M"; + case MODEL_33M: return "33M"; + case MODEL_70M: return "70M"; + case MODEL_109M: return "109M"; + case MODEL_137M: return "137M"; + case MODEL_160M: return "160M"; + case MODEL_335M: return "335M"; + case MODEL_410M: return "410M"; + case MODEL_0_5B: return "0.5B"; + case MODEL_1B: return "1B"; + case MODEL_1_4B: return "1.4B"; + case MODEL_2B: return "2B"; + case MODEL_2_8B: return "2.8B"; + case MODEL_3B: return "3B"; + case MODEL_4B: return "4B"; + case MODEL_6_9B: return "6.9B"; + case MODEL_7B: return "7B"; + case MODEL_8B: return "8B"; + case MODEL_12B: return "12B"; + case MODEL_13B: return "13B"; + case MODEL_14B: return "14B"; + case MODEL_15B: return "15B"; + case MODEL_16B: return "16B"; + case MODEL_20B: return "20B"; + case MODEL_30B: return "30B"; + case MODEL_34B: return "34B"; + case MODEL_35B: return "35B"; + case MODEL_40B: return "40B"; + case MODEL_65B: return "65B"; + case MODEL_70B: return "70B"; + case MODEL_236B: return "236B"; + case MODEL_314B: return "314B"; + case MODEL_SMALL: return "0.1B"; + case MODEL_MEDIUM: return "0.4B"; + case MODEL_LARGE: return "0.8B"; + case MODEL_XL: return "1.5B"; + case MODEL_A2_7B: return "A2.7B"; + case MODEL_8x7B: return "8x7B"; + case MODEL_8x22B: return "8x22B"; + case MODEL_16x12B: return "16x12B"; + case MODEL_10B_128x3_66B: return "10B+128x3.66B"; + default: return "?B"; } } @@ -3879,8 +4032,8 @@ static void llm_load_hparams( ml.get_key(LLM_KV_ROPE_SCALING_FINETUNED, rope_finetuned, false); hparams.rope_finetuned = rope_finetuned; - hparams.n_yarn_orig_ctx = hparams.n_ctx_train; - ml.get_key(LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, hparams.n_yarn_orig_ctx, false); + hparams.n_ctx_orig_yarn = hparams.n_ctx_train; + ml.get_key(LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, hparams.n_ctx_orig_yarn, false); // rope_freq_base (optional) hparams.rope_freq_base_train = 10000.0f; @@ -3899,6 +4052,8 @@ static void llm_load_hparams( } hparams.rope_freq_scale_train = ropescale == 0.0f ? 1.0f : 1.0f/ropescale; + ml.get_key(LLM_KV_ROPE_SCALING_ATTN_FACTOR, hparams.rope_attn_factor, false); + // sanity check for n_rot (optional) { hparams.n_rot = (hparams.n_head == 0) ? 0 : hparams.n_embd / hparams.n_head; @@ -3936,7 +4091,9 @@ static void llm_load_hparams( switch (hparams.n_layer) { case 22: model.type = e_model::MODEL_1B; break; case 26: model.type = e_model::MODEL_3B; break; - case 32: model.type = hparams.n_vocab < 40000 ? e_model::MODEL_7B : e_model::MODEL_8B; break; + // granite uses a vocab with len 49152 + case 32: model.type = hparams.n_vocab == 49152 ? e_model::MODEL_3B : (hparams.n_vocab < 40000 ? e_model::MODEL_7B : e_model::MODEL_8B); break; + case 36: model.type = e_model::MODEL_8B; break; // granite case 40: model.type = e_model::MODEL_13B; break; case 48: model.type = e_model::MODEL_34B; break; case 60: model.type = e_model::MODEL_30B; break; @@ -3998,14 +4155,6 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; - case LLM_ARCH_PERSIMMON: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - switch (hparams.n_layer) { - case 36: model.type = e_model::MODEL_8B; break; - default: model.type = e_model::MODEL_UNKNOWN; - } - } break; case LLM_ARCH_REFACT: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -4147,6 +4296,7 @@ static void llm_load_hparams( switch (hparams.n_layer) { case 24: model.type = e_model::MODEL_1B; break; case 32: model.type = e_model::MODEL_3B; break; + case 40: model.type = e_model::MODEL_14B; break; default: model.type = e_model::MODEL_UNKNOWN; } } break; @@ -4213,6 +4363,8 @@ static void llm_load_hparams( case 30: model.type = e_model::MODEL_3B; break; case 32: model.type = e_model::MODEL_7B; break; case 40: model.type = e_model::MODEL_15B; break; + case 52: model.type = e_model::MODEL_20B; break; // granite + case 88: model.type = e_model::MODEL_34B; break; // granite default: model.type = e_model::MODEL_UNKNOWN; } } break; @@ -4287,6 +4439,85 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_GPTNEOX: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_USE_PARALLEL_RESIDUAL, hparams.use_par_res); + switch (hparams.n_layer) { + case 6: + switch (hparams.n_ff) { + case 512: model.type = e_model::MODEL_14M; break; + case 2048: model.type = e_model::MODEL_70M; break; + default: model.type = e_model::MODEL_UNKNOWN; + } break; + case 12: + switch (hparams.n_ff) { + case 3072: model.type = e_model::MODEL_160M; break; + default: model.type = e_model::MODEL_UNKNOWN; + } break; + case 16: + switch (hparams.n_ff) { + case 8192: model.type = e_model::MODEL_1B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } break; + case 24: + switch (hparams.n_ff) { + case 4096: model.type = e_model::MODEL_410M; break; + case 8192: model.type = e_model::MODEL_1_4B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } break; + case 32: + switch (hparams.n_ff) { + case 10240: model.type = e_model::MODEL_2_8B; break; + case 16384: model.type = e_model::MODEL_6_9B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } break; + case 36: + switch (hparams.n_ff) { + case 20480: model.type = e_model::MODEL_12B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } break; + case 44: + switch (hparams.n_ff) { + case 24576: model.type = e_model::MODEL_20B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; + case LLM_ARCH_ARCTIC: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + if (hparams.n_expert == 128) { + switch (hparams.n_layer) { + case 35: model.type = e_model::MODEL_10B_128x3_66B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } else { + model.type = e_model::MODEL_UNKNOWN; + } + } break; + case LLM_ARCH_DEEPSEEK2: + { + bool is_lite = (hparams.n_layer == 27); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); + if (!is_lite) { + ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); + } + ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); + ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul); + + switch (hparams.n_layer) { + case 27: model.type = e_model::MODEL_16B; break; + case 60: model.type = e_model::MODEL_236B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; default: (void)0; } @@ -4393,15 +4624,14 @@ static void llm_load_vocab( vocab.special_cls_id = 101; vocab.special_mask_id = 103; vocab.add_space_prefix = false; - } else { - if (tokenizer_model == "gpt2") { - vocab.type = LLAMA_VOCAB_TYPE_BPE; - } else { - LLAMA_LOG_WARN("%s: unknown tokenizer: '%s'", __func__, tokenizer_model.c_str()); - LLAMA_LOG_WARN("%s: using default tokenizer: 'llama'", __func__); - vocab.type = LLAMA_VOCAB_TYPE_SPM; - return; + } else if (tokenizer_model == "gpt2") { + vocab.type = LLAMA_VOCAB_TYPE_BPE; + + const int add_space_prefix_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_ADD_PREFIX).c_str()); + if (add_space_prefix_keyidx != -1) { + vocab.add_space_prefix = gguf_get_val_bool(ctx, add_space_prefix_keyidx); } + // read bpe merges and populate bpe ranks const int merges_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_MERGES).c_str()); if (merges_keyidx == -1) { @@ -4435,21 +4665,13 @@ static void llm_load_vocab( vocab.special_pad_id = -1; vocab.special_cls_id = -1; vocab.special_mask_id = -1; + } else { + throw std::runtime_error(format("unknown tokenizer: '%s'", tokenizer_model.c_str())); } // for now, only BPE models have pre-tokenizers if (vocab.type == LLAMA_VOCAB_TYPE_BPE) { - if (tokenizer_pre.empty()) { - LLAMA_LOG_WARN("%s: missing pre-tokenizer type, using: 'default'\n", __func__); - LLAMA_LOG_WARN("%s: \n", __func__); - LLAMA_LOG_WARN("%s: ************************************ \n", __func__); - LLAMA_LOG_WARN("%s: GENERATION QUALITY WILL BE DEGRADED! \n", __func__); - LLAMA_LOG_WARN("%s: CONSIDER REGENERATING THE MODEL \n", __func__); - LLAMA_LOG_WARN("%s: ************************************ \n", __func__); - LLAMA_LOG_WARN("%s: \n", __func__); - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; - } else if ( - tokenizer_pre == "default") { + if (tokenizer_pre == "default") { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; } else if ( tokenizer_pre == "llama3" || @@ -4476,7 +4698,8 @@ static void llm_load_vocab( tokenizer_pre == "jina-es" || tokenizer_pre == "jina-de" || tokenizer_pre == "jina-v2-es" || - tokenizer_pre == "jina-v2-de") { + tokenizer_pre == "jina-v2-de" || + tokenizer_pre == "jina-v2-code") { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_GPT2; } else if ( tokenizer_pre == "refact") { @@ -4487,14 +4710,21 @@ static void llm_load_vocab( } else if ( tokenizer_pre == "qwen2") { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_QWEN2; + } else if ( + tokenizer_pre == "stablelm2") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_STABLELM2; } else if ( tokenizer_pre == "olmo") { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_OLMO; } else if ( tokenizer_pre == "dbrx") { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DBRX; + } else if ( + tokenizer_pre == "smaug-bpe") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_SMAUG; } else { - throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str())); + LLAMA_LOG_WARN("%s: missing or unrecognized pre-tokenizer type, using: 'default'\n", __func__); + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; } } else { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; @@ -4531,7 +4761,20 @@ static void llm_load_vocab( auto & token_data = vocab.id_to_token[i]; token_data.text = std::move(word); token_data.score = scores ? scores[i] : 0.0f; - token_data.type = toktypes ? (llama_token_type) toktypes[i] : LLAMA_TOKEN_TYPE_NORMAL; + token_data.attr = LLAMA_TOKEN_ATTR_NORMAL; + + if (toktypes) { //TODO: remove, required until per token attributes are available from GGUF file + switch(toktypes[i]) { + case LLAMA_TOKEN_TYPE_UNKNOWN: token_data.attr = LLAMA_TOKEN_ATTR_UNKNOWN; break; + case LLAMA_TOKEN_TYPE_UNUSED: token_data.attr = LLAMA_TOKEN_ATTR_UNUSED; break; + case LLAMA_TOKEN_TYPE_NORMAL: token_data.attr = LLAMA_TOKEN_ATTR_NORMAL; break; + case LLAMA_TOKEN_TYPE_CONTROL: token_data.attr = LLAMA_TOKEN_ATTR_CONTROL; break; + case LLAMA_TOKEN_TYPE_USER_DEFINED: token_data.attr = LLAMA_TOKEN_ATTR_USER_DEFINED; break; + case LLAMA_TOKEN_TYPE_BYTE: token_data.attr = LLAMA_TOKEN_ATTR_BYTE; break; + case LLAMA_TOKEN_TYPE_UNDEFINED: token_data.attr = LLAMA_TOKEN_ATTR_UNDEFINED; break; + default: token_data.attr = LLAMA_TOKEN_ATTR_UNDEFINED; break; + } + } } GGML_ASSERT(vocab.id_to_token.size() == vocab.token_to_id.size()); @@ -4608,7 +4851,8 @@ static void llm_load_vocab( (t.first == "<|eot_id|>" || t.first == "<|im_end|>" || t.first == "<|end|>" || - t.first == "" + t.first == "" || + t.first == "<|endoftext|>" ) ) { vocab.special_eot_id = t.second; @@ -4620,96 +4864,88 @@ static void llm_load_vocab( // build special tokens cache { - // TODO: It is unclear (to me) at this point, whether special tokes are guaranteed to be of a deterministic type, - // and will always be correctly labeled in 'added_tokens.json' etc. - // The assumption is, since special tokens aren't meant to be exposed to end user, they are designed - // to be unmatchable by the tokenizer, therefore tokens from the vocab, which are unmatchable by the tokenizer - // are special tokens. - // From testing, this appears to correlate 1:1 with special tokens. - // - - // Counting special tokens and verifying in only one direction - // is sufficient to detect difference in those two sets. - // - uint32_t special_tokens_count_by_type = 0; - uint32_t special_tokens_count_from_verification = 0; - - bool special_tokens_definition_mismatch = false; - - for (const auto & t : vocab.token_to_id) { - const auto & token = t.first; - const auto & id = t.second; - - // Count all non-normal tokens in the vocab while iterating - if (vocab.id_to_token[id].type != LLAMA_TOKEN_TYPE_NORMAL) { - special_tokens_count_by_type++; - } - - // Skip single character tokens - if (token.length() > 1) { - bool is_tokenizable = false; - - // Split token string representation in two, in all possible ways - // and check if both halves can be matched to a valid token - for (unsigned i = 1; i < token.length();) { - const auto left = token.substr(0, i); - const auto right = token.substr(i); - - // check if we didnt partition in the middle of a utf sequence - auto utf = utf8_len(left.at(left.length() - 1)); - - if (utf == 1) { - if (vocab.token_to_id.find(left) != vocab.token_to_id.end() && - vocab.token_to_id.find(right) != vocab.token_to_id.end() ) { - is_tokenizable = true; - break; - } - i++; - } else { - // skip over the rest of multibyte utf sequence - i += utf - 1; - } - } - - if (!is_tokenizable) { - // Some tokens are multibyte, but they are utf sequences with equivalent text length of 1 - // it's faster to re-filter them here, since there are way less candidates now - - // Calculate a total "utf" length of a token string representation - size_t utf8_str_len = 0; - for (unsigned i = 0; i < token.length();) { - utf8_str_len++; - i += utf8_len(token.at(i)); - } - - // And skip the ones which are one character - if (utf8_str_len > 1) { - // At this point what we have left are special tokens only - vocab.special_tokens_cache[token] = id; - - // Count manually found special tokens - special_tokens_count_from_verification++; - - // If this manually found special token is not marked as such, flag a mismatch - if (vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_NORMAL) { - special_tokens_definition_mismatch = true; - } - } - } + for (llama_vocab::id id = 0; id < (llama_vocab::id)n_vocab; ++id) { + if (!(vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_NORMAL)) { + vocab.cache_special_tokens.push_back(id); } } - if (special_tokens_definition_mismatch || special_tokens_count_from_verification != special_tokens_count_by_type) { - LLAMA_LOG_WARN("%s: mismatch in special tokens definition ( %u/%zu vs %u/%zu ).\n", - __func__, - special_tokens_count_from_verification, vocab.id_to_token.size(), - special_tokens_count_by_type, vocab.id_to_token.size() - ); - } else { - LLAMA_LOG_INFO("%s: special tokens definition check successful ( %u/%zu ).\n", - __func__, - special_tokens_count_from_verification, vocab.id_to_token.size() - ); + std::sort( vocab.cache_special_tokens.begin(), vocab.cache_special_tokens.end(), + [&] (const llama_vocab::id a, const llama_vocab::id b) { + return vocab.id_to_token[a].text.size() > vocab.id_to_token[b].text.size(); + } + ); + + LLAMA_LOG_INFO("%s: special tokens cache size = %u\n", __func__, (uint32_t)vocab.cache_special_tokens.size()); + } + + // build token to piece cache + { + size_t size_cache = 0; + + std::vector cache_token_to_piece(n_vocab); + + for (uint32_t id = 0; id < n_vocab; ++id) { + cache_token_to_piece[id] = llama_token_to_piece(&model, id, true); + + size_cache += cache_token_to_piece[id].size(); + } + + std::swap(vocab.cache_token_to_piece, cache_token_to_piece); + + LLAMA_LOG_INFO("%s: token to piece cache size = %.4f MB\n", __func__, size_cache / 1024.0 / 1024.0); + } + + // Handle per token attributes + //NOTE: Each model customizes per token attributes. + //NOTE: Per token attributes are missing from the GGUF file. + //TODO: Extract attributes from GGUF file. + { + auto _contains_any = [] (const std::string &str, const std::vector &substrs) -> bool { + for (auto substr : substrs) { + if (str.find(substr) < std::string::npos) { + return true; + } + } + return false; + }; + + auto _set_tokenid_attr = [&] (const llama_vocab::id id, llama_token_attr attr, bool value) { + uint32_t current = vocab.id_to_token.at(id).attr; + current = value ? (current | attr) : (current & ~attr); + vocab.id_to_token[id].attr = (llama_token_attr) current; + }; + + auto _set_token_attr = [&] (const std::string & token, llama_token_attr attr, bool value) { + _set_tokenid_attr(vocab.token_to_id.at(token), attr, value); + }; + + std::string model_name; + std::string tokenizer_pre; + + ml.get_key(LLM_KV_GENERAL_NAME, model_name, false); + ml.get_key(LLM_KV_TOKENIZER_PRE, tokenizer_pre, false); + + // model name to lowercase + std::transform(model_name.begin(), model_name.end(), model_name.begin(), + [] (const std::string::value_type x) { + return std::tolower(x); + } + ); + + // set attributes by model/tokenizer name + if (_contains_any(tokenizer_pre, {"jina-v2-es", "jina-v2-de"})) { + _set_token_attr("", LLAMA_TOKEN_ATTR_LSTRIP, true); + } else if (_contains_any(model_name, {"phi-3", "phi3"})) { + for (auto id : vocab.cache_special_tokens) { + _set_tokenid_attr(id, LLAMA_TOKEN_ATTR_RSTRIP, true); + } + for (auto token : {""}) { + _set_token_attr(token, LLAMA_TOKEN_ATTR_RSTRIP, true); + } + for (auto token : {"", "", "<|endoftext|>"}) { + _set_token_attr(token, LLAMA_TOKEN_ATTR_RSTRIP, false); + } } } } @@ -4751,7 +4987,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { LLAMA_LOG_INFO("%s: rope scaling = %s\n", __func__, rope_scaling_type); LLAMA_LOG_INFO("%s: freq_base_train = %.1f\n", __func__, hparams.rope_freq_base_train); LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train); - LLAMA_LOG_INFO("%s: n_yarn_orig_ctx = %u\n", __func__, hparams.n_yarn_orig_ctx); + LLAMA_LOG_INFO("%s: n_ctx_orig_yarn = %u\n", __func__, hparams.n_ctx_orig_yarn); LLAMA_LOG_INFO("%s: rope_finetuned = %s\n", __func__, hparams.rope_finetuned ? "yes" : "unknown"); LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv); LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner); @@ -4791,6 +5027,16 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { if (vocab.special_suffix_id != -1) { LLAMA_LOG_INFO( "%s: SUF token = %d '%s'\n", __func__, vocab.special_suffix_id, vocab.id_to_token[vocab.special_suffix_id].text.c_str() ); } if (vocab.special_middle_id != -1) { LLAMA_LOG_INFO( "%s: MID token = %d '%s'\n", __func__, vocab.special_middle_id, vocab.id_to_token[vocab.special_middle_id].text.c_str() ); } if (vocab.special_eot_id != -1) { LLAMA_LOG_INFO( "%s: EOT token = %d '%s'\n", __func__, vocab.special_eot_id, vocab.id_to_token[vocab.special_eot_id].text.c_str() ); } + + if (model.arch == LLM_ARCH_DEEPSEEK2) { + LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); + LLAMA_LOG_INFO("%s: n_lora_q = %d\n", __func__, hparams.n_lora_q); + LLAMA_LOG_INFO("%s: n_lora_kv = %d\n", __func__, hparams.n_lora_kv); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); + LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); + LLAMA_LOG_INFO("%s: rope_yarn_log_mul = %.4f\n", __func__, hparams.rope_yarn_log_mul); + } } // Returns false if cancelled by progress_callback @@ -4934,6 +5180,7 @@ static bool llm_load_tensors( // create tensors for the weights { const int64_t n_embd = hparams.n_embd; + const int64_t n_embd_head = n_embd / hparams.n_head; const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(); const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(); const int64_t n_embd_gqa = n_embd_v_gqa; @@ -4946,8 +5193,6 @@ static bool llm_load_tensors( throw std::runtime_error("model has expert layers but no expert layers are used"); } - GGML_ASSERT(n_embd_gqa == n_embd_k_gqa); - ggml_context * ctx_input = ctx_map.at(model.buft_input.buft); ggml_context * ctx_output = ctx_map.at(model.buft_output.buft); ggml_context * ctx_output_split = ctx_map.at(model.buft_output.buft_matrix); @@ -4967,14 +5212,10 @@ static bool llm_load_tensors( // output { model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - if (model.arch != LLM_ARCH_MINICPM){ - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, false); - // if output is NULL, init from the input tok embed - if (model.output == NULL) { - model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); - ml.n_created--; // artificial tensor - ml.size_data += ggml_nbytes(model.output); - } + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (model.output == NULL) { + model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); } } @@ -4992,10 +5233,10 @@ static bool llm_load_tensors( layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); // optional bias tensors - layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, false); - layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, false); - layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, false); - layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, false); + layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); @@ -5003,10 +5244,15 @@ static bool llm_load_tensors( layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + + // optional MLP bias + layer.ffn_gate_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_down_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_up_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); } else { layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}); - layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, false); + layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, llama_model_loader::TENSOR_NOT_REQUIRED); if (layer.ffn_gate_exps) { layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}); layer.ffn_up_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}); @@ -5048,12 +5294,10 @@ static bool llm_load_tensors( // output { model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, false); + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); // if output is NULL, init from the input tok embed if (model.output == NULL) { - model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); - ml.n_created--; // artificial tensor - ml.size_data += ggml_nbytes(model.output); + model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); } } @@ -5076,7 +5320,7 @@ static bool llm_load_tensors( layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}); - layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, false); + layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, llama_model_loader::TENSOR_NOT_REQUIRED); if (layer.ffn_gate_exps) { layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}); layer.ffn_up_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}); @@ -5178,11 +5422,9 @@ static bool llm_load_tensors( model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, false); + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); if (!model.output) { - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); // needs to be on GPU - ml.n_created--; // artificial tensor - ml.size_data += ggml_nbytes(model.output); + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // needs to be on GPU } } @@ -5195,8 +5437,8 @@ static bool llm_load_tensors( layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}); - layer.attn_norm_2 = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, false); - layer.attn_norm_2_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, false); + layer.attn_norm_2 = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.attn_norm_2_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}); layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); @@ -5214,12 +5456,10 @@ static bool llm_load_tensors( { model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, false); + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); if (!model.output) { // needs to be on GPU - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); - ml.n_created--; // artificial tensor - ml.size_data += ggml_nbytes(model.output); + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); } } @@ -5249,47 +5489,6 @@ static bool llm_load_tensors( layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}); } } break; - case LLM_ARCH_PERSIMMON: - { - model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); - - { - model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); - } - - for (int i = 0; i < n_layer; ++i) { - ggml_context * ctx_layer = ctx_for_layer(i); - ggml_context * ctx_split = ctx_for_layer_split(i); - - auto & layer = model.layers[i]; - - layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); - layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}); - - layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}); - layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}); - - layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); - layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}); - - layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); - layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}); - - layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); - layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}); - - layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); - layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}); - - layer.attn_q_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {64}); - layer.attn_q_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {64}); - - layer.attn_k_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {64}); - layer.attn_k_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {64}); - } - } break; case LLM_ARCH_BERT: case LLM_ARCH_NOMIC_BERT: { @@ -5335,7 +5534,7 @@ static bool llm_load_tensors( layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}); } else { - layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); + layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); } layer.layer_out_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}); @@ -5358,14 +5557,14 @@ static bool llm_load_tensors( layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}); - layer.attn_q_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, false); - layer.attn_q_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {n_embd}, false); + layer.attn_q_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.attn_q_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}); - layer.attn_k_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, false); - layer.attn_k_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {n_embd}, false); + layer.attn_k_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.attn_k_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}); @@ -5376,6 +5575,9 @@ static bool llm_load_tensors( layer.attn_out_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}); //output_norm layer.attn_out_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd}); + layer.attn_norm_2 = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.attn_norm_2_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); @@ -5427,18 +5629,16 @@ static bool llm_load_tensors( case LLM_ARCH_MPT: { model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); - model.pos_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, hparams.n_ctx_train}, false); + model.pos_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, hparams.n_ctx_train}, llama_model_loader::TENSOR_NOT_REQUIRED); // output { model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, false); + model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, false); + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); if (!model.output) { - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); // needs to be on GPU - ml.n_created--; // artificial tensor - ml.size_data += ggml_nbytes(model.output); + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // needs to be on GPU } } @@ -5449,31 +5649,31 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); - layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, false); + layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}); - layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, false); + layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); - layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, false); + layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); - layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, false); + layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); - layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, false); + layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); - layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, false); + layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.attn_q_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, false); - layer.attn_q_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {n_embd}, false); + layer.attn_q_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.attn_q_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.attn_k_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, false); - layer.attn_k_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {n_embd}, false); + layer.attn_k_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.attn_k_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); // AWQ ScaleActivation layer - layer.ffn_act = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_ACT, "scales", i), {n_ff}, false); + layer.ffn_act = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_ACT, "scales", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); } } break; case LLM_ARCH_STABLELM: @@ -5502,17 +5702,17 @@ static bool llm_load_tensors( layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); // optional bias tensors, present in Stable LM 2 1.6B - layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, false); - layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, false); - layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, false); + layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); // optional q and k layernorms, present in StableLM 2 12B - layer.attn_q_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {hparams.n_embd_head_k, hparams.n_head}, false); - layer.attn_k_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {hparams.n_embd_head_k, hparams.n_head_kv}, false); + layer.attn_q_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {hparams.n_embd_head_k, hparams.n_head}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.attn_k_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {hparams.n_embd_head_k, hparams.n_head_kv}, llama_model_loader::TENSOR_NOT_REQUIRED); // optional FFN norm, not present in StableLM 2 12B which uses parallel residual - layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, false); - layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, false); + layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); @@ -5555,12 +5755,10 @@ static bool llm_load_tensors( // output { model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, false); + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); // if output is NULL, init from the input tok embed if (model.output == NULL) { - model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); - ml.n_created--; // artificial tensor - ml.size_data += ggml_nbytes(model.output); + model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); } } @@ -5658,8 +5856,8 @@ static bool llm_load_tensors( layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}); - layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, false); - layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, false); + layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); if (layer.wqkv == nullptr) { layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); @@ -5696,17 +5894,20 @@ static bool llm_load_tensors( ggml_context* ctx_layer = ctx_for_layer(i); ggml_context* ctx_split = ctx_for_layer_split(i); - auto& layer = model.layers[i]; + auto & layer = model.layers[i]; layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }); - layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, n_embd + 2 * n_embd_gqa }, false); - layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }); + layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, n_embd + 2 * n_embd_gqa }, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }); layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }); layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }); layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, 2 * n_ff }); + + layer.rope_long = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight"), { n_embd_head/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0)); + layer.rope_short = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight"), { n_embd_head/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0)); } } break; case LLM_ARCH_PLAMO: @@ -5875,9 +6076,7 @@ static bool llm_load_tensors( // output model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); // same as tok_embd, duplicated to allow offloading - ml.n_created--; // artificial tensor - ml.size_data += ggml_nbytes(model.output); + model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading const int64_t n_ff = hparams.n_ff; const int64_t n_embd_head_k = hparams.n_embd_head_k; @@ -5912,12 +6111,10 @@ static bool llm_load_tensors( model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, false); + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); // if output is NULL, init from the input tok embed if (model.output == NULL) { - model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); - ml.n_created--; // artificial tensor - ml.size_data += ggml_nbytes(model.output); + model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); } } @@ -5968,12 +6165,10 @@ static bool llm_load_tensors( { model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, false); + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); // if output is NULL, init from the input tok embed, duplicated to allow offloading if (model.output == NULL) { - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); - ml.n_created--; // artificial tensor - ml.size_data += ggml_nbytes(model.output); + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); } } @@ -6034,9 +6229,7 @@ static bool llm_load_tensors( { model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); // init output from the input tok embed - model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); - ml.n_created--; // artificial tensor - ml.size_data += ggml_nbytes(model.output); + model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); } for (int i = 0; i < n_layer; ++i) { @@ -6068,12 +6261,10 @@ static bool llm_load_tensors( // output { - model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, false); + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); // if output is NULL, init from the input tok embed if (model.output == NULL) { - model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); - ml.n_created--; // artificial tensor - ml.size_data += ggml_nbytes(model.output); + model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); } } @@ -6093,6 +6284,145 @@ static bool llm_load_tensors( layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); } } break; + case LLM_ARCH_GPTNEOX: + { + model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + // output + { + model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); + } + + for (int i = 0; i < n_layer; ++i) { + ggml_context * ctx_layer = ctx_for_layer(i); + ggml_context * ctx_split = ctx_for_layer_split(i); + + auto & layer = model.layers[i]; + + layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}); + + layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}); + layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}); + + layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}); + + layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}); + + layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); + layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}); + + layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}); + } + } break; + case LLM_ARCH_ARCTIC: + { + model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + + // output + { + model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (model.output == NULL) { + model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); + } + } + + for (int i = 0; i < n_layer; ++i) { + ggml_context * ctx_layer = ctx_for_layer(i); + ggml_context * ctx_split = ctx_for_layer_split(i); + + auto & layer = model.layers[i]; + + layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + + layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); + layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); + layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); + layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + + layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + + layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_embd}); + layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_embd, n_embd}); + layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_embd}); + + layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}); + layer.ffn_norm_exps = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM_EXPS, "weight", i), {n_embd}); + layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, false); + layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}); + layer.ffn_up_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}); + } + } break; + case LLM_ARCH_DEEPSEEK2: + { + bool is_lite = (hparams.n_layer == 27); + + const uint32_t n_embd_head_qk_rope = hparams.n_rot; + const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot; + const uint32_t q_lora_rank = hparams.n_lora_q; + const uint32_t kv_lora_rank = hparams.n_lora_kv; + const uint32_t n_ff_exp = hparams.n_ff_exp; + + model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + + // output + { + model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}); + } + + for (int i = 0; i < n_layer; ++i) { + ggml_context * ctx_layer = ctx_for_layer(i); + ggml_context * ctx_split = ctx_for_layer_split(i); + + auto & layer = model.layers[i]; + + layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + if (!is_lite) { + layer.attn_q_a_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}); + } + layer.attn_kv_a_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}); + + if (!is_lite) { + layer.wq_a = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}); + layer.wq_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, hparams.n_head * hparams.n_embd_head_k}); + } else { + layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}); + } + layer.wkv_a_mqa = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + n_embd_head_qk_rope}); + layer.wkv_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, hparams.n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)}); + layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {hparams.n_head * hparams.n_embd_head_v, n_embd}); + + layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + + if ((uint32_t) i < hparams.n_layer_dense_lead) { + layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); + layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); + layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + } else { + layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}); + + GGML_ASSERT(hparams.n_expert > 0); + GGML_ASSERT(hparams.n_expert_used > 0); + + // MoE branch + layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}); + layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}); + layer.ffn_up_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}); + + // Shared expert branch + layer.ffn_gate_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * hparams.n_expert_shared}); + layer.ffn_down_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * hparams.n_expert_shared, n_embd}); + layer.ffn_up_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * hparams.n_expert_shared}); + } + } + } break; default: throw std::runtime_error("unknown architecture"); } @@ -6357,10 +6687,7 @@ static struct ggml_tensor * llm_build_inp_embd( inpL = ggml_get_rows(ctx, tok_embd, lctx.inp_tokens); } else { -#ifdef GGML_USE_MPI - GGML_ASSERT(false && "not implemented"); -#endif - lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, batch.n_tokens); + lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, batch.n_tokens); inpL = lctx.inp_embd; ggml_set_input(lctx.inp_embd); } @@ -6550,6 +6877,8 @@ static struct ggml_tensor * llm_build_moe_ffn( int64_t n_expert_used, llm_ffn_op_type type_op, bool norm_w, + bool scale_w, + float w_scale, const llm_build_cb & cb, int il) { int64_t n_embd = cur->ne[0]; @@ -6581,6 +6910,10 @@ static struct ggml_tensor * llm_build_moe_ffn( weights = ggml_reshape_3d(ctx, weights, 1, n_expert_used, n_tokens); } + if (scale_w) { + weights = ggml_scale(ctx, weights, w_scale); + cb(weights, "ffn_moe_weights_scaled", il); + } cur = ggml_reshape_3d(ctx, cur, n_embd, 1, n_tokens); ggml_tensor * up = ggml_mul_mat_id(ctx, up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] @@ -6685,7 +7018,7 @@ static struct ggml_tensor * llm_build_kqv( cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias); - if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3) { + if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX) { ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32); } @@ -6694,7 +7027,7 @@ static struct ggml_tensor * llm_build_kqv( struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); cb(kq, "kq", il); - if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3) { + if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2) { // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847 ggml_mul_mat_set_prec(kq, GGML_PREC_F32); @@ -6823,7 +7156,7 @@ struct llm_build_context { const int32_t n_kv; // size of KV cache to consider (n_kv <= kv_self.size) const int32_t n_outputs; const int32_t kv_head; // index of where we store new KV data in the cache - const int32_t n_orig_ctx; + const int32_t n_ctx_orig; const bool flash_attn; @@ -6872,7 +7205,7 @@ struct llm_build_context { n_kv (worst_case ? kv_self.size : kv_self.n), n_outputs (worst_case ? n_tokens : lctx.n_outputs), kv_head (worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head), - n_orig_ctx (cparams.n_yarn_orig_ctx), + n_ctx_orig (cparams.n_ctx_orig_yarn), flash_attn (cparams.flash_attn), pooling_type (cparams.pooling_type), rope_type (hparams.rope_type), @@ -6919,17 +7252,20 @@ struct llm_build_context { cb(lctx.inp_K_shift, "K_shift", -1); ggml_set_input(lctx.inp_K_shift); + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * rope_factors = build_rope_factors(il); struct ggml_tensor * tmp = // we rotate only the first n_rot dimensions - ggml_rope_custom_inplace(ctx0, + ggml_rope_ext_inplace(ctx0, ggml_view_3d(ctx0, kv_self.k_l[il], n_embd_head_k, n_head_kv, n_ctx, ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k), ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa), 0), - lctx.inp_K_shift, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + lctx.inp_K_shift, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); + cb(tmp, "K_shifted", il); ggml_build_forward_expand(gf, tmp); } @@ -7032,6 +7368,17 @@ struct llm_build_context { return lctx.inp_pos; } + struct ggml_tensor * build_rope_factors(int il) { + // choose long/short freq factors based on the context size + const auto n_ctx_pre_seq = cparams.n_ctx / cparams.n_seq_max; + + if (n_ctx_pre_seq > hparams.n_ctx_orig_yarn) { + return model.layers[il].rope_long; + } + + return model.layers[il].rope_short; + } + struct ggml_tensor * build_inp_out_ids() { lctx.inp_out_ids = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_outputs); cb(lctx.inp_out_ids, "inp_out_ids", -1); @@ -7139,16 +7486,16 @@ struct llm_build_context { cb(Vcur, "Vcur", il); } - Qcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, - n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, - n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Kcur, "Kcur", il); @@ -7177,9 +7524,9 @@ struct llm_build_context { cb(cur, "ffn_norm", il); cur = llm_build_ffn(ctx0, cur, - model.layers[il].ffn_up, NULL, - model.layers[il].ffn_gate, NULL, - model.layers[il].ffn_down, NULL, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, + model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, LLM_FFN_SILU, LLM_FFN_PAR, cb, il); cb(cur, "ffn_out", il); @@ -7197,6 +7544,7 @@ struct llm_build_context { model.layers[il].ffn_down_exps, n_expert, n_expert_used, LLM_FFN_SILU, true, + false, 0.0, cb, il); cb(cur, "ffn_moe_out", il); } @@ -7269,14 +7617,14 @@ struct llm_build_context { switch (model.type) { case MODEL_7B: - Qcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, - n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); - Kcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, - n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); break; @@ -7381,16 +7729,16 @@ struct llm_build_context { struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); cb(Vcur, "Vcur", il); - Qcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, - n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, - n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Kcur, "Kcur", il); @@ -7502,14 +7850,14 @@ struct llm_build_context { Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); // using mode = 2 for neox mode - Qcur = ggml_rope_custom( - ctx0, Qcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx, + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, Kcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx, + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Kcur, "Kcur", il); @@ -7625,16 +7973,16 @@ struct llm_build_context { cb(Vcur, "Vcur", il); } - Qcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, - n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, - n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Kcur, "Kcur", il); @@ -7678,6 +8026,7 @@ struct llm_build_context { model.layers[il].ffn_down_exps, n_expert, n_expert_used, LLM_FFN_GELU, true, + false, 0.0, cb, il); cb(cur, "ffn_moe_out", il); @@ -7777,16 +8126,16 @@ struct llm_build_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - Qcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, - n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, - n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Kcur, "Kcur", il); @@ -7821,6 +8170,7 @@ struct llm_build_context { model.layers[il].ffn_down_exps, n_expert, n_expert_used, LLM_FFN_SILU, true, + false, 0.0, cb, il); cb(cur, "ffn_moe_out", il); @@ -7954,213 +8304,6 @@ struct llm_build_context { return gf; } - struct ggml_cgraph * build_persimmon() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); - - const int64_t n_embd_head = hparams.n_embd_head_v; - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head/2 == hparams.n_rot); - - struct ggml_tensor * cur; - struct ggml_tensor * inpL; - - inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); - - // inp_pos - contains the positions - struct ggml_tensor * inp_pos = build_inp_pos(); - - // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); - - for (int il = 0; il < n_layer; ++il) { - struct ggml_tensor * residual = inpL; - - cur = llm_build_norm(ctx0, inpL, hparams, - model.layers[il].attn_norm, - model.layers[il].attn_norm_b, - LLM_NORM, cb, il); - cb(cur, "attn_norm", il); - - // self attention - { - cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur); - cb(cur, "wqkv", il); - - cur = ggml_add(ctx0, cur, model.layers[il].bqkv); - cb(cur, "bqkv", il); - - // split qkv - GGML_ASSERT(n_head_kv == n_head); - - struct ggml_tensor * tmpqkv = ggml_reshape_4d(ctx0, cur, n_embd_head, 3, n_head, n_tokens); - cb(tmpqkv, "tmpqkv", il); - - struct ggml_tensor * tmpqkv_perm = ggml_cont(ctx0, ggml_permute(ctx0, tmpqkv, 0, 3, 1, 2)); - cb(tmpqkv_perm, "tmpqkv", il); - - struct ggml_tensor * tmpq = ggml_view_3d( - ctx0, tmpqkv_perm, n_embd_head, n_head, n_tokens, - ggml_element_size(tmpqkv_perm) * n_embd_head, - ggml_element_size(tmpqkv_perm) * n_embd_head * n_head, - 0 - ); - cb(tmpq, "tmpq", il); - - struct ggml_tensor * tmpk = ggml_view_3d( - ctx0, tmpqkv_perm, n_embd_head, n_head, n_tokens, - ggml_element_size(tmpqkv_perm) * n_embd_head, - ggml_element_size(tmpqkv_perm) * n_embd_head * n_head, - ggml_element_size(tmpqkv_perm) * n_embd_head * n_head * n_tokens - ); - cb(tmpk, "tmpk", il); - - // Q/K Layernorm - tmpq = llm_build_norm(ctx0, tmpq, hparams, - model.layers[il].attn_q_norm, - model.layers[il].attn_q_norm_b, - LLM_NORM, cb, il); - cb(tmpq, "tmpq", il); - - tmpk = llm_build_norm(ctx0, tmpk, hparams, - model.layers[il].attn_k_norm, - model.layers[il].attn_k_norm_b, - LLM_NORM, cb, il); - cb(tmpk, "tmpk", il); - - // RoPE the first n_rot of q/k, pass the other half, and concat. - struct ggml_tensor * qrot = ggml_view_3d( - ctx0, tmpq, n_rot, n_head, n_tokens, - ggml_element_size(tmpq) * n_embd_head, - ggml_element_size(tmpq) * n_embd_head * n_head, - 0 - ); - cb(qrot, "qrot", il); - - struct ggml_tensor * krot = ggml_view_3d( - ctx0, tmpk, n_rot, n_head, n_tokens, - ggml_element_size(tmpk) * n_embd_head, - ggml_element_size(tmpk) * n_embd_head * n_head, - 0 - ); - cb(krot, "krot", il); - - // get the second half of tmpq, e.g tmpq[n_rot:, :, :] - struct ggml_tensor * qpass = ggml_view_3d( - ctx0, tmpq, n_rot, n_head, n_tokens, - ggml_element_size(tmpq) * n_embd_head, - ggml_element_size(tmpq) * n_embd_head * n_head, - ggml_element_size(tmpq) * n_rot - ); - cb(qpass, "qpass", il); - - struct ggml_tensor * kpass = ggml_view_3d( - ctx0, tmpk, n_rot, n_head, n_tokens, - ggml_element_size(tmpk) * n_embd_head, - ggml_element_size(tmpk) * n_embd_head * n_head, - ggml_element_size(tmpk) * n_rot - ); - cb(kpass, "kpass", il); - - struct ggml_tensor * qrotated = ggml_rope_custom( - ctx0, qrot, inp_pos, n_rot, rope_type, 0, n_orig_ctx, - freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow - ); - cb(qrotated, "qrotated", il); - - struct ggml_tensor * krotated = ggml_rope_custom( - ctx0, krot, inp_pos, n_rot, rope_type, 0, n_orig_ctx, - freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow - ); - cb(krotated, "krotated", il); - - // ggml currently only supports concatenation on dim=2 - // so we need to permute qrot, qpass, concat, then permute back. - qrotated = ggml_cont(ctx0, ggml_permute(ctx0, qrotated, 2, 1, 0, 3)); - cb(qrotated, "qrotated", il); - - krotated = ggml_cont(ctx0, ggml_permute(ctx0, krotated, 2, 1, 0, 3)); - cb(krotated, "krotated", il); - - qpass = ggml_cont(ctx0, ggml_permute(ctx0, qpass, 2, 1, 0, 3)); - cb(qpass, "qpass", il); - - kpass = ggml_cont(ctx0, ggml_permute(ctx0, kpass, 2, 1, 0, 3)); - cb(kpass, "kpass", il); - - struct ggml_tensor * Qcur = ggml_concat(ctx0, qrotated, qpass); - cb(Qcur, "Qcur", il); - - struct ggml_tensor * Kcur = ggml_concat(ctx0, krotated, kpass); - cb(Kcur, "Kcur", il); - - struct ggml_tensor * Q = ggml_cont(ctx0, ggml_permute(ctx0, Qcur, 2, 1, 0, 3)); - cb(Q, "Q", il); - - Kcur = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 2, 1, 0, 3)); - cb(Kcur, "Kcur", il); - - struct ggml_tensor * Vcur = ggml_view_3d( - ctx0, tmpqkv_perm, n_embd_head, n_head, n_tokens, - ggml_element_size(tmpqkv_perm) * n_embd_head, - ggml_element_size(tmpqkv_perm) * n_embd_head * n_head, - ggml_element_size(tmpqkv_perm) * n_embd_head * n_head * n_tokens * 2 - ); - cb(Vcur, "Vcur", il); - - cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, - model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Q, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); - } - - if (il == n_layer - 1) { - // skip computing output for unused tokens - struct ggml_tensor * inp_out_ids = build_inp_out_ids(); - cur = ggml_get_rows(ctx0, cur, inp_out_ids); - residual = ggml_get_rows(ctx0, residual, inp_out_ids); - } - - struct ggml_tensor * ffn_inp = ggml_add(ctx0, residual, cur); - cb(ffn_inp, "ffn_inp", il); - - // feed-forward network - { - cur = llm_build_norm(ctx0, ffn_inp, hparams, - model.layers[il].ffn_norm, - model.layers[il].ffn_norm_b, - LLM_NORM, cb, il); - cb(cur, "ffn_norm", il); - - cur = llm_build_ffn(ctx0, cur, - model.layers[il].ffn_up, model.layers[il].ffn_up_b, - NULL, NULL, - model.layers[il].ffn_down, model.layers[il].ffn_down_b, - NULL, - LLM_FFN_RELU_SQR, LLM_FFN_SEQ, cb, il); - cb(cur, "ffn_out", il); - } - - cur = ggml_add(ctx0, cur, ffn_inp); - cb(cur, "l_out", il); - - inpL = cur; - } - - cur = inpL; - - cur = llm_build_norm(ctx0, cur, hparams, - model.output_norm, - model.output_norm_b, - LLM_NORM, cb, -1); - cb(cur, "result_norm", -1); - - cur = ggml_mul_mat(ctx0, model.output, cur); - cb(cur, "result_output", -1); - - ggml_build_forward_expand(gf, cur); - - return gf; - } - struct ggml_cgraph * build_refact() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); @@ -8337,16 +8480,16 @@ struct llm_build_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - Qcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, - n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, - n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Kcur, "Kcur", il); @@ -8398,6 +8541,11 @@ struct llm_build_context { // attention layer norm cur = llm_build_norm(ctx0, cur, hparams, model.layers[il].attn_out_norm, model.layers[il].attn_out_norm_b, LLM_NORM, cb, il); + if (model.layers[il].attn_norm_2 != nullptr) { + cur = ggml_add(ctx0, cur, inpL); // re-add the layer input + cur = llm_build_norm(ctx0, cur, hparams, model.layers[il].attn_norm_2, model.layers[il].attn_norm_2_b, LLM_NORM, cb, il); + } + struct ggml_tensor * ffn_inp = cur; cb(ffn_inp, "ffn_inp", il); @@ -8777,16 +8925,16 @@ struct llm_build_context { } - Qcur = ggml_rope_custom( - ctx0, Qcur, inp_pos, - n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, Kcur, inp_pos, - n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Kcur, "Kcur", il); @@ -8897,14 +9045,14 @@ struct llm_build_context { Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); // using mode = 2 for neox mode - Qcur = ggml_rope_custom( - ctx0, Qcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx, + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, Kcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx, + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Kcur, "Kcur", il); @@ -9008,16 +9156,16 @@ struct llm_build_context { Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); cb(Vcur, "Vcur", il); - Qcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, - n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, - n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Kcur, "Kcur", il); @@ -9122,16 +9270,16 @@ struct llm_build_context { Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); cb(Vcur, "Vcur", il); - Qcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, - n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, - n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Kcur, "Kcur", il); @@ -9166,6 +9314,7 @@ struct llm_build_context { model.layers[il].ffn_down_exps, n_expert, n_expert_used, LLM_FFN_SILU, false, + false, 0.0, cb, il); cb(cur, "ffn_moe_out", il); @@ -9274,8 +9423,8 @@ struct llm_build_context { Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Qcur = ggml_rope_custom( - ctx0, Qcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx, + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); @@ -9285,8 +9434,8 @@ struct llm_build_context { Qcur = ggml_scale(ctx0, Qcur, 1.0f/sqrtf(float(n_embd_head))); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, Kcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx, + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Kcur, "Kcur", il); @@ -9362,6 +9511,9 @@ struct llm_build_context { // self-attention { + // rope freq factors for 128k context + struct ggml_tensor * rope_factors = build_rope_factors(il); + struct ggml_tensor* attn_norm_output = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, @@ -9393,8 +9545,8 @@ struct llm_build_context { Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Qcur = ggml_rope_custom( - ctx0, Qcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx, + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); @@ -9402,8 +9554,8 @@ struct llm_build_context { Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head))); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, Kcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx, + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Kcur, "Kcur", il); @@ -9509,15 +9661,15 @@ struct llm_build_context { struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); cb(Vcur, "Vcur", il); - Qcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Qcur, n_rot, n_head, n_tokens), inp_pos, - n_embd_head, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_rot, n_head, n_tokens), inp_pos, nullptr, + n_embd_head, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Kcur, n_rot, n_head_kv, n_tokens), inp_pos, - n_embd_head, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_rot, n_head_kv, n_tokens), inp_pos, nullptr, + n_embd_head, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); cb(Kcur, "Kcur", il); @@ -9717,16 +9869,16 @@ struct llm_build_context { cb(tmpk, "tmpk", il); cb(Vcur, "Vcur", il); - struct ggml_tensor * Qcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, n_tokens), inp_pos, - n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + struct ggml_tensor * Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, n_tokens), inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); - struct ggml_tensor * Kcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, n_tokens), inp_pos, - n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + struct ggml_tensor * Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Kcur, "Kcur", il); @@ -9833,16 +9985,16 @@ struct llm_build_context { // cb(Vcur, "Vcur", il); // } - Qcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, - n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, - n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Kcur, "Kcur", il); @@ -9950,16 +10102,16 @@ struct llm_build_context { cb(Vcur, "Vcur", il); } - Qcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, - n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, - n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Kcur, "Kcur", il); @@ -10080,16 +10232,16 @@ struct llm_build_context { cb(Vcur, "Vcur", il); } - Qcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, - n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, - n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Kcur, "Kcur", il); @@ -10154,7 +10306,7 @@ struct llm_build_context { cb(cur, "lmhead_scaling", -1); // lm_head - cur = ggml_mul_mat(ctx0, model.tok_embd, cur); + cur = ggml_mul_mat(ctx0, model.output, cur); cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); @@ -10200,18 +10352,18 @@ struct llm_build_context { struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); cb(Vcur, "Vcur", il); - Qcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head, n_tokens), inp_pos, - n_embd_head_k, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head, n_tokens), inp_pos, nullptr, + n_embd_head_k, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); cb(Qcur, "Qcur", il); Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head_k))); cb(Qcur, "Qcur_scaled", il); - Kcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens), inp_pos, - n_embd_head_k, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens), inp_pos, nullptr, + n_embd_head_k, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); cb(Kcur, "Kcur", il); @@ -10320,16 +10472,16 @@ struct llm_build_context { cb(Vcur, "Vcur", il); } - Qcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, - n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, - n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Kcur, "Kcur", il); @@ -10610,16 +10762,16 @@ struct llm_build_context { cb(Kcur, "Kcur", il); } - Qcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, - n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, - n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Kcur, "Kcur", il); @@ -10741,16 +10893,16 @@ struct llm_build_context { cb(Vcur, "Vcur", il); } - Qcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, - n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, - n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Kcur, "Kcur", il); @@ -10813,6 +10965,508 @@ struct llm_build_context { return gf; } + + struct ggml_cgraph * build_gptneox() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + + const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = build_inp_pos(); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + + for (int il = 0; il < n_layer; ++il) { + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, + model.layers[il].attn_norm_b, + LLM_NORM, cb, il); + cb(cur, "attn_norm", il); + + // self-attention + { + cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur); + cb(cur, "wqkv", il); + + cur = ggml_add(ctx0, cur, model.layers[il].bqkv); + cb(cur, "bqkv", il); + + struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); + struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Qcur, "Qcur", il); + + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Kcur, "Kcur", il); + + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + model.layers[il].wo, model.layers[il].bo, + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + + // ffn + if (hparams.use_par_res) { + // attention and ffn are computed in parallel + // x = x + attn(ln1(x)) + ffn(ln2(x)) + + struct ggml_tensor * attn_out = cur; + + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].ffn_norm, + model.layers[il].ffn_norm_b, + LLM_NORM, cb, il); + cb(cur, "ffn_norm", il); + + cur = llm_build_ffn(ctx0, cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, + NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, + NULL, + LLM_FFN_GELU, LLM_FFN_SEQ, cb, il); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, inpL); + cb(cur, "ffn_out", il); + + inpL = ggml_add(ctx0, cur, attn_out); + cb(inpL, "l_out", il); + } else { + // attention and ffn are computed sequentially + // x = x + attn(ln1(x)) + // x = x + ffn(ln2(x)) + + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL); + cb(ffn_inp, "ffn_inp", il); + + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, + model.layers[il].ffn_norm_b, + LLM_NORM, cb, il); + cb(cur, "ffn_norm", il); + + cur = llm_build_ffn(ctx0, cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, + NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, + NULL, + LLM_FFN_GELU, LLM_FFN_SEQ, cb, il); + cb(cur, "ffn_out", il); + + inpL = ggml_add(ctx0, cur, ffn_inp); + cb(inpL, "l_out", il); + } + } + + cur = llm_build_norm(ctx0, inpL, hparams, + model.output_norm, + model.output_norm_b, + LLM_NORM, cb, -1); + cb(cur, "result_norm", -1); + + cur = ggml_mul_mat(ctx0, model.output, cur); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } + + struct ggml_cgraph * build_arctic() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + + // mutable variable, needed during the last layer of the computation to skip unused tokens + int32_t n_tokens = this->n_tokens; + + const int64_t n_embd_head = hparams.n_embd_head_v; + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = build_inp_pos(); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; + + // norm + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // compute Q and K and RoPE them + struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Qcur, "Qcur", il); + + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Kcur, "Kcur", il); + + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + model.layers[il].wo, NULL, + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + n_tokens = n_outputs; + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + cur = llm_build_ffn(ctx0, cur, + model.layers[il].ffn_up, NULL, + model.layers[il].ffn_gate, NULL, + model.layers[il].ffn_down, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + cb(cur, "ffn_out", il); + + struct ggml_tensor * ffn_out = ggml_add(ctx0, cur, ffn_inp); + cb(ffn_out, "ffn_out", il); + + // MoE + cur = llm_build_norm(ctx0, inpSA, hparams, + model.layers[il].ffn_norm_exps, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm_exps", il); + + cur = llm_build_moe_ffn(ctx0, cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + false, 0.0, + cb, il); + cb(cur, "ffn_moe_out", il); + + cur = ggml_add(ctx0, cur, ffn_out); + cb(cur, "ffn_out", il); + + ggml_tensor * layer_dir = lctx.cvec.tensor_for(il); + if (layer_dir != nullptr) { + cur = ggml_add(ctx0, cur, layer_dir); + } + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = llm_build_norm(ctx0, cur, hparams, + model.output_norm, NULL, + LLM_NORM_RMS, cb, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = ggml_mul_mat(ctx0, model.output, cur); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } + + struct ggml_cgraph * build_deepseek2() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + + // mutable variable, needed during the last layer of the computation to skip unused tokens + int32_t n_tokens = this->n_tokens; + + bool is_lite = (hparams.n_layer == 27); + + // We have to pre-scale kq_scale and attn_factor to make the YaRN RoPE work correctly. + // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation. + const float mscale = attn_factor * (1.0f + hparams.rope_yarn_log_mul * logf(1.0f / freq_scale)); + const float kq_scale = 1.0f*mscale*mscale/sqrtf(float(hparams.n_embd_head_k)); + const float attn_factor_scaled = 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)); + + const uint32_t n_embd_head_qk_rope = hparams.n_rot; + const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot; + const uint32_t kv_lora_rank = hparams.n_lora_kv; + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + // {n_embd, n_tokens} + inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = build_inp_pos(); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; + + // norm + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm", il); + + // self_attention + { + struct ggml_tensor * q = NULL; + if (!is_lite) { + // {n_embd, q_lora_rank} * {n_embd, n_tokens} -> {q_lora_rank, n_tokens} + q = ggml_mul_mat(ctx0, model.layers[il].wq_a, cur); + cb(q, "q", il); + + q = llm_build_norm(ctx0, q, hparams, + model.layers[il].attn_q_a_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(q, "q", il); + + // {q_lora_rank, n_head * hparams.n_embd_head_k} * {q_lora_rank, n_tokens} -> {n_head * hparams.n_embd_head_k, n_tokens} + q = ggml_mul_mat(ctx0, model.layers[il].wq_b, q); + cb(q, "q", il); + } else { + q = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + cb(q, "q", il); + } + + // split into {n_head * n_embd_head_qk_nope, n_tokens} + struct ggml_tensor * q_nope = ggml_view_3d(ctx0, q, n_embd_head_qk_nope, n_head, n_tokens, + ggml_row_size(q->type, hparams.n_embd_head_k), + ggml_row_size(q->type, hparams.n_embd_head_k * n_head), + 0); + cb(q_nope, "q_nope", il); + + // and {n_head * n_embd_head_qk_rope, n_tokens} + struct ggml_tensor * q_pe = ggml_view_3d(ctx0, q, n_embd_head_qk_rope, n_head, n_tokens, + ggml_row_size(q->type, hparams.n_embd_head_k), + ggml_row_size(q->type, hparams.n_embd_head_k * n_head), + ggml_row_size(q->type, n_embd_head_qk_nope)); + cb(q_pe, "q_pe", il); + + // {n_embd, kv_lora_rank + n_embd_head_qk_rope} * {n_embd, n_tokens} -> {kv_lora_rank + n_embd_head_qk_rope, n_tokens} + struct ggml_tensor * kv_pe_compresseed = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur); + cb(kv_pe_compresseed, "kv_pe_compresseed", il); + + // split into {kv_lora_rank, n_tokens} + struct ggml_tensor * kv_compressed = ggml_view_2d(ctx0, kv_pe_compresseed, kv_lora_rank, n_tokens, + kv_pe_compresseed->nb[1], + 0); + cb(kv_compressed, "kv_compressed", il); + + // and {n_embd_head_qk_rope, n_tokens} + struct ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_pe_compresseed, n_embd_head_qk_rope, 1, n_tokens, + kv_pe_compresseed->nb[1], + kv_pe_compresseed->nb[1], + ggml_row_size(kv_pe_compresseed->type, kv_lora_rank)); + cb(k_pe, "k_pe", il); + + kv_compressed = ggml_cont(ctx0, kv_compressed); // TODO: the CUDA backend does not support non-contiguous norm + kv_compressed = llm_build_norm(ctx0, kv_compressed, hparams, + model.layers[il].attn_kv_a_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(kv_compressed, "kv_compressed", il); + + // {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)} * {kv_lora_rank, n_tokens} -> {n_head * (n_embd_head_qk_nope + n_embd_head_v), n_tokens} + struct ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_compressed); + cb(kv, "kv", il); + + // split into {n_head * n_embd_head_qk_nope, n_tokens} + struct ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens, + ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v), + ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)), + 0); + cb(k_nope, "k_nope", il); + + // and {n_head * n_embd_head_v, n_tokens} + struct ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens, + ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)), + ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head), + ggml_row_size(kv->type, (n_embd_head_qk_nope))); + cb(v_states, "v_states", il); + + v_states = ggml_cont(ctx0, v_states); + cb(v_states, "v_states", il); + + v_states = ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v * n_head, n_tokens, + ggml_row_size(kv->type, hparams.n_embd_head_v * n_head), + 0); + cb(v_states, "v_states", il); + + q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend does not support non-contiguous RoPE + q_pe = ggml_rope_ext( + ctx0, q_pe, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor_scaled, beta_fast, beta_slow + ); + cb(q_pe, "q_pe", il); + + // shared RoPE key + k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend does not support non-contiguous RoPE + k_pe = ggml_rope_ext( + ctx0, k_pe, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor_scaled, beta_fast, beta_slow + ); + cb(k_pe, "k_pe", il); + + struct ggml_tensor * q_states = ggml_concat(ctx0, q_nope, q_pe, 0); + cb(q_states, "q_states", il); + + struct ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0); + cb(k_states, "k_states", il); + + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + model.layers[il].wo, NULL, + k_states, v_states, q_states, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + n_tokens = n_outputs; + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + if ((uint32_t) il < hparams.n_layer_dense_lead) { + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + cur = llm_build_ffn(ctx0, cur, + model.layers[il].ffn_up, NULL, + model.layers[il].ffn_gate, NULL, + model.layers[il].ffn_down, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + cb(cur, "ffn_out", il); + } else { + // MoE branch + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + ggml_tensor * moe_out = + llm_build_moe_ffn(ctx0, cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + n_expert, n_expert_used, + LLM_FFN_SILU, false, + true, hparams.expert_weights_scale, + cb, il); + cb(moe_out, "ffn_moe_out", il); + + // FFN shared expert + { + ggml_tensor * ffn_shexp = llm_build_ffn(ctx0, cur, + model.layers[il].ffn_up_shexp, NULL, + model.layers[il].ffn_gate_shexp, NULL, + model.layers[il].ffn_down_shexp, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + cb(ffn_shexp, "ffn_shexp", il); + + cur = ggml_add(ctx0, moe_out, ffn_shexp); + cb(cur, "ffn_out", il); + } + } + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = llm_build_norm(ctx0, cur, hparams, + model.output_norm, NULL, + LLM_NORM_RMS, cb, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = ggml_mul_mat(ctx0, model.output, cur); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } + }; static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector & ids) { @@ -10929,10 +11583,6 @@ static struct ggml_cgraph * llama_build_graph( { result = llm.build_starcoder(); } break; - case LLM_ARCH_PERSIMMON: - { - result = llm.build_persimmon(); - } break; case LLM_ARCH_REFACT: { result = llm.build_refact(); @@ -11027,6 +11677,18 @@ static struct ggml_cgraph * llama_build_graph( { result = llm.build_olmo(); } break; + case LLM_ARCH_GPTNEOX: + { + result = llm.build_gptneox(); + } break; + case LLM_ARCH_ARCTIC: + { + result = llm.build_arctic(); + } break; + case LLM_ARCH_DEEPSEEK2: + { + result = llm.build_deepseek2(); + } break; default: GGML_ASSERT(false); } @@ -11372,11 +12034,6 @@ static void llama_graph_compute( llama_context & lctx, ggml_cgraph * gf, int n_threads) { -#ifdef GGML_USE_MPI - const int64_t n_layer = lctx.model.hparams.n_layer; - ggml_mpi_graph_compute_pre(lctx.ctx_mpi, gf, n_layer); -#endif - #ifdef GGML_USE_METAL if (ggml_backend_is_metal(lctx.backend_metal)) { ggml_backend_metal_set_n_cb(lctx.backend_metal, n_threads); @@ -11391,10 +12048,6 @@ static void llama_graph_compute( ggml_backend_sched_graph_compute_async(lctx.sched, gf); // fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(lctx.sched)); - -#ifdef GGML_USE_MPI - ggml_mpi_graph_compute_post(lctx.ctx_mpi, gf, n_layer); -#endif } // decode a batch of tokens by evaluating the transformer @@ -11432,12 +12085,6 @@ static int llama_decode_internal( } lctx.n_queued_tokens += n_tokens_all; -#ifdef GGML_USE_MPI - // TODO: needs fix after #3228 - GGML_ASSERT(false && "not implemented"); - //ggml_mpi_eval_init(lctx.ctx_mpi, &n_tokens, &n_past, &n_threads); -#endif - auto & kv_self = lctx.kv_self; const int64_t n_embd = hparams.n_embd; @@ -12058,27 +12705,27 @@ static enum llama_vocab_type llama_vocab_get_type(const llama_vocab & vocab) { static bool llama_is_normal_token(const llama_vocab & vocab, llama_token id) { GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE); - return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_NORMAL; + return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_NORMAL; } static bool llama_is_unknown_token(const llama_vocab & vocab, llama_token id) { GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE); - return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_UNKNOWN; + return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_UNKNOWN; } static bool llama_is_control_token(const llama_vocab & vocab, llama_token id) { GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE); - return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_CONTROL; + return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_CONTROL; } static bool llama_is_byte_token(const llama_vocab & vocab, llama_token id) { GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE); - return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_BYTE; + return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_BYTE; } static bool llama_is_user_defined_token(const llama_vocab& vocab, llama_token id) { GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE); - return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_USER_DEFINED; + return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_USER_DEFINED; } static uint8_t llama_token_to_byte(const llama_vocab& vocab, llama_token id) { @@ -12331,6 +12978,7 @@ struct llm_tokenizer_bpe { }); break; case LLAMA_VOCAB_PRE_TYPE_DBRX: + case LLAMA_VOCAB_PRE_TYPE_SMAUG: word_collection = unicode_regex_split(text, { // same as llama3 "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", @@ -12387,6 +13035,7 @@ struct llm_tokenizer_bpe { "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", }); break; + case LLAMA_VOCAB_PRE_TYPE_STABLELM2: case LLAMA_VOCAB_PRE_TYPE_QWEN2: word_collection = unicode_regex_split(text, { // original regex from tokenizer.json @@ -12552,7 +13201,7 @@ struct llm_tokenizer_wpm { llm_tokenizer_wpm(const llama_vocab & vocab): vocab(vocab) {} void tokenize(const std::string & text, std::vector & output) { - auto * token_map = &vocab.token_to_id; + const auto & token_map = vocab.token_to_id; // normalize and split by whitespace std::vector words = preprocess(text); @@ -12567,108 +13216,89 @@ struct llm_tokenizer_wpm { } // prepend phantom space - std::string word1 = "\xe2\x96\x81" + word; - int n = word1.size(); + const std::string word1 = "\xe2\x96\x81" + word; + const int n = word1.size(); + + const size_t current_tokens = output.size(); // we're at the start of a new word - int i = 0; - bool match_any = false; - // move through character position in word - while (i < n) { + for (int i = 0; i < n; ++i) { // loop through possible match length bool match = false; for (int j = n; j > i; j--) { - auto it = token_map->find(word1.substr(i, j - i)); - if (it != token_map->end()) { + auto it = token_map.find(word1.substr(i, j - i)); + if (it != token_map.end()) { output.push_back(it->second); match = true; - match_any = true; - i = j; + i = j - 1; break; } } - // must be an unknown character - if (!match) { - i++; + if (!match) { // discard all + output.resize(current_tokens); + break; // and discard next tokens } } // we didn't find any matches for this word - if (!match_any) { + if (current_tokens == output.size()) { output.push_back(vocab.special_unk_id); } } } std::vector preprocess(const std::string & text) { - std::vector cpts_nfd = unicode_cpts_normalize_nfd(unicode_cpts_from_utf8(text)); + const std::vector cpts_nfd = unicode_cpts_normalize_nfd(unicode_cpts_from_utf8(text)); + std::vector words(1, ""); - // strip accents, strip control, uniformize whitespace, - // to lowercase, pad chinese characters, pad punctuation - std::string new_str = ""; - for (uint32_t code : cpts_nfd) { - const codepoint_flags flags = unicode_cpt_flags(code); - if (flags.is_accent_mark || flags.is_control) { + for (const char32_t cpt : cpts_nfd) { + const auto flags = unicode_cpt_flags(cpt); + + if (flags.is_whitespace) { + if (words.back().size()) { // finish previous word if any + words.emplace_back(); + } continue; } - code = unicode_tolower(code); - if (flags.is_separator || flags.is_whitespace) { //####FIXME: is_separator ? - code = ' '; + + assert (!flags.is_separator); + if (cpt == 0 || cpt == 0xFFFD || flags.is_control) { + continue; } - std::string s = unicode_cpt_to_utf8(code); - if (flags.is_punctuation || is_ascii_punct(code) || is_chinese_char(code)) { - new_str += " "; - new_str += s; - new_str += " "; + + const std::string s = unicode_cpt_to_utf8(unicode_tolower(cpt)); + if (flags.is_punctuation || ( cpt < 0x7F && flags.is_symbol ) || is_chinese_char(cpt)) { + if (words.back().size()) { // finish previous word if any + words.emplace_back(); + } + words.back() = s; // single char word + words.emplace_back(); // start a new word } else { - new_str += s; + words.back() += s; // append char to word } } - // split by whitespace - uint64_t l = 0; - uint64_t r = 0; - std::vector words; - while (r < new_str.size()) { - // if is whitespace - if (isspace(new_str[r], std::locale::classic())) { - if (r > l) words.push_back(new_str.substr(l, (r - l))); - l = r + 1; - r = l; - } else { - r += 1; - } - } - if (r > l) { - words.push_back(new_str.substr(l, (r - l))); + if (!words.back().size()) { + words.pop_back(); } + return words; } - bool is_ascii_punct(uint32_t code) { - if (code > 0xFF) { - return false; - } - auto c = char(static_cast(code)); - return ispunct(c, std::locale::classic()); - } - - bool is_chinese_char(uint32_t cpt) { - if ((cpt >= 0x4E00 && cpt <= 0x9FFF) || - (cpt >= 0x3400 && cpt <= 0x4DBF) || + static bool is_chinese_char(uint32_t cpt) { + return + (cpt >= 0x04E00 && cpt <= 0x09FFF) || + (cpt >= 0x03400 && cpt <= 0x04DBF) || (cpt >= 0x20000 && cpt <= 0x2A6DF) || (cpt >= 0x2A700 && cpt <= 0x2B73F) || (cpt >= 0x2B740 && cpt <= 0x2B81F) || (cpt >= 0x2B920 && cpt <= 0x2CEAF) || // this should be 0x2B820 but in hf rust code it is 0x2B920 - (cpt >= 0xF900 && cpt <= 0xFAFF) || - (cpt >= 0x2F800 && cpt <= 0x2FA1F) || - (cpt >= 0x3000 && cpt <= 0x303F) || - (cpt >= 0xFF00 && cpt <= 0xFFEF)) { - return true; // NOLINT - } - return false; + (cpt >= 0x0F900 && cpt <= 0x0FAFF) || + (cpt >= 0x2F800 && cpt <= 0x2FA1F); + //(cpt >= 0x3000 && cpt <= 0x303F) || + //(cpt >= 0xFF00 && cpt <= 0xFFEF); } const llama_vocab & vocab; @@ -12712,9 +13342,9 @@ struct fragment_buffer_variant { static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list & buffer) { // for each special token - for (const auto & st: vocab.special_tokens_cache) { - const auto & special_token = st.first; - const auto & special_id = st.second; + for (const llama_vocab::id special_id : vocab.cache_special_tokens) { + const auto & data = vocab.id_to_token[special_id]; + const auto & special_token = data.text; // for each text fragment std::forward_list::iterator it = buffer.begin(); @@ -12723,7 +13353,7 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list< // if a fragment is text ( not yet processed ) if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { - auto * raw_text = &(fragment.raw_text); + auto & raw_text = fragment.raw_text; auto raw_text_base_offset = fragment.offset; auto raw_text_base_length = fragment.length; @@ -12733,7 +13363,7 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list< // find the first occurrence of a given special token in this fragment // passing offset argument only limit the "search area" but match coordinates // are still relative to the source full raw_text - auto match = raw_text->find(special_token, raw_text_base_offset); + auto match = raw_text.find(special_token, raw_text_base_offset); // no occurrences found, stop processing this fragment for a given special token if (match == std::string::npos) break; @@ -12751,13 +13381,22 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list< if (match > raw_text_base_offset) { // left const int64_t left_reminder_offset = raw_text_base_offset + 0; - const int64_t left_reminder_length = match - raw_text_base_offset; - buffer.emplace_after(it, (*raw_text), left_reminder_offset, left_reminder_length); + int64_t left_reminder_length = match - raw_text_base_offset; + + if (data.attr & LLAMA_TOKEN_ATTR_LSTRIP) { + while (left_reminder_length > 0 && isspace(raw_text[left_reminder_offset + left_reminder_length - 1])) { + left_reminder_length--; + } + } + + if (left_reminder_length > 0) { + buffer.emplace_after(it, raw_text, left_reminder_offset, left_reminder_length); + it++; + } #ifdef PRETOKENIZERDEBUG LLAMA_LOG_WARN("FL: (%ld %ld) '%s'\n", left_reminder_offset, left_reminder_length, raw_text->substr(left_reminder_offset, left_reminder_length).c_str()); #endif - it++; } // special token @@ -12766,16 +13405,25 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list< // right if (match + special_token.length() < raw_text_base_offset + raw_text_base_length) { - const int64_t right_reminder_offset = match + special_token.length(); - const int64_t right_reminder_length = raw_text_base_length - ((match - raw_text_base_offset) + special_token.length()); - buffer.emplace_after(it, (*raw_text), right_reminder_offset, right_reminder_length); + int64_t right_reminder_offset = match + special_token.length(); + int64_t right_reminder_length = raw_text_base_length - ((match - raw_text_base_offset) + special_token.length()); + + if (data.attr & LLAMA_TOKEN_ATTR_RSTRIP) { + while (right_reminder_length > 0 && isspace(raw_text[right_reminder_offset])) { + right_reminder_offset++; + right_reminder_length--; + } + } + + if (right_reminder_length > 0) { + buffer.emplace_after(it, raw_text, right_reminder_offset, right_reminder_length); + it++; + } #ifdef PRETOKENIZERDEBUG LLAMA_LOG_WARN("FR: (%ld %ld) '%s'\n", right_reminder_offset, right_reminder_length, raw_text->substr(right_reminder_offset, right_reminder_length).c_str()); #endif - it++; - if (source == 0) { buffer.erase_after(buffer.before_begin()); } else { @@ -12821,23 +13469,21 @@ static std::vector llama_tokenize_internal(const llama_vocab & // tokenizer.encode('', add_special_tokens=True) returns [1] // tokenizer.encode('', add_special_tokens=False) returns [] + bool is_prev_special = false; + if (add_special && vocab.special_add_bos != 0) { GGML_ASSERT(vocab.special_bos_id != -1); output.push_back(vocab.special_bos_id); + is_prev_special = true; } for (const auto & fragment : fragment_buffer) { if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { - // without adding this leading whitespace, we do not get the same results as the original tokenizer - - // TODO: It's likely possible to get rid of this string copy entirely - // by modifying llm_tokenizer_x to operate with string offsets like pre-tokenizer - // and passing 'add space prefix' as bool argument - // auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length); - if (&fragment == &fragment_buffer.front()) { - if (vocab.add_space_prefix) { - raw_text = " " + raw_text; // prefix with space if the first token is not special + + if (vocab.add_space_prefix) { + if (!output.size() || is_prev_special) { // prefix with space if first token + raw_text = " " + raw_text; } } @@ -12849,6 +13495,7 @@ static std::vector llama_tokenize_internal(const llama_vocab & tokenizer.tokenize(raw_text, output); } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN) output.push_back(fragment.token); + is_prev_special = true; } } @@ -13011,7 +13658,7 @@ static std::pair llama_grammar_match_char( const uint32_t chr) { bool found = false; - bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR; + bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR || pos->type == LLAMA_GRETYPE_CHAR_ANY; GGML_ASSERT(is_positive_char || pos->type == LLAMA_GRETYPE_CHAR_NOT); // NOLINT @@ -13020,6 +13667,10 @@ static std::pair llama_grammar_match_char( // inclusive range, e.g. [a-z] found = found || (pos->value <= chr && chr <= pos[1].value); pos += 2; + } else if (pos->type == LLAMA_GRETYPE_CHAR_ANY) { + // Any character matches "." + found = true; + pos += 1; } else { // exact char match, e.g. [a] or "a" found = found || pos->value == chr; @@ -13037,7 +13688,7 @@ static bool llama_grammar_match_partial_char( const llama_grammar_element * pos, const llama_partial_utf8 partial_utf8) { - bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR; + bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR || pos->type == LLAMA_GRETYPE_CHAR_ANY; GGML_ASSERT(is_positive_char || pos->type == LLAMA_GRETYPE_CHAR_NOT); uint32_t partial_value = partial_utf8.value; @@ -13067,6 +13718,9 @@ static bool llama_grammar_match_partial_char( return is_positive_char; } pos += 2; + } else if (pos->type == LLAMA_GRETYPE_CHAR_ANY) { + // Any character matches "." + return true; } else { // exact char match, e.g. [a] or "a" if (low <= pos->value && pos->value <= high) { @@ -13127,6 +13781,7 @@ static void llama_grammar_advance_stack( } case LLAMA_GRETYPE_CHAR: case LLAMA_GRETYPE_CHAR_NOT: + case LLAMA_GRETYPE_CHAR_ANY: if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) { // only add the stack if it's not a duplicate of one we already have new_stacks.emplace_back(stack); @@ -13849,7 +14504,7 @@ void llama_sample_repetition_penalties( void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar) { GGML_ASSERT(ctx); - const int64_t t_start_sample_us = ggml_time_us(); + int64_t t_start_sample_us = ggml_time_us(); bool allow_eog = false; for (const auto & stack : grammar->stacks) { @@ -13861,12 +14516,13 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c std::vector, llama_partial_utf8>> candidates_decoded; candidates_decoded.reserve(candidates->size); - std::vector candidates_grammar; + + std::vector candidates_grammar; candidates_grammar.reserve(candidates->size); for (size_t i = 0; i < candidates->size; ++i) { - const llama_token id = candidates->data[i].id; - const std::string piece = llama_token_to_piece(ctx, id, false); + const llama_token id = candidates->data[i].id; + const std::string & piece = ctx->model.vocab.cache_token_to_piece.at(id); if (llama_token_is_eog(&ctx->model, id)) { if (!allow_eog) { @@ -14066,7 +14722,7 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar GGML_ASSERT(false); } - const std::string piece = llama_token_to_piece(ctx, token, false); + const std::string & piece = ctx->model.vocab.cache_token_to_piece.at(token); // Note terminating 0 in decoded string const auto decoded = decode_utf8(piece, grammar->partial_utf8); @@ -14082,260 +14738,6 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar ctx->t_sample_us += ggml_time_us() - t_start_sample_us; } -// -// Beam search -// - -struct llama_beam { - std::vector tokens; - float p; // Cumulative beam probability (renormalized relative to all beams) - bool eob; // Initialize end-of-beam to false. Callback sets this to true. - // Sort beams by probability. In case of ties, prefer beams at eob. - bool operator<(const llama_beam & rhs) const { - return std::make_pair(p, eob) < std::make_pair(rhs.p, rhs.eob); - } - // Shift off first n tokens and discard them. - void shift_tokens(const size_t n) { - if (n) { - std::copy(tokens.begin() + n, tokens.end(), tokens.begin()); - tokens.resize(tokens.size() - n); - } - } - llama_beam_view view() const { return {tokens.data(), tokens.size(), p, eob}; } -}; - -// A struct for calculating logit-related info. -struct llama_logit_info { - const float * const logits; - const int n_vocab; - const float max_l; - const float normalizer; - struct sum_exp { - float max_l; - float operator()(float sum, float l) const { return sum + std::exp(l - max_l); } - }; - llama_logit_info(llama_context * ctx) - : logits(llama_get_logits(ctx)) - , n_vocab(llama_n_vocab(llama_get_model(ctx))) - , max_l(*std::max_element(logits, logits + n_vocab)) - , normalizer(1.0f / std::accumulate(logits, logits + n_vocab, 0.0f, sum_exp{max_l})) - { } - llama_token_data get_token_data(const llama_token token_id) const { - constexpr auto p = std::numeric_limits::quiet_NaN(); // never used - return {token_id, logits[token_id], p}; - } - // Return top k token_data by logit. - std::vector top_k(size_t k) { - std::vector min_heap; // min-heap by logit - const llama_token k_min = std::min(static_cast(k), n_vocab); - min_heap.reserve(k_min); - for (llama_token token_id = 0 ; token_id < k_min ; ++token_id) { - min_heap.push_back(get_token_data(token_id)); - } - auto comp = [](const llama_token_data & a, const llama_token_data & b) { return a.logit > b.logit; }; - std::make_heap(min_heap.begin(), min_heap.end(), comp); - for (llama_token token_id = k_min ; token_id < n_vocab ; ++token_id) { - if (min_heap.front().logit < logits[token_id]) { - std::pop_heap(min_heap.begin(), min_heap.end(), comp); - min_heap.back().id = token_id; - min_heap.back().logit = logits[token_id]; - std::push_heap(min_heap.begin(), min_heap.end(), comp); - } - } - return min_heap; - } - float probability_from_logit(float logit) const { - return normalizer * std::exp(logit - max_l); - } -}; - -struct llama_beam_search_data { - llama_context * ctx; - size_t n_beams; - int n_past; - int n_predict; - std::vector beams; - std::vector next_beams; - - // Re-calculated on each loop iteration - size_t common_prefix_length; - - // Used to communicate to/from callback on beams state. - std::vector beam_views; - - llama_beam_search_data(llama_context * ctx, size_t n_beams, int n_past, int n_predict) - : ctx(ctx) - , n_beams(n_beams) - , n_past(n_past) - , n_predict(n_predict) - , beam_views(n_beams) { - beams.reserve(n_beams); - next_beams.reserve(n_beams); - } - - // Collapse beams to a single beam given by index. - void collapse_beams(const size_t beam_idx) { - if (0u < beam_idx) { - std::swap(beams[0], beams[beam_idx]); - } - beams.resize(1); - } - - // Min-heaps are used to efficiently collect the top-k elements (k=n_beams). - // The repetitive patterns below reflect the 2 stages of heaps: - // * Gather elements until the vector is full, then call std::make_heap() on it. - // * If the heap is full and a new element is found that should be included, pop the - // least element to the back(), replace it with the new, then push it into the heap. - void fill_next_beams_by_top_probabilities(llama_beam & beam) { - // Min-heaps use a greater-than comparator. - const auto comp = [](const llama_beam & a, const llama_beam & b) { return a.p > b.p; }; - if (beam.eob) { - // beam is at end-of-sentence, so just copy it to next_beams if its probability is high enough. - if (next_beams.size() < n_beams) { - next_beams.push_back(std::move(beam)); - if (next_beams.size() == n_beams) { - std::make_heap(next_beams.begin(), next_beams.end(), comp); - } - } else if (next_beams.front().p < beam.p) { - std::pop_heap(next_beams.begin(), next_beams.end(), comp); - next_beams.back() = std::move(beam); - std::push_heap(next_beams.begin(), next_beams.end(), comp); - } - } else { - // beam is not at end-of-sentence, so branch with next top_k tokens. - if (!beam.tokens.empty()) { - llama_decode(ctx, llama_batch_get_one(beam.tokens.data(), beam.tokens.size(), n_past, 0)); - } - llama_logit_info logit_info(ctx); - std::vector next_tokens = logit_info.top_k(n_beams); - - // Clear the kv slot so that other beams may try different tokens at this position. The llama_decode() - // call in loop() will conclusively fill in the kv slot once the beams converge at this position. - llama_kv_cache_seq_rm(ctx, 0, n_past, -1); - - size_t i=0; - if (next_beams.size() < n_beams) { - for (; next_beams.size() < n_beams ; ++i) { - llama_beam next_beam = beam; - next_beam.tokens.push_back(next_tokens[i].id); - next_beam.p *= logit_info.probability_from_logit(next_tokens[i].logit); - next_beams.push_back(std::move(next_beam)); - } - std::make_heap(next_beams.begin(), next_beams.end(), comp); - } else { - for (; next_beams.front().p == 0.0f ; ++i) { - std::pop_heap(next_beams.begin(), next_beams.end(), comp); - next_beams.back() = beam; - next_beams.back().tokens.push_back(next_tokens[i].id); - next_beams.back().p *= logit_info.probability_from_logit(next_tokens[i].logit); - std::push_heap(next_beams.begin(), next_beams.end(), comp); - } - } - for (; i < n_beams ; ++i) { - const float next_p = beam.p * logit_info.probability_from_logit(next_tokens[i].logit); - if (next_beams.front().p < next_p) { - std::pop_heap(next_beams.begin(), next_beams.end(), comp); - next_beams.back() = beam; - next_beams.back().tokens.push_back(next_tokens[i].id); - next_beams.back().p = next_p; - std::push_heap(next_beams.begin(), next_beams.end(), comp); - } - } - } - } - - // Find common_prefix_length based on beams. - // Requires beams is not empty. - size_t find_common_prefix_length() { - size_t common_prefix_length = beams[0].tokens.size(); - for (size_t i = 1 ; i < beams.size() ; ++i) { - common_prefix_length = std::min(common_prefix_length, beams[i].tokens.size()); - for (size_t j = 0 ; j < common_prefix_length ; ++j) { - if (beams[0].tokens[j] != beams[i].tokens[j]) { - common_prefix_length = j; - break; - } - } - } - return common_prefix_length; - } - - // Construct beams_state to send back to caller via the callback function. - // Side effect: set common_prefix_length = find_common_prefix_length(); - llama_beams_state get_beams_state(const bool last_call) { - for (size_t i = 0 ; i < beams.size() ; ++i) { - beam_views[i] = beams[i].view(); - } - common_prefix_length = find_common_prefix_length(); - return {beam_views.data(), beams.size(), common_prefix_length, last_call}; - } - - // Loop: - // * while i < n_predict, AND - // * any of the beams have not yet reached end-of-beam (eob), AND - // * the highest probability beam(s) (plural in case of ties) are not at end-of-sentence - // (since all other beam probabilities can only decrease) - void loop(const llama_beam_search_callback_fn_t callback, void * const callback_data) { - beams.push_back({{}, 1.0f, false}); // Start with one empty beam w/ probability = 1.0 and !eob. - const auto not_eob = [](const llama_beam & beam) { return !beam.eob; }; - for (int i = 0 ; i < n_predict && std::any_of(beams.begin(),beams.end(),not_eob) && - !beams[top_beam_index()].eob ; ++i) { - callback(callback_data, get_beams_state(false)); // Sets common_prefix_length - update_beams_from_beam_views(); // Update values (p,eob) that callback may have changed. - if (common_prefix_length) { - llama_decode(ctx, llama_batch_get_one(beams[0].tokens.data(), common_prefix_length, n_past, 0)); - n_past += common_prefix_length; - } - // Zero-out next_beam probabilities to place them last in following min-heap. - std::for_each(next_beams.begin(), next_beams.end(), [](llama_beam & beam) { beam.p = 0.0f; }); - for (llama_beam & beam : beams) { - beam.shift_tokens(common_prefix_length); - fill_next_beams_by_top_probabilities(beam); - } - // next_beams become the beams of next/final iteration. Swap them to re-use memory. - beams.swap(next_beams); - renormalize_beam_probabilities(beams); - } - collapse_beams(top_beam_index()); - callback(callback_data, get_beams_state(true)); - } - - // As beams grow, the cumulative probabilities decrease. - // Renormalize them to avoid floating point underflow. - static void renormalize_beam_probabilities(std::vector & beams) { - const auto sum_p = [](float sum, llama_beam & beam) { return sum + beam.p; }; - const float inv_sum = 1.0f / std::accumulate(beams.begin(), beams.end(), 0.0f, sum_p); - std::for_each(beams.begin(), beams.end(), [=](llama_beam & beam) { beam.p *= inv_sum; }); - } - - // Assumes beams is non-empty. Uses llama_beam::operator<() for ordering. - size_t top_beam_index() { - return std::max_element(beams.begin(), beams.end()) - beams.begin(); - } - - // Copy (p,eob) for each beam which may have been changed by the callback. - void update_beams_from_beam_views() { - for (size_t i = 0 ; i < beams.size() ; ++i) { - beams[i].p = beam_views[i].p; - beams[i].eob = beam_views[i].eob; - } - } -}; - -void llama_beam_search(llama_context * ctx, - llama_beam_search_callback_fn_t callback, void * callback_data, - size_t n_beams, int n_past, int n_predict) { - assert(ctx); - const int64_t t_start_sample_us = ggml_time_us(); - - llama_beam_search_data beam_search_data(ctx, n_beams, n_past, n_predict); - - beam_search_data.loop(callback, callback_data); - - ctx->t_sample_us += ggml_time_us() - t_start_sample_us; - ctx->n_sample++; -} - // // quantization // @@ -14551,8 +14953,6 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) && use_more_bits(qs.i_attention_wv, qs.n_attention_wv)) new_type = GGML_TYPE_Q6_K; else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && qs.i_attention_wv < 4) new_type = GGML_TYPE_Q5_K; - else if (QK_K == 64 && (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_S) && - (qs.i_attention_wv < qs.n_attention_wv/8 || qs.i_attention_wv >= 7*qs.n_attention_wv/8)) new_type = GGML_TYPE_Q6_K; if (qs.model.type == MODEL_70B) { // In the 70B model we have 8 heads sharing the same attn_v weights. As a result, the attn_v.weight tensor is // 8x smaller compared to attn_q.weight. Hence, we can get a nice boost in quantization accuracy with @@ -14855,6 +15255,14 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s if (imatrix_data) { LLAMA_LOG_INFO("================================ Have weights data with %d entries\n",int(imatrix_data->size())); qs.has_imatrix = true; + // check imatrix for nans or infs + for (const auto & kv : *imatrix_data) { + for (float f : kv.second) { + if (!std::isfinite(f)) { + throw std::runtime_error(format("imatrix contains non-finite value %f\n", f)); + } + } + } } } @@ -15548,7 +15956,7 @@ bool llama_supports_mlock(void) { } bool llama_supports_gpu_offload(void) { -#if defined(GGML_USE_CUDA) || defined(GGML_USE_CLBLAST) || defined(GGML_USE_METAL) || defined(GGML_USE_VULKAN) || \ +#if defined(GGML_USE_CUDA) || defined(GGML_USE_METAL) || defined(GGML_USE_VULKAN) || \ defined(GGML_USE_SYCL) || defined(GGML_USE_KOMPUTE) || defined(GGML_USE_RPC) // Defined when llama.cpp is compiled with support for offloading model layers to GPU. return true; @@ -15566,10 +15974,6 @@ void llama_backend_init(void) { struct ggml_context * ctx = ggml_init(params); ggml_free(ctx); } - -#ifdef GGML_USE_MPI - ggml_mpi_backend_init(); -#endif } void llama_numa_init(enum ggml_numa_strategy numa) { @@ -15579,9 +15983,6 @@ void llama_numa_init(enum ggml_numa_strategy numa) { } void llama_backend_free(void) { -#ifdef GGML_USE_MPI - ggml_mpi_backend_free(); -#endif ggml_quantize_free(); } @@ -15612,7 +16013,7 @@ struct llama_model * llama_load_model_from_file( return true; }; } - if (params.rpc_servers != nullptr) { + if (params.rpc_servers != nullptr && params.rpc_servers[0] != '\0') { // split the servers set them into model->rpc_servers std::string servers(params.rpc_servers); size_t pos = 0; @@ -15666,6 +16067,11 @@ struct llama_context * llama_new_context_with_model( params.flash_attn = false; } + if (params.type_v != GGML_TYPE_F16 && !params.flash_attn) { + LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__); + return nullptr; + } + llama_context * ctx = new llama_context(*model); const auto & hparams = model->hparams; @@ -15704,8 +16110,8 @@ struct llama_context * llama_new_context_with_model( cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch); - cparams.n_yarn_orig_ctx = params.yarn_orig_ctx != 0 ? params.yarn_orig_ctx : - hparams.n_yarn_orig_ctx != 0 ? hparams.n_yarn_orig_ctx : + cparams.n_ctx_orig_yarn = params.yarn_orig_ctx != 0 ? params.yarn_orig_ctx : + hparams.n_ctx_orig_yarn != 0 ? hparams.n_ctx_orig_yarn : hparams.n_ctx_train; cparams.cb_eval = params.cb_eval; @@ -15724,6 +16130,7 @@ struct llama_context * llama_new_context_with_model( cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_YARN ? 1.0f : 0.0f; } + cparams.yarn_attn_factor *= hparams.rope_attn_factor; cparams.causal_attn = hparams.causal_attn; if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) { @@ -15769,17 +16176,7 @@ struct llama_context * llama_new_context_with_model( if (!hparams.vocab_only) { // initialize backends -#if defined(GGML_USE_RPC) - for (auto & server : model->rpc_servers) { - ggml_backend_t backend = ggml_backend_rpc_init(server.c_str()); - if (backend == nullptr) { - LLAMA_LOG_ERROR("%s: failed to connect RPC backend to %s\n", __func__, server.c_str()); - llama_free(ctx); - return nullptr; - } - ctx->backends.push_back(backend); - } -#elif defined(GGML_USE_METAL) +#if defined(GGML_USE_METAL) if (model->n_gpu_layers > 0) { ctx->backend_metal = ggml_backend_metal_init(); if (ctx->backend_metal == nullptr) { @@ -15818,7 +16215,7 @@ struct llama_context * llama_new_context_with_model( return nullptr; } if (model->split_mode == LLAMA_SPLIT_MODE_NONE) { - ggml_backend_t backend = ggml_backend_vk_init(0); + ggml_backend_t backend = ggml_backend_vk_init(model->main_gpu); if (backend == nullptr) { LLAMA_LOG_ERROR("%s: failed to initialize Vulkan backend\n", __func__); llama_free(ctx); @@ -15871,6 +16268,19 @@ struct llama_context * llama_new_context_with_model( } ctx->backends.push_back(backend); } +#endif +#if defined(GGML_USE_RPC) + if (model->n_gpu_layers > 0) { + for (const auto & endpoint : model->rpc_servers) { + ggml_backend_t backend = ggml_backend_rpc_init(endpoint.c_str()); + if (backend == nullptr) { + LLAMA_LOG_ERROR("%s: failed to initialize RPC to '%s'\n", __func__, endpoint.c_str()); + llama_free(ctx); + return nullptr; + } + ctx->backends.push_back(backend); + } + } #endif ctx->backend_cpu = ggml_backend_cpu_init(); if (ctx->backend_cpu == nullptr) { @@ -15982,20 +16392,6 @@ struct llama_context * llama_new_context_with_model( } } -#ifdef GGML_USE_MPI - ctx->ctx_mpi = ggml_mpi_init(); - - if (ggml_mpi_rank(ctx->ctx_mpi) > 0) { - // Enter a blocking eval loop with dummy input, letting rank=0 drive the process - // TODO: needs fix after #3228 - GGML_ASSERT(false && "not implemented"); - //const std::vector tmp(ctx->model.hparams.n_ctx, llama_token_bos(ctx)); - //while (!llama_eval(ctx, tmp.data(), tmp.size(), 0, 0)) {}; - llama_backend_free(); - exit(1); - } -#endif - return ctx; } @@ -16032,7 +16428,6 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { // these models do not use RoPE case LLM_ARCH_GPT2: case LLM_ARCH_GPTJ: - case LLM_ARCH_GPTNEOX: case LLM_ARCH_MPT: case LLM_ARCH_REFACT: case LLM_ARCH_BLOOM: @@ -16052,13 +16447,14 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { case LLM_ARCH_XVERSE: case LLM_ARCH_COMMAND_R: case LLM_ARCH_OLMO: + case LLM_ARCH_ARCTIC: + case LLM_ARCH_DEEPSEEK2: return LLAMA_ROPE_TYPE_NORM; // the pairs of head values are offset by n_rot/2 case LLM_ARCH_FALCON: case LLM_ARCH_GROK: case LLM_ARCH_DBRX: - case LLM_ARCH_PERSIMMON: case LLM_ARCH_BERT: case LLM_ARCH_NOMIC_BERT: case LLM_ARCH_STABLELM: @@ -16069,6 +16465,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { case LLM_ARCH_PHI3: case LLM_ARCH_GEMMA: case LLM_ARCH_STARCODER2: + case LLM_ARCH_GPTNEOX: return LLAMA_ROPE_TYPE_NEOX; // all model arches should be listed explicitly here @@ -16228,6 +16625,7 @@ static bool llama_control_vector_init(struct llama_control_vector & cvec, const } // make tensors + cvec.tensors.reserve(model.hparams.n_layer); cvec.tensors.push_back(nullptr); // there's never a tensor for layer 0 for (size_t il = 1; il < model.hparams.n_layer; il++) { struct ggml_context * ctx = ctx_map.at(model.buft_layer[il].buft); @@ -16236,6 +16634,8 @@ static bool llama_control_vector_init(struct llama_control_vector & cvec, const } // allocate tensors / buffers and zero + cvec.ctxs.reserve(ctx_map.size()); + cvec.bufs.reserve(ctx_map.size()); for (auto it : ctx_map) { ggml_backend_buffer_type_t buft = it.first; ggml_context * ctx = it.second; @@ -17444,6 +17844,14 @@ void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_ ctx->cparams.n_threads_batch = n_threads_batch; } +uint32_t llama_n_threads(struct llama_context * ctx) { + return ctx->cparams.n_threads; +} + +uint32_t llama_n_threads_batch(struct llama_context * ctx) { + return ctx->cparams.n_threads_batch; +} + void llama_set_abort_callback(struct llama_context * ctx, bool (*abort_callback)(void * data), void * abort_callback_data) { ctx->abort_callback = abort_callback; ctx->abort_callback_data = abort_callback_data; @@ -17655,9 +18063,9 @@ float llama_token_get_score(const struct llama_model * model, llama_token token) return model->vocab.id_to_token[token].score; } -llama_token_type llama_token_get_type(const struct llama_model * model, llama_token token) { +llama_token_attr llama_token_get_attr(const struct llama_model * model, llama_token token) { GGML_ASSERT(model->vocab.type != LLAMA_VOCAB_TYPE_NONE); - return model->vocab.id_to_token[token].type; + return model->vocab.id_to_token[token].attr; } bool llama_token_is_eog(const struct llama_model * model, llama_token token) { @@ -17667,6 +18075,10 @@ bool llama_token_is_eog(const struct llama_model * model, llama_token token) { ); } +bool llama_token_is_control(const struct llama_model * model, llama_token token) { + return llama_is_control_token(model->vocab, token); +} + llama_token llama_token_bos(const struct llama_model * model) { return model->vocab.special_bos_id; } @@ -17738,7 +18150,16 @@ static std::string llama_decode_text(const std::string & text) { const auto cpts = unicode_cpts_from_utf8(text); for (const auto cpt : cpts) { - decoded_text += unicode_utf8_to_byte(unicode_cpt_to_utf8(cpt)); + const auto utf8 = unicode_cpt_to_utf8(cpt); + try { + decoded_text += unicode_utf8_to_byte(utf8); + } catch (const std::out_of_range & e) { + decoded_text += "[UNK_BYTE_0x"; + for (const auto c : utf8) { + decoded_text += format("%02x", (uint8_t) c); + } + decoded_text += text + "]"; + } } return decoded_text; @@ -17746,69 +18167,88 @@ static std::string llama_decode_text(const std::string & text) { // does not write null-terminator to buf int32_t llama_token_to_piece(const struct llama_model * model, llama_token token, char * buf, int32_t length, bool special) { + // ref: https://github.com/ggerganov/llama.cpp/pull/7587#discussion_r1620983843 + if (!special && llama_is_control_token(model->vocab, token)) { + return 0; + } + + // if we have a cache - use it + { + const auto & cache = model->vocab.cache_token_to_piece; + + if (!cache.empty()) { + const auto & res = cache.at(token); + if (length < (int) res.size()) { + return -(int) res.size(); + } + memcpy(buf, res.c_str(), res.size()); + return res.size(); + } + } + if (0 <= token && token < llama_n_vocab(model)) { switch (llama_vocab_get_type(model->vocab)) { - case LLAMA_VOCAB_TYPE_WPM: - case LLAMA_VOCAB_TYPE_SPM: { - // NOTE: we accept all unsupported token types, - // suppressing them like CONTROL tokens. - if (llama_is_normal_token(model->vocab, token)) { - std::string result = model->vocab.id_to_token[token].text; - llama_unescape_whitespace(result); - if (length < (int) result.length()) { - return -(int) result.length(); + case LLAMA_VOCAB_TYPE_WPM: + case LLAMA_VOCAB_TYPE_SPM: { + // NOTE: we accept all unsupported token types, + // suppressing them like CONTROL tokens. + if (llama_is_normal_token(model->vocab, token)) { + std::string result = model->vocab.id_to_token[token].text; + llama_unescape_whitespace(result); + if (length < (int) result.length()) { + return -(int) result.length(); + } + memcpy(buf, result.c_str(), result.length()); + return result.length(); + } else if ( + (llama_is_user_defined_token(model->vocab, token)) || + (llama_is_control_token (model->vocab, token) && special)) { + std::string result = model->vocab.id_to_token[token].text; + if (length < (int) result.length()) { + return -(int) result.length(); + } + memcpy(buf, result.c_str(), result.length()); + return result.length(); + } else if (llama_is_unknown_token(model->vocab, token)) { // NOLINT + if (length < 3) { + return -3; + } + memcpy(buf, "\xe2\x96\x85", 3); + return 3; + } else if (llama_is_byte_token(model->vocab, token)) { + if (length < 1) { + return -1; + } + buf[0] = llama_token_to_byte(model->vocab, token); + return 1; } - memcpy(buf, result.c_str(), result.length()); - return result.length(); - } else if ( - (llama_is_user_defined_token(model->vocab, token)) || - (llama_is_control_token (model->vocab, token) && special)) { - std::string result = model->vocab.id_to_token[token].text; - if (length < (int) result.length()) { - return -(int) result.length(); - } - memcpy(buf, result.c_str(), result.length()); - return result.length(); - } else if (llama_is_unknown_token(model->vocab, token)) { // NOLINT - if (length < 3) { - return -3; - } - memcpy(buf, "\xe2\x96\x85", 3); - return 3; - } else if (llama_is_byte_token(model->vocab, token)) { - if (length < 1) { - return -1; - } - buf[0] = llama_token_to_byte(model->vocab, token); - return 1; + break; } - break; - } - case LLAMA_VOCAB_TYPE_BPE: { - // NOTE: we accept all unsupported token types, - // suppressing them like CONTROL tokens. - if (llama_is_normal_token(model->vocab, token)) { - std::string result = model->vocab.id_to_token[token].text; - result = llama_decode_text(result); - if (length < (int) result.length()) { - return -(int) result.length(); + case LLAMA_VOCAB_TYPE_BPE: { + // NOTE: we accept all unsupported token types, + // suppressing them like CONTROL tokens. + if (llama_is_normal_token(model->vocab, token)) { + std::string result = model->vocab.id_to_token[token].text; + result = llama_decode_text(result); + if (length < (int) result.length()) { + return -(int) result.length(); + } + memcpy(buf, result.c_str(), result.length()); + return result.length(); + } else if ( + (llama_is_user_defined_token(model->vocab, token)) || + (llama_is_control_token (model->vocab, token) && special)) { + std::string result = model->vocab.id_to_token[token].text; + if (length < (int) result.length()) { + return -(int) result.length(); + } + memcpy(buf, result.c_str(), result.length()); + return result.length(); } - memcpy(buf, result.c_str(), result.length()); - return result.length(); - } else if ( - (llama_is_user_defined_token(model->vocab, token)) || - (llama_is_control_token (model->vocab, token) && special)) { - std::string result = model->vocab.id_to_token[token].text; - if (length < (int) result.length()) { - return -(int) result.length(); - } - memcpy(buf, result.c_str(), result.length()); - return result.length(); + break; } - break; - } - default: - GGML_ASSERT(false); + default: + GGML_ASSERT(false); } } return 0; @@ -17878,6 +18318,15 @@ static int32_t llama_chat_apply_template_internal( } } // llama2 templates seem to not care about "add_generation_prompt" + } else if (tmpl == "phi3" || (tmpl.find("<|assistant|>") != std::string::npos && tmpl.find("<|end|>") != std::string::npos)) { + // Phi 3 + for (auto message : chat) { + std::string role(message->role); + ss << "<|" << role << "|>\n" << message->content << "<|end|>\n"; + } + if (add_ass) { + ss << "<|assistant|>\n"; + } } else if (tmpl == "zephyr" || tmpl.find("<|user|>") != std::string::npos) { // zephyr template for (auto message : chat) { @@ -18010,15 +18459,6 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "<|start_header_id|>assistant<|end_header_id|>\n\n"; } - } else if (tmpl == "phi3" || (tmpl.find("<|assistant|>") != std::string::npos && tmpl.find("<|end|>") != std::string::npos )) { - // Phi 3 - for (auto message : chat) { - std::string role(message->role); - ss << "<|" << role << "|>\n" << trim(message->content) << "<|end|>\n"; - } - if (add_ass) { - ss << "<|assistant|>\n"; - } } else { // template not supported return -1; @@ -18140,8 +18580,10 @@ const char * llama_print_system_info(void) { s += "AVX512 = " + std::to_string(ggml_cpu_has_avx512()) + " | "; s += "AVX512_VBMI = " + std::to_string(ggml_cpu_has_avx512_vbmi()) + " | "; s += "AVX512_VNNI = " + std::to_string(ggml_cpu_has_avx512_vnni()) + " | "; + s += "AVX512_BF16 = " + std::to_string(ggml_cpu_has_avx512_bf16()) + " | "; s += "FMA = " + std::to_string(ggml_cpu_has_fma()) + " | "; s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | "; + s += "SVE = " + std::to_string(ggml_cpu_has_sve()) + " | "; s += "ARM_FMA = " + std::to_string(ggml_cpu_has_arm_fma()) + " | "; s += "F16C = " + std::to_string(ggml_cpu_has_f16c()) + " | "; s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | "; @@ -18200,6 +18642,8 @@ void llama_log_set(ggml_log_callback log_callback, void * user_data) { g_state.log_callback_user_data = user_data; #ifdef GGML_USE_METAL ggml_backend_metal_log_set_callback(g_state.log_callback, g_state.log_callback_user_data); +#elif defined(GGML_USE_CUDA) + ggml_backend_cuda_log_set_callback(g_state.log_callback, g_state.log_callback_user_data); #endif } diff --git a/llama/llama.h b/llama/llama.h index 93d5f361..f91c11ed 100644 --- a/llama/llama.h +++ b/llama/llama.h @@ -1,5 +1,5 @@ /** - * llama.cpp - git 059031b8c40e1f4ba60586842c5b1ed3ddf61842 + * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28 * * MIT License * @@ -107,9 +107,11 @@ extern "C" { LLAMA_VOCAB_PRE_TYPE_GPT2 = 7, LLAMA_VOCAB_PRE_TYPE_REFACT = 8, LLAMA_VOCAB_PRE_TYPE_COMMAND_R = 9, - LLAMA_VOCAB_PRE_TYPE_QWEN2 = 10, - LLAMA_VOCAB_PRE_TYPE_OLMO = 11, - LLAMA_VOCAB_PRE_TYPE_DBRX = 12, + LLAMA_VOCAB_PRE_TYPE_STABLELM2 = 10, + LLAMA_VOCAB_PRE_TYPE_QWEN2 = 11, + LLAMA_VOCAB_PRE_TYPE_OLMO = 12, + LLAMA_VOCAB_PRE_TYPE_DBRX = 13, + LLAMA_VOCAB_PRE_TYPE_SMAUG = 14, }; // note: these values should be synchronized with ggml_rope @@ -121,7 +123,7 @@ extern "C" { LLAMA_ROPE_TYPE_GLM = 4, }; - enum llama_token_type { + enum llama_token_type { //TODO: remove, required until per token attributes are available from GGUF file LLAMA_TOKEN_TYPE_UNDEFINED = 0, LLAMA_TOKEN_TYPE_NORMAL = 1, LLAMA_TOKEN_TYPE_UNKNOWN = 2, @@ -131,6 +133,20 @@ extern "C" { LLAMA_TOKEN_TYPE_BYTE = 6, }; + enum llama_token_attr { + LLAMA_TOKEN_ATTR_UNDEFINED = 0, + LLAMA_TOKEN_ATTR_UNKNOWN = 1 << 0, + LLAMA_TOKEN_ATTR_UNUSED = 1 << 1, + LLAMA_TOKEN_ATTR_NORMAL = 1 << 2, + LLAMA_TOKEN_ATTR_CONTROL = 1 << 3, // SPECIAL? + LLAMA_TOKEN_ATTR_USER_DEFINED = 1 << 4, + LLAMA_TOKEN_ATTR_BYTE = 1 << 5, + LLAMA_TOKEN_ATTR_NORMALIZED = 1 << 6, + LLAMA_TOKEN_ATTR_LSTRIP = 1 << 7, + LLAMA_TOKEN_ATTR_RSTRIP = 1 << 8, + LLAMA_TOKEN_ATTR_SINGLE_WORD = 1 << 9, + }; + // model file types enum llama_ftype { LLAMA_FTYPE_ALL_F32 = 0, @@ -289,6 +305,8 @@ extern "C" { bool check_tensors; // validate model tensor data }; + // NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations + // https://github.com/ggerganov/llama.cpp/pull/7544 struct llama_context_params { uint32_t seed; // RNG seed, -1 for random uint32_t n_ctx; // text context, 0 = from model @@ -315,14 +333,14 @@ extern "C" { ggml_backend_sched_eval_callback cb_eval; void * cb_eval_user_data; - enum ggml_type type_k; // data type for K cache - enum ggml_type type_v; // data type for V cache + enum ggml_type type_k; // data type for K cache [EXPERIMENTAL] + enum ggml_type type_v; // data type for V cache [EXPERIMENTAL] // Keep the booleans together to avoid misalignment during copy-by-value. bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead) bool embeddings; // if true, extract embeddings (together with logits) bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU - bool flash_attn; // whether to use flash attention + bool flash_attn; // whether to use flash attention [EXPERIMENTAL] // Abort callback // if it returns true, execution of llama_decode() will be aborted @@ -373,6 +391,9 @@ extern "C" { // modifies a preceding LLAMA_GRETYPE_CHAR or // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA]) LLAMA_GRETYPE_CHAR_ALT = 6, + + // any character (.) + LLAMA_GRETYPE_CHAR_ANY = 7, }; typedef struct llama_grammar_element { @@ -446,8 +467,8 @@ extern "C" { LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx); - LLAMA_API enum llama_vocab_type llama_vocab_type (const struct llama_model * model); - LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model); + LLAMA_API enum llama_vocab_type llama_vocab_type (const struct llama_model * model); + LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model); LLAMA_API int32_t llama_n_vocab (const struct llama_model * model); LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model); @@ -784,6 +805,12 @@ extern "C" { // n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens) LLAMA_API void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch); + // Get the number of threads used for generation of a single token. + LLAMA_API uint32_t llama_n_threads(struct llama_context * ctx); + + // Get the number of threads used for prompt and batch processing (multiple token). + LLAMA_API uint32_t llama_n_threads_batch(struct llama_context * ctx); + // Set whether to use causal attention or not // If set to true, the model will only attend to the past tokens LLAMA_API void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn); @@ -837,11 +864,14 @@ extern "C" { LLAMA_API float llama_token_get_score(const struct llama_model * model, llama_token token); - LLAMA_API enum llama_token_type llama_token_get_type(const struct llama_model * model, llama_token token); + LLAMA_API enum llama_token_attr llama_token_get_attr(const struct llama_model * model, llama_token token); // Check if the token is supposed to end generation (end-of-generation, eg. EOS, EOT, etc.) LLAMA_API bool llama_token_is_eog(const struct llama_model * model, llama_token token); + // Identify if Token Id is a control token or a render-able token + LLAMA_API bool llama_token_is_control(const struct llama_model * model, llama_token token); + // Special tokens LLAMA_API llama_token llama_token_bos(const struct llama_model * model); // beginning-of-sentence LLAMA_API llama_token llama_token_eos(const struct llama_model * model); // end-of-sentence @@ -1055,49 +1085,9 @@ extern "C" { llama_token token); // - // Beam search + // Model split // - struct llama_beam_view { - const llama_token * tokens; - - size_t n_tokens; - float p; // Cumulative beam probability (renormalized relative to all beams) - bool eob; // Callback should set this to true when a beam is at end-of-beam. - }; - - // Passed to beam_search_callback function. - // Whenever 0 < common_prefix_length, this number of tokens should be copied from any of the beams - // (e.g. beams[0]) as they will be removed (shifted) from all beams in all subsequent callbacks. - // These pointers are valid only during the synchronous callback, so should not be saved. - struct llama_beams_state { - struct llama_beam_view * beam_views; - - size_t n_beams; // Number of elements in beam_views[]. - size_t common_prefix_length; // Current max length of prefix tokens shared by all beams. - bool last_call; // True iff this is the last callback invocation. - }; - - // Type of pointer to the beam_search_callback function. - // void* callback_data is any custom data passed to llama_beam_search, that is subsequently - // passed back to beam_search_callback. This avoids having to use global variables in the callback. - typedef void (*llama_beam_search_callback_fn_t)(void * callback_data, struct llama_beams_state); - - /// @details Deterministically returns entire sentence constructed by a beam search. - /// @param ctx Pointer to the llama_context. - /// @param callback Invoked for each iteration of the beam_search loop, passing in beams_state. - /// @param callback_data A pointer that is simply passed back to callback. - /// @param n_beams Number of beams to use. - /// @param n_past Number of tokens already evaluated. - /// @param n_predict Maximum number of tokens to predict. EOS may occur earlier. - LLAMA_API void llama_beam_search( - struct llama_context * ctx, - llama_beam_search_callback_fn_t callback, - void * callback_data, - size_t n_beams, - int32_t n_past, - int32_t n_predict); - /// @details Build a split GGUF final path for this chunk. /// llama_split_path(split_path, sizeof(split_path), "/models/ggml-model-q4_0", 2, 4) => split_path = "/models/ggml-model-q4_0-00002-of-00004.gguf" // Returns the split_path length. diff --git a/llama/llava.cpp b/llama/llava.cpp index ee2e3ee4..16f27eaa 100644 --- a/llama/llava.cpp +++ b/llama/llava.cpp @@ -1,5 +1,5 @@ /** - * llama.cpp - git 059031b8c40e1f4ba60586842c5b1ed3ddf61842 + * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28 * * MIT License * diff --git a/llama/llava.h b/llama/llava.h index b2ec0a1a..e54b37cb 100644 --- a/llama/llava.h +++ b/llama/llava.h @@ -1,5 +1,5 @@ /** - * llama.cpp - git 059031b8c40e1f4ba60586842c5b1ed3ddf61842 + * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28 * * MIT License * diff --git a/llama/log.h b/llama/log.h index 1f4cde51..9223b9e3 100644 --- a/llama/log.h +++ b/llama/log.h @@ -1,5 +1,5 @@ /** - * llama.cpp - git 059031b8c40e1f4ba60586842c5b1ed3ddf61842 + * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28 * * MIT License * diff --git a/llama/patches/02-default-pretokenizer.diff b/llama/patches/02-llamacpp.diff similarity index 58% rename from llama/patches/02-default-pretokenizer.diff rename to llama/patches/02-llamacpp.diff index 27c8aabc..7b5f9b70 100644 --- a/llama/patches/02-default-pretokenizer.diff +++ b/llama/patches/02-llamacpp.diff @@ -1,8 +1,8 @@ -diff --git a/llama.cpp b/llama.cpp -index 40d2ec2c..74f3ee9c 100644 ---- a/llama.cpp -+++ b/llama.cpp -@@ -4642,16 +4642,7 @@ static void llm_load_vocab( +diff --git a/llama/llama.cpp b/llama/llama.cpp +index 8b675ea9..bcc6ae75 100644 +--- a/llama/llama.cpp ++++ b/llama/llama.cpp +@@ -4645,16 +4645,7 @@ static void llm_load_vocab( // for now, only BPE models have pre-tokenizers if (vocab.type == LLAMA_VOCAB_TYPE_BPE) { @@ -15,12 +15,12 @@ index 40d2ec2c..74f3ee9c 100644 - LLAMA_LOG_WARN("%s: ************************************ \n", __func__); - LLAMA_LOG_WARN("%s: \n", __func__); - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; -- } else if ( -+ if ( - tokenizer_pre == "default") { +- } else if (tokenizer_pre == "default") { ++ if (tokenizer_pre == "default") { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; } else if ( -@@ -4703,7 +4694,8 @@ static void llm_load_vocab( + tokenizer_pre == "llama3" || +@@ -4706,7 +4697,8 @@ static void llm_load_vocab( tokenizer_pre == "smaug-bpe") { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_SMAUG; } else { @@ -30,3 +30,12 @@ index 40d2ec2c..74f3ee9c 100644 } } else { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; +@@ -7009,7 +7001,7 @@ static struct ggml_tensor * llm_build_kqv( + struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); + cb(kq, "kq", il); + +- if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX) { ++ if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2) { + // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs + // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847 + ggml_mul_mat_set_prec(kq, GGML_PREC_F32); diff --git a/llama/patches/03-metal.diff b/llama/patches/03-metal.diff index f8fa7db7..f3edde3e 100644 --- a/llama/patches/03-metal.diff +++ b/llama/patches/03-metal.diff @@ -1,7 +1,7 @@ -diff --git a/ggml-metal.m b/ggml-metal.m +diff --git a/llama/ggml-metal-darwin_arm64.m b/llama/ggml-metal-darwin_arm64.m index 0207b787..b5e9884b 100644 ---- a/ggml-metal.m -+++ b/ggml-metal.m +--- a/llama/ggml-metal-darwin_arm64.m ++++ b/llama/ggml-metal-darwin_arm64.m @@ -1396,27 +1396,23 @@ static enum ggml_status ggml_metal_graph_compute( // to the matrix-vector kernel int ne11_mm_min = 1; diff --git a/llama/patches/04-qwen2.diff b/llama/patches/04-qwen2.diff deleted file mode 100644 index d7b0c155..00000000 --- a/llama/patches/04-qwen2.diff +++ /dev/null @@ -1,13 +0,0 @@ -diff --git a/llama.cpp b/llama.cpp -index 40d2ec2c..f34eb79a 100644 ---- a/llama.cpp -+++ b/llama.cpp -@@ -6943,7 +6943,7 @@ static struct ggml_tensor * llm_build_kqv( - struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); - cb(kq, "kq", il); - -- if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX) { -+ if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2) { - // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs - // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847 - ggml_mul_mat_set_prec(kq, GGML_PREC_F32); diff --git a/llama/sampling.cpp b/llama/sampling.cpp index 062a47cf..148d76e5 100644 --- a/llama/sampling.cpp +++ b/llama/sampling.cpp @@ -1,5 +1,5 @@ /** - * llama.cpp - git 059031b8c40e1f4ba60586842c5b1ed3ddf61842 + * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28 * * MIT License * @@ -151,7 +151,7 @@ std::string llama_sampling_order_print(const llama_sampling_params & params) { std::string result = "CFG -> Penalties "; if (params.mirostat == 0) { for (auto sampler_type : params.samplers_sequence) { - const auto sampler_type_name = sampler_type_to_name_string(sampler_type); + const auto sampler_type_name = llama_sampling_type_to_str(sampler_type); if (!sampler_type_name.empty()) { result += "-> " + sampler_type_name + " "; } @@ -163,6 +163,87 @@ std::string llama_sampling_order_print(const llama_sampling_params & params) { return result; } +std::string llama_sampling_type_to_str(llama_sampler_type sampler_type) { + switch (sampler_type) { + case llama_sampler_type::TOP_K: return "top_k"; + case llama_sampler_type::TFS_Z: return "tfs_z"; + case llama_sampler_type::TYPICAL_P: return "typical_p"; + case llama_sampler_type::TOP_P: return "top_p"; + case llama_sampler_type::MIN_P: return "min_p"; + case llama_sampler_type::TEMPERATURE: return "temperature"; + default : return ""; + } +} + +std::vector llama_sampling_types_from_names(const std::vector & names, bool allow_alt_names) { + std::unordered_map sampler_canonical_name_map { + {"top_k", llama_sampler_type::TOP_K}, + {"top_p", llama_sampler_type::TOP_P}, + {"typical_p", llama_sampler_type::TYPICAL_P}, + {"min_p", llama_sampler_type::MIN_P}, + {"tfs_z", llama_sampler_type::TFS_Z}, + {"temperature", llama_sampler_type::TEMPERATURE} + }; + + // since samplers names are written multiple ways + // make it ready for both system names and input names + std::unordered_map sampler_alt_name_map { + {"top-k", llama_sampler_type::TOP_K}, + {"top-p", llama_sampler_type::TOP_P}, + {"nucleus", llama_sampler_type::TOP_P}, + {"typical-p", llama_sampler_type::TYPICAL_P}, + {"typical", llama_sampler_type::TYPICAL_P}, + {"min-p", llama_sampler_type::MIN_P}, + {"tfs-z", llama_sampler_type::TFS_Z}, + {"tfs", llama_sampler_type::TFS_Z}, + {"temp", llama_sampler_type::TEMPERATURE} + }; + + std::vector sampler_types; + sampler_types.reserve(names.size()); + for (const auto & name : names) + { + auto sampler_item = sampler_canonical_name_map.find(name); + if (sampler_item != sampler_canonical_name_map.end()) + { + sampler_types.push_back(sampler_item->second); + } + else + { + if (allow_alt_names) + { + sampler_item = sampler_alt_name_map.find(name); + if (sampler_item != sampler_alt_name_map.end()) + { + sampler_types.push_back(sampler_item->second); + } + } + } + } + return sampler_types; +} + +std::vector llama_sampling_types_from_chars(const std::string & names_string) { + std::unordered_map sampler_name_map { + {'k', llama_sampler_type::TOP_K}, + {'p', llama_sampler_type::TOP_P}, + {'y', llama_sampler_type::TYPICAL_P}, + {'m', llama_sampler_type::MIN_P}, + {'f', llama_sampler_type::TFS_Z}, + {'t', llama_sampler_type::TEMPERATURE} + }; + + std::vector sampler_types; + sampler_types.reserve(names_string.size()); + for (const auto & c : names_string) { + const auto sampler_item = sampler_name_map.find(c); + if (sampler_item != sampler_name_map.end()) { + sampler_types.push_back(sampler_item->second); + } + } + return sampler_types; +} + // no reasons to expose this function in header static void sampler_queue( struct llama_context * ctx_main, @@ -205,7 +286,7 @@ static llama_token llama_sampling_sample_impl( struct llama_context * ctx_main, struct llama_context * ctx_cfg, const int idx, - bool is_resampling) { // Add a parameter to indicate if we are resampling + bool is_resampling) { const llama_sampling_params & params = ctx_sampling->params; const float temp = params.temp; @@ -214,8 +295,8 @@ static llama_token llama_sampling_sample_impl( const float mirostat_eta = params.mirostat_eta; std::vector original_logits; - auto cur_p = llama_sampling_prepare(ctx_sampling, ctx_main, ctx_cfg, idx, !is_resampling, &original_logits); - if (!is_resampling) { + auto cur_p = llama_sampling_prepare(ctx_sampling, ctx_main, ctx_cfg, idx, /* apply_grammar= */ is_resampling, &original_logits); + if (ctx_sampling->grammar != NULL && !is_resampling) { GGML_ASSERT(!original_logits.empty()); } llama_token id = 0; @@ -278,7 +359,7 @@ static llama_token llama_sampling_sample_impl( // Restore logits from the copy std::copy(original_logits.begin(), original_logits.end(), logits); - return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, true); // Pass true for is_resampling + return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, /* is_resampling= */ true); } } @@ -311,7 +392,8 @@ static llama_token_data_array llama_sampling_prepare_impl( // Get a pointer to the logits float * logits = llama_get_logits_ith(ctx_main, idx); - if (apply_grammar && original_logits != NULL) { + if (ctx_sampling->grammar != NULL && !apply_grammar) { + GGML_ASSERT(original_logits != NULL); // Only make a copy of the original logits if we are not applying grammar checks, not sure if I actually have to do this. *original_logits = {logits, logits + llama_n_vocab(llama_get_model(ctx_main))}; } @@ -368,7 +450,7 @@ llama_token llama_sampling_sample( struct llama_context * ctx_cfg, const int idx) { // Call the implementation function with is_resampling set to false by default - return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, false); + return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, /* is_resampling= */ false); } llama_token_data_array llama_sampling_prepare( diff --git a/llama/sampling.h b/llama/sampling.h index 5f7b05e3..e1e7041d 100644 --- a/llama/sampling.h +++ b/llama/sampling.h @@ -1,5 +1,5 @@ /** - * llama.cpp - git 059031b8c40e1f4ba60586842c5b1ed3ddf61842 + * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28 * * MIT License * @@ -142,6 +142,11 @@ std::string llama_sampling_print(const llama_sampling_params & params); // Print sampling order into a string std::string llama_sampling_order_print(const llama_sampling_params & params); +std::string llama_sampling_type_to_str(llama_sampler_type sampler_type); + +std::vector llama_sampling_types_from_names(const std::vector & names, bool allow_alt_names); +std::vector llama_sampling_types_from_chars(const std::string & names_string); + // this is a common sampling function used across the examples for convenience // it can serve as a starting point for implementing your own sampling function // Note: When using multiple sequences, it is the caller's responsibility to call diff --git a/llama/sampling_ext.cpp b/llama/sampling_ext.cpp index 0ad7f941..0aadc8e5 100644 --- a/llama/sampling_ext.cpp +++ b/llama/sampling_ext.cpp @@ -1,3 +1,263 @@ +/** + * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28 + * + * MIT License + * + * Copyright (c) 2023-2024 The ggml authors + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +/** + * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28 + * + * MIT License + * + * Copyright (c) 2023-2024 The ggml authors + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +/** + * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28 + * + * MIT License + * + * Copyright (c) 2023-2024 The ggml authors + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +/** + * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28 + * + * MIT License + * + * Copyright (c) 2023-2024 The ggml authors + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +/** + * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28 + * + * MIT License + * + * Copyright (c) 2023-2024 The ggml authors + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +/** + * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28 + * + * MIT License + * + * Copyright (c) 2023-2024 The ggml authors + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +/** + * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28 + * + * MIT License + * + * Copyright (c) 2023-2024 The ggml authors + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +/** + * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28 + * + * MIT License + * + * Copyright (c) 2023-2024 The ggml authors + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +/** + * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28 + * + * MIT License + * + * Copyright (c) 2023-2024 The ggml authors + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +/** + * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28 + * + * MIT License + * + * Copyright (c) 2023-2024 The ggml authors + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + // TODO: this is a temporary wrapper to allow calling C++ code from CGo #include "sampling.h" #include "sampling_ext.h" diff --git a/llama/sampling_ext.h b/llama/sampling_ext.h index ac791ddf..37955719 100644 --- a/llama/sampling_ext.h +++ b/llama/sampling_ext.h @@ -1,3 +1,263 @@ +/** + * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28 + * + * MIT License + * + * Copyright (c) 2023-2024 The ggml authors + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +/** + * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28 + * + * MIT License + * + * Copyright (c) 2023-2024 The ggml authors + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +/** + * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28 + * + * MIT License + * + * Copyright (c) 2023-2024 The ggml authors + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +/** + * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28 + * + * MIT License + * + * Copyright (c) 2023-2024 The ggml authors + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +/** + * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28 + * + * MIT License + * + * Copyright (c) 2023-2024 The ggml authors + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +/** + * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28 + * + * MIT License + * + * Copyright (c) 2023-2024 The ggml authors + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +/** + * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28 + * + * MIT License + * + * Copyright (c) 2023-2024 The ggml authors + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +/** + * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28 + * + * MIT License + * + * Copyright (c) 2023-2024 The ggml authors + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +/** + * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28 + * + * MIT License + * + * Copyright (c) 2023-2024 The ggml authors + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +/** + * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28 + * + * MIT License + * + * Copyright (c) 2023-2024 The ggml authors + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + // TODO: this is a temporary wrapper to allow calling C++ code from CGo #ifndef LLAMA_SAMPLING_EXT_H #define LLAMA_SAMPLING_EXT_H diff --git a/llama/stb_image.h b/llama/stb_image.h index e116ad30..fa13e569 100644 --- a/llama/stb_image.h +++ b/llama/stb_image.h @@ -1,5 +1,5 @@ /** - * llama.cpp - git 059031b8c40e1f4ba60586842c5b1ed3ddf61842 + * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28 * * MIT License * diff --git a/llama/unicode-data.cpp b/llama/unicode-data.cpp index 070ed987..e5699800 100644 --- a/llama/unicode-data.cpp +++ b/llama/unicode-data.cpp @@ -1,5 +1,5 @@ /** - * llama.cpp - git 059031b8c40e1f4ba60586842c5b1ed3ddf61842 + * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28 * * MIT License * diff --git a/llama/unicode-data.h b/llama/unicode-data.h index 63a8f727..75e1fd9d 100644 --- a/llama/unicode-data.h +++ b/llama/unicode-data.h @@ -1,5 +1,5 @@ /** - * llama.cpp - git 059031b8c40e1f4ba60586842c5b1ed3ddf61842 + * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28 * * MIT License * diff --git a/llama/unicode.cpp b/llama/unicode.cpp index 2083d20e..14564785 100644 --- a/llama/unicode.cpp +++ b/llama/unicode.cpp @@ -1,5 +1,5 @@ /** - * llama.cpp - git 059031b8c40e1f4ba60586842c5b1ed3ddf61842 + * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28 * * MIT License * diff --git a/llama/unicode.h b/llama/unicode.h index 6074f0d6..033818c1 100644 --- a/llama/unicode.h +++ b/llama/unicode.h @@ -1,5 +1,5 @@ /** - * llama.cpp - git 059031b8c40e1f4ba60586842c5b1ed3ddf61842 + * llama.cpp - git d5c938cd7716b9a2ace49a43a469dfbffcff4d28 * * MIT License * diff --git a/scripts/sync_llama.sh b/scripts/sync_llama.sh index aaff057c..a6cfb12a 100755 --- a/scripts/sync_llama.sh +++ b/scripts/sync_llama.sh @@ -1,5 +1,7 @@ #!/bin/bash +set -e + # Set the source directory src_dir=$1 @@ -8,7 +10,7 @@ if [ -z "$src_dir" ]; then exit 1 fi -# Set the destination directory (current directory) +# Set the destination directory dst_dir=./llama # llama.cpp @@ -72,7 +74,7 @@ char const *LLAMA_BUILD_TARGET = ""; EOF # apply patches -for patch in $dst_dir/patches/*.patch; do +for patch in $dst_dir/patches/*.diff; do git apply "$patch" done @@ -112,6 +114,6 @@ echo "_ggml_metallib_start:" >> $TEMP_ASSEMBLY echo ".incbin \"temp.metal\"" >> $TEMP_ASSEMBLY echo ".globl _ggml_metallib_end" >> $TEMP_ASSEMBLY echo "_ggml_metallib_end:" >> $TEMP_ASSEMBLY -as -mmacosx-version-min=11.3 $TEMP_ASSEMBLY -o ggml-metal.o +as -mmacosx-version-min=11.3 $TEMP_ASSEMBLY -o $dst_dir/ggml-metal.o rm -f $TEMP_ASSEMBLY rm -rf temp.metal