Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ggml/src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -1281,7 +1281,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float_ref = (ggml_from_float_t)quantize_row_iq4_nl_ref,
.vec_dot = ggml_vec_dot_iq4_nl_q8_0,
#if GGML_USE_IQK_MULMAT
#if defined __AVX2__
#if defined HAVE_FANCY_SIMD
.vec_dot_type = GGML_TYPE_Q8_2_X4,
#else
.vec_dot_type = GGML_TYPE_Q8_0_X4,
Expand Down
49 changes: 47 additions & 2 deletions ggml/src/iqk/iqk_mul_mat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1750,6 +1750,15 @@ __m256i inline load_iq4nl_values_256() {
return MM256_SET_M128I(val128, val128);
}

__m128i inline load_iq4k_values_128() {
return _mm_loadu_si128((const __m128i *)iq4k_values);
}

__m256i inline load_iq4k_values_256() {
auto val128 = load_iq4k_values_128();
return MM256_SET_M128I(val128, val128);
}

#ifdef HAVE_FANCY_SIMD
//====================================== Zen4 ==================================================

Expand Down Expand Up @@ -8519,7 +8528,11 @@ struct Q4_0_1_Dequantizer {

struct IQ4_NL_Dequantizer {
Dequantizer4bit b4;
#ifdef HAVE_FANCY_SIMD
const __m256i values = load_iq4nl_values_256();
#else
const __m256i values = load_iq4k_values_256();
#endif
inline __m256i dequant(const block_iq4_nl * x) const {
return _mm256_shuffle_epi8(values, b4.dequant(x->qs));
}
Expand Down Expand Up @@ -8630,11 +8643,19 @@ struct Q4_0_1_Unpacker final : public Q_Unpacker<block_q4_0, ScaleHelperQ_0_1<8>
using Sum4T = Sum4TypeQ82;
inline static int block_size() { return QK4_0; }
};
#ifdef HAVE_FANCY_SIMD
struct IQ4_NL_Unpacker final : public Q_Unpacker<block_iq4_nl, ScaleHelperQ_0_1<128>, IQ4_NL_Dequantizer> {
IQ4_NL_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}
using Sum4T = Sum4TypeQ82;
inline static int block_size() { return QK4_NL; }
};
#else
struct IQ4_NL_Unpacker final : public Q_Unpacker<block_iq4_nl, ScaleHelperQ_0, IQ4_NL_Dequantizer> {
IQ4_NL_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}
using Sum4T = Sum4TypeQ80;
inline static int block_size() { return QK4_NL; }
};
#endif
struct Q5_0_Unpacker final : public Q_Unpacker<block_q5_0, ScaleHelperQ_0, Q5_0_Dequantizer> {
Q5_0_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}
using Sum4T = Sum4TypeQ80;
Expand Down Expand Up @@ -9155,9 +9176,29 @@ template <typename Dequantizer> void MulMat::set_functions(MulMat& m) {
m.funcs[6] = mul_mat_qX_1_q8_2_T<Dequantizer, 7>;
m.funcs[7] = mul_mat_qX_1_q8_2_T<Dequantizer, 8>;
}
else if constexpr (std::is_same_v<Dequantizer, IQ4_NL_Unpacker>) {
#ifdef HAVE_FANCY_SIMD
m.funcs[0] = mul_mat_qX_1_q8_2_T<Dequantizer, 1>;
m.funcs[1] = mul_mat_qX_1_q8_2_T<Dequantizer, 2>;
m.funcs[2] = mul_mat_qX_1_q8_2_T<Dequantizer, 3>;
m.funcs[3] = mul_mat_qX_1_q8_2_T<Dequantizer, 4>;
m.funcs[4] = mul_mat_qX_1_q8_2_T<Dequantizer, 5>;
m.funcs[5] = mul_mat_qX_1_q8_2_T<Dequantizer, 6>;
m.funcs[6] = mul_mat_qX_1_q8_2_T<Dequantizer, 7>;
m.funcs[7] = mul_mat_qX_1_q8_2_T<Dequantizer, 8>;
#else
m.funcs[0] = mul_mat_qX_0_q8_0_T<Dequantizer, 1>;
m.funcs[1] = mul_mat_qX_0_q8_0_T<Dequantizer, 2>;
m.funcs[2] = mul_mat_qX_0_q8_0_T<Dequantizer, 3>;
m.funcs[3] = mul_mat_qX_0_q8_0_T<Dequantizer, 4>;
m.funcs[4] = mul_mat_qX_0_q8_0_T<Dequantizer, 5>;
m.funcs[5] = mul_mat_qX_0_q8_0_T<Dequantizer, 6>;
m.funcs[6] = mul_mat_qX_0_q8_0_T<Dequantizer, 7>;
m.funcs[7] = mul_mat_qX_0_q8_0_T<Dequantizer, 8>;
#endif
}
else if constexpr (std::is_same_v<Dequantizer, Q8_0_1_Unpacker> || std::is_same_v<Dequantizer, Q4_0_1_Unpacker> ||
std::is_same_v<Dequantizer, Q5_0_1_Unpacker> || std::is_same_v<Dequantizer, IQ4_NL_Unpacker> ||
std::is_same_v<Dequantizer, Q6_0_1_Unpacker>) {
std::is_same_v<Dequantizer, Q5_0_1_Unpacker> || std::is_same_v<Dequantizer, Q6_0_1_Unpacker>) {
m.funcs[0] = mul_mat_qX_1_q8_2_T<Dequantizer, 1>;
m.funcs[1] = mul_mat_qX_1_q8_2_T<Dequantizer, 2>;
m.funcs[2] = mul_mat_qX_1_q8_2_T<Dequantizer, 3>;
Expand Down Expand Up @@ -9476,7 +9517,11 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
case GGML_TYPE_IQ4_NL:
assert (ne00 % QK4_NL == 0);
MulMat::set_functions<IQ4_NL_Unpacker>(mm);
#ifdef HAVE_FANCY_SIMD
expected_typeB = GGML_TYPE_Q8_2_X4;
#else
expected_typeB = GGML_TYPE_Q8_0_X4;
#endif
break;
case GGML_TYPE_IQ4_NL_R4:
assert (ne00 % QK4_NL == 0);
Expand Down
203 changes: 199 additions & 4 deletions src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ enum llm_arch {
LLM_ARCH_JAIS,
LLM_ARCH_GRANITE = 46,
LLM_ARCH_GRANITE_MOE,
LLM_ARCH_COHERE2,
LLM_ARCH_UNKNOWN,
};

Expand Down Expand Up @@ -279,6 +280,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_JAIS, "jais" },
{ LLM_ARCH_GRANITE, "granite" },
{ LLM_ARCH_GRANITE_MOE, "granitemoe" },
{ LLM_ARCH_COHERE2, "cohere2" },
{ LLM_ARCH_UNKNOWN, "(unknown)" },
};

Expand Down Expand Up @@ -1456,7 +1458,21 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
},
},

{
LLM_ARCH_COHERE2,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_UNKNOWN,
{
Expand Down Expand Up @@ -2539,6 +2555,7 @@ struct llama_hparams {
if (this->n_layer != other.n_layer) return true;
if (this->n_rot != other.n_rot) return true;
if (this->n_swa != other.n_swa) return true;
if (this->n_swa_pattern != other.n_swa_pattern) return false;
if (this->n_embd_head_k != other.n_embd_head_k) return true;
if (this->n_embd_head_v != other.n_embd_head_v) return true;
if (this->n_expert != other.n_expert) return true;
Expand Down Expand Up @@ -5797,6 +5814,17 @@ static void llm_load_hparams(
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
case LLM_ARCH_COHERE2:
{
hparams.n_swa_pattern = 4;
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale);
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
switch (hparams.n_layer) {
case 32: model.type = e_model::MODEL_8B; break;
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
default: (void)0;
}

Expand Down Expand Up @@ -6406,6 +6434,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
LLAMA_LOG_INFO("%s: n_head_kv = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head_kv(il); }, hparams.n_layer).c_str());
LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot);
LLAMA_LOG_INFO("%s: n_swa = %u\n", __func__, hparams.n_swa);
LLAMA_LOG_INFO("%s: n_swa_pattern = %u\n", __func__, hparams.n_swa_pattern);
LLAMA_LOG_INFO("%s: n_embd_head_k = %u\n", __func__, hparams.n_embd_head_k);
LLAMA_LOG_INFO("%s: n_embd_head_v = %u\n", __func__, hparams.n_embd_head_v);
LLAMA_LOG_INFO("%s: n_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_gqa(il); }, hparams.n_layer).c_str());
Expand Down Expand Up @@ -8397,6 +8426,34 @@ static bool llm_load_tensors(
layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
}
} break;
case LLM_ARCH_COHERE2:
{
model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0);

// output
model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0);
// init output from the input tok embed
model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab },
llama_model_loader::TENSOR_DUPLICATED);

for (int i = 0; i < n_layer; ++i) {
auto & layer = model.layers[i];
ggml_context * ctx_layer = ctx_for_layer(i);
ggml_context * ctx_split = ctx_for_layer_split(i);

layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0);

layer.wq = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd }, 0);
layer.wk = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_gqa }, 0);
layer.wv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_gqa }, 0);
layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }, 0);

layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, 0);
layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0);
layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, 0);
}
}
break;
default:
throw std::runtime_error("unknown architecture");
}
Expand Down Expand Up @@ -9340,7 +9397,7 @@ static struct ggml_tensor * llm_build_kqv(
// For DeepSeek-2, it is perfectly fine with fp16 for PP, but I get gibberish when uding fp16 for TG.
// Not sure if it is really a matter of insufficient precision, or I have made a mistake in the fattn-vec-f16 kernel.
if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX ||
(model.arch == LLM_ARCH_DEEPSEEK2 && q->ne[1] <= 8)) {
(model.arch == LLM_ARCH_DEEPSEEK2 && q->ne[1] <= 8) || model.arch == LLM_ARCH_COHERE2) {
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
}
//ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
Expand All @@ -9364,7 +9421,8 @@ static struct ggml_tensor * llm_build_kqv(

//ggml_mul_mat_set_prec(kq, GGML_PREC_F32);

if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2) {
if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2 ||
model.arch == LLM_ARCH_COHERE2) {
// for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs
// ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847
ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
Expand Down Expand Up @@ -9423,7 +9481,8 @@ static struct ggml_tensor * llm_build_kqv(
auto k_i = ggml_view_3d(ctx, k, k->ne[0], k->ne[1], this_ne12, k->nb[1], k->nb[2], k->nb[2]*i02);
auto q_i = ggml_view_3d(ctx, q, q->ne[0], q->ne[1], this_ne12, q->nb[1], q->nb[2], q->nb[2]*i12);
auto kq_i = ggml_mul_mat(ctx, k_i, q_i);
if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2) {
if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2 ||
model.arch == LLM_ARCH_COHERE2) {
ggml_mul_mat_set_prec(kq_i, GGML_PREC_F32);
}
if (model.arch == LLM_ARCH_GROK) {
Expand Down Expand Up @@ -15013,6 +15072,137 @@ struct llm_build_context {
return gf;
}

struct ggml_cgraph * build_cohere2() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);

const int64_t n_embd_head = hparams.n_embd_head_v;
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
const float f_logit_scale = hparams.f_logit_scale;

struct ggml_tensor * cur;
struct ggml_tensor * inpL;

inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);

// inp_pos - contains the positions
struct ggml_tensor * inp_pos = build_inp_pos();

// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
// cohere2 requires different mask for layers using sliding window (SWA)
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
struct ggml_tensor * KQ_mask_swa = build_inp_KQ_mask_swa();

// sliding window switch pattern
const int32_t sliding_window_pattern = 4;

for (int il = 0; il < n_layer; ++il) {
// three layers sliding window attention (window size 4096) and ROPE
// fourth layer uses global attention without positional embeddings
const bool is_sliding = il % sliding_window_pattern < (sliding_window_pattern - 1);
struct ggml_tensor * KQ_mask_l = is_sliding ? KQ_mask_swa : KQ_mask;

// norm
cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, LLM_NORM, cb, il);
cb(cur, "attn_norm", il);
struct ggml_tensor * ffn_inp = cur;

// self-attention
{
// rope freq factors for 128k context
struct ggml_tensor * rope_factors = build_rope_factors(il);

// compute Q and K and RoPE them
struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
cb(Qcur, "Qcur", il);
if (model.layers[il].bq) {
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
cb(Qcur, "Qcur", il);
}

struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
cb(Kcur, "Kcur", il);
if (model.layers[il].bk) {
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
cb(Kcur, "Kcur", il);
}

struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
cb(Vcur, "Vcur", il);
if (model.layers[il].bv) {
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
cb(Vcur, "Vcur", il);
}

if (is_sliding) {
Qcur = ggml_rope_ext(ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor,
beta_fast, beta_slow);
cb(Qcur, "Qcur", il);

Kcur = ggml_rope_ext(ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor,
attn_factor, beta_fast, beta_slow);
cb(Kcur, "Kcur", il);
} else {
// For non-sliding layers, just reshape without applying RoPE
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
cb(Qcur, "Qcur", il);

Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
cb(Kcur, "Kcur", il);
}

cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, Kcur, Vcur, Qcur,
KQ_mask_l, n_tokens, kv_head, n_kv, 1.0f / sqrtf(float(n_embd_head)), cb, il);
}

if (il == n_layer - 1) {
// skip computing output for unused tokens
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
ffn_inp = ggml_get_rows(ctx0, ffn_inp, inp_out_ids);
}

struct ggml_tensor * attn_out = cur;

// feed-forward network
{
cur = llm_build_ffn(ctx0, lctx, ffn_inp, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate,
NULL, NULL, model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR,
cb, il);
cb(cur, "ffn_out", il);
}

// add together residual + FFN + self-attention
cur = ggml_add(ctx0, cur, inpL);
cur = ggml_add(ctx0, cur, attn_out);
cur = lctx.cvec.apply_to(ctx0, cur, il);
cb(cur, "l_out", il);

// input for next layer
inpL = cur;
}

cur = inpL;

cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, NULL, LLM_NORM, cb, -1);
cb(cur, "result_norm", -1);

// lm_head
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);

if (f_logit_scale) {
cur = ggml_scale(ctx0, cur, f_logit_scale);
}

cb(cur, "result_output", -1);

ggml_build_forward_expand(gf, cur);

return gf;
}

struct ggml_cgraph * build_t5_encoder() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);

Expand Down Expand Up @@ -15813,6 +16003,10 @@ static struct ggml_cgraph * llama_build_graph(
{
result = llm.build_bitnet_25();
} break;
case LLM_ARCH_COHERE2:
{
result = llm.build_cohere2();
} break;
case LLM_ARCH_T5:
{
if (lctx.is_encoding) {
Expand Down Expand Up @@ -19486,6 +19680,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
case LLM_ARCH_CHATGLM:
case LLM_ARCH_GRANITE:
case LLM_ARCH_GRANITE_MOE:
case LLM_ARCH_COHERE2:
return LLAMA_ROPE_TYPE_NORM;

// the pairs of head values are offset by n_rot/2
Expand Down