diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 03bc5ad0ba..487820a196 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -57,7 +57,7 @@ static bool ggml_is_power_of_2(int n) { } // orthonormal Walsh-Hadamard rotation matrix -static void set_input_hadamard(int n, float * data) { +static void set_input_hadamard(float * data, int n, int H) { assert(ggml_is_power_of_2(n)); data[0*n + 0] = 1.0 / sqrtf(n); @@ -73,6 +73,21 @@ static void set_input_hadamard(int n, float * data) { } } } + + srand(1242); + + // copy to other heads + for (int h = 1; h < H; h++) { + //memcpy(data + h*n*n, data + (h-1)*n*n, n*n*sizeof(float)); + + for (int i = 0; i < n; i++) { + float sgn = rand() % 2 ? 1.0f : -1.0f; + for (int j = 0; j < n; j++) { + data[h*n*n + j*n + i] = sgn*data[j*n + i]; + //data[h*n*n + (h-1)*n + j] *= sgn; + } + } + } } static ggml_tensor * ggml_rotate_hadamard( @@ -82,7 +97,8 @@ static ggml_tensor * ggml_rotate_hadamard( const auto n = rot->ne[0]; ggml_tensor * res; - res = ggml_reshape_2d(ctx, cur, n, ggml_nelements(cur)/n); + res = ggml_reshape_4d(ctx, cur, n, cur->ne[0]/(n), cur->ne[1], cur->ne[2]); + //res = ggml_reshape_3d(ctx, cur, n, ggml_nelements(cur)/(n*cur->ne[1]), cur->ne[1]); res = ggml_mul_mat(ctx, rot, res); res = ggml_reshape_4d(ctx, res, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]); @@ -472,7 +488,7 @@ void llm_graph_input_attn_kv::set_input(const llama_ubatch * ubatch) { float * data = (float *) self_rotk->data; - set_input_hadamard(self_rotk->ne[0], data); + set_input_hadamard(data, self_rotk->ne[0], self_rotk->ne[2]); } if (self_rotv) { @@ -480,7 +496,7 @@ void llm_graph_input_attn_kv::set_input(const llama_ubatch * ubatch) { float * data = (float *) self_rotv->data; - set_input_hadamard(self_rotv->ne[0], data); + set_input_hadamard(data, self_rotv->ne[0], self_rotv->ne[2]); } } @@ -535,7 +551,7 @@ void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) { float * data = (float *) self_rotk->data; - set_input_hadamard(self_rotk->ne[0], data); + set_input_hadamard(data, self_rotk->ne[0], self_rotk->ne[2]); } if (self_rotv) { @@ -543,7 +559,7 @@ void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) { float * data = (float *) self_rotv->data; - set_input_hadamard(self_rotv->ne[0], data); + set_input_hadamard(data, self_rotv->ne[0], self_rotv->ne[2]); } } @@ -606,7 +622,7 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) { float * data = (float *) inp_attn->self_rotk->data; - set_input_hadamard(inp_attn->self_rotk->ne[0], data); + set_input_hadamard(data, inp_attn->self_rotk->ne[0], inp_attn->self_rotk->ne[2]); } if (inp_attn->self_rotv) { @@ -614,7 +630,7 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) { float * data = (float *) inp_attn->self_rotv->data; - set_input_hadamard(inp_attn->self_rotv->ne[0], data); + set_input_hadamard(data, inp_attn->self_rotv->ne[0], inp_attn->self_rotv->ne[2]); } const int64_t n_rs = mctx->get_recr()->get_n_rs(); @@ -720,7 +736,7 @@ void llm_graph_input_mem_hybrid_iswa::set_input(const llama_ubatch * ubatch) { float * data = (float *) inp_attn->self_rotk->data; - set_input_hadamard(inp_attn->self_rotk->ne[0], data); + set_input_hadamard(data, inp_attn->self_rotk->ne[0], inp_attn->self_rotk->ne[2]); } if (inp_attn->self_rotv) { @@ -728,7 +744,7 @@ void llm_graph_input_mem_hybrid_iswa::set_input(const llama_ubatch * ubatch) { float * data = (float *) inp_attn->self_rotv->data; - set_input_hadamard(inp_attn->self_rotv->ne[0], data); + set_input_hadamard(data, inp_attn->self_rotv->ne[0], inp_attn->self_rotv->ne[2]); } const int64_t n_rs = mctx->get_recr()->get_n_rs(); @@ -2117,12 +2133,12 @@ static std::unique_ptr build_attn_inp_kv_impl( if (can_rotk) { int nrot = 64; - do { - nrot *= 2; - } while (hparams.n_embd_head_k() % nrot == 0); - nrot /= 2; + //do { + // nrot *= 2; + //} while (hparams.n_embd_head_k() % nrot == 0); + //nrot /= 2; - inp->self_rotk = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, nrot, nrot); + inp->self_rotk = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, nrot, nrot, hparams.n_head_kv()); ggml_set_input(inp->self_rotk); } else { inp->self_rotk = nullptr; @@ -2493,12 +2509,12 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const if (can_rotk) { int nrot = 64; - do { - nrot *= 2; - } while (hparams.n_embd_head_k() % nrot == 0); - nrot /= 2; + //do { + // nrot *= 2; + //} while (hparams.n_embd_head_k() % nrot == 0); + //nrot /= 2; - inp->self_rotk = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, nrot, nrot); + inp->self_rotk = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, nrot, nrot, hparams.n_head_kv()); ggml_set_input(inp->self_rotk); } else { inp->self_rotk = nullptr;