cont : rand hadamard matrices

This commit is contained in:
Georgi Gerganov 2026-03-27 20:11:47 +02:00
parent 7711b3a36a
commit f0fea264b0
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -57,7 +57,7 @@ static bool ggml_is_power_of_2(int n) {
}
// orthonormal Walsh-Hadamard rotation matrix
static void set_input_hadamard(int n, float * data) {
static void set_input_hadamard(float * data, int n, int H) {
assert(ggml_is_power_of_2(n));
data[0*n + 0] = 1.0 / sqrtf(n);
@ -73,6 +73,21 @@ static void set_input_hadamard(int n, float * data) {
}
}
}
srand(1242);
// copy to other heads
for (int h = 1; h < H; h++) {
//memcpy(data + h*n*n, data + (h-1)*n*n, n*n*sizeof(float));
for (int i = 0; i < n; i++) {
float sgn = rand() % 2 ? 1.0f : -1.0f;
for (int j = 0; j < n; j++) {
data[h*n*n + j*n + i] = sgn*data[j*n + i];
//data[h*n*n + (h-1)*n + j] *= sgn;
}
}
}
}
static ggml_tensor * ggml_rotate_hadamard(
@ -82,7 +97,8 @@ static ggml_tensor * ggml_rotate_hadamard(
const auto n = rot->ne[0];
ggml_tensor * res;
res = ggml_reshape_2d(ctx, cur, n, ggml_nelements(cur)/n);
res = ggml_reshape_4d(ctx, cur, n, cur->ne[0]/(n), cur->ne[1], cur->ne[2]);
//res = ggml_reshape_3d(ctx, cur, n, ggml_nelements(cur)/(n*cur->ne[1]), cur->ne[1]);
res = ggml_mul_mat(ctx, rot, res);
res = ggml_reshape_4d(ctx, res, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]);
@ -472,7 +488,7 @@ void llm_graph_input_attn_kv::set_input(const llama_ubatch * ubatch) {
float * data = (float *) self_rotk->data;
set_input_hadamard(self_rotk->ne[0], data);
set_input_hadamard(data, self_rotk->ne[0], self_rotk->ne[2]);
}
if (self_rotv) {
@ -480,7 +496,7 @@ void llm_graph_input_attn_kv::set_input(const llama_ubatch * ubatch) {
float * data = (float *) self_rotv->data;
set_input_hadamard(self_rotv->ne[0], data);
set_input_hadamard(data, self_rotv->ne[0], self_rotv->ne[2]);
}
}
@ -535,7 +551,7 @@ void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) {
float * data = (float *) self_rotk->data;
set_input_hadamard(self_rotk->ne[0], data);
set_input_hadamard(data, self_rotk->ne[0], self_rotk->ne[2]);
}
if (self_rotv) {
@ -543,7 +559,7 @@ void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) {
float * data = (float *) self_rotv->data;
set_input_hadamard(self_rotv->ne[0], data);
set_input_hadamard(data, self_rotv->ne[0], self_rotv->ne[2]);
}
}
@ -606,7 +622,7 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
float * data = (float *) inp_attn->self_rotk->data;
set_input_hadamard(inp_attn->self_rotk->ne[0], data);
set_input_hadamard(data, inp_attn->self_rotk->ne[0], inp_attn->self_rotk->ne[2]);
}
if (inp_attn->self_rotv) {
@ -614,7 +630,7 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
float * data = (float *) inp_attn->self_rotv->data;
set_input_hadamard(inp_attn->self_rotv->ne[0], data);
set_input_hadamard(data, inp_attn->self_rotv->ne[0], inp_attn->self_rotv->ne[2]);
}
const int64_t n_rs = mctx->get_recr()->get_n_rs();
@ -720,7 +736,7 @@ void llm_graph_input_mem_hybrid_iswa::set_input(const llama_ubatch * ubatch) {
float * data = (float *) inp_attn->self_rotk->data;
set_input_hadamard(inp_attn->self_rotk->ne[0], data);
set_input_hadamard(data, inp_attn->self_rotk->ne[0], inp_attn->self_rotk->ne[2]);
}
if (inp_attn->self_rotv) {
@ -728,7 +744,7 @@ void llm_graph_input_mem_hybrid_iswa::set_input(const llama_ubatch * ubatch) {
float * data = (float *) inp_attn->self_rotv->data;
set_input_hadamard(inp_attn->self_rotv->ne[0], data);
set_input_hadamard(data, inp_attn->self_rotv->ne[0], inp_attn->self_rotv->ne[2]);
}
const int64_t n_rs = mctx->get_recr()->get_n_rs();
@ -2117,12 +2133,12 @@ static std::unique_ptr<llm_graph_input_attn_kv> build_attn_inp_kv_impl(
if (can_rotk) {
int nrot = 64;
do {
nrot *= 2;
} while (hparams.n_embd_head_k() % nrot == 0);
nrot /= 2;
//do {
// nrot *= 2;
//} while (hparams.n_embd_head_k() % nrot == 0);
//nrot /= 2;
inp->self_rotk = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, nrot, nrot);
inp->self_rotk = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, nrot, nrot, hparams.n_head_kv());
ggml_set_input(inp->self_rotk);
} else {
inp->self_rotk = nullptr;
@ -2493,12 +2509,12 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
if (can_rotk) {
int nrot = 64;
do {
nrot *= 2;
} while (hparams.n_embd_head_k() % nrot == 0);
nrot /= 2;
//do {
// nrot *= 2;
//} while (hparams.n_embd_head_k() % nrot == 0);
//nrot /= 2;
inp->self_rotk = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, nrot, nrot);
inp->self_rotk = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, nrot, nrot, hparams.n_head_kv());
ggml_set_input(inp->self_rotk);
} else {
inp->self_rotk = nullptr;