Spaces:
Running
vulkan: Add fusion support for RMS_NORM+MUL (llama/14366)
Browse files* vulkan: Add fusion support for RMS_NORM+MUL
- Add a use_count to ggml_tensor, so we can detect if an output is used more than once.
- Change the ggml-vulkan rms_norm shader to optionally multiply by another tensor.
- Add detection logic and basic fusion logic in ggml-vulkan.
- Add some testing support for fusion. Rather than computing one node at a time, allow
for computing the whole graph and just testing one node's results. Add rms_norm_mul tests
and enable a llama test.
* extract some common fusion logic
* fix -Winconsistent-missing-override
* move ggml_can_fuse to a common function
* build fix
* C and C++ versions of can_fuse
* move use count to the graph to avoid data races and double increments when used in multiple threads
* use hash table lookup to find node index
* change use_counts to be indexed by hash table slot
* minimize hash lookups
style fixes
* last node doesn't need single use.
fix type.
handle mul operands being swapped.
* remove redundant parameter
---------
Co-authored-by: slaren <[email protected]>
|
@@ -339,7 +339,7 @@ extern "C" {
|
|
| 339 |
typedef bool (*ggml_backend_eval_callback)(int node_index, struct ggml_tensor * t1, struct ggml_tensor * t2, void * user_data);
|
| 340 |
|
| 341 |
// Compare the output of two backends
|
| 342 |
-
GGML_API bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data);
|
| 343 |
|
| 344 |
// Tensor initialization
|
| 345 |
GGML_API enum ggml_status ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr);
|
|
|
|
| 339 |
typedef bool (*ggml_backend_eval_callback)(int node_index, struct ggml_tensor * t1, struct ggml_tensor * t2, void * user_data);
|
| 340 |
|
| 341 |
// Compare the output of two backends
|
| 342 |
+
GGML_API bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data, struct ggml_tensor * test_node);
|
| 343 |
|
| 344 |
// Tensor initialization
|
| 345 |
GGML_API enum ggml_status ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr);
|
|
@@ -817,8 +817,9 @@ static void ggml_backend_sched_print_assignments(ggml_backend_sched_t sched, str
|
|
| 817 |
}
|
| 818 |
if (sched->debug > 1) {
|
| 819 |
ggml_backend_t tensor_backend = ggml_backend_sched_get_tensor_backend(sched, node);
|
| 820 |
-
GGML_LOG_DEBUG("node #%3d (%10.10s): %20.20s (%5.5s) [%5.5s %8.8s]:", i, ggml_op_name(node->op), node->name,
|
| 821 |
-
fmt_size(ggml_nbytes(node)), tensor_backend ? ggml_backend_name(tensor_backend) : "NULL", GET_CAUSE(node)
|
|
|
|
| 822 |
for (int j = 0; j < GGML_MAX_SRC; j++) {
|
| 823 |
struct ggml_tensor * src = node->src[j];
|
| 824 |
if (src == NULL) {
|
|
@@ -1826,7 +1827,7 @@ void ggml_backend_graph_copy_free(struct ggml_backend_graph_copy copy) {
|
|
| 1826 |
ggml_free(copy.ctx_unallocated);
|
| 1827 |
}
|
| 1828 |
|
| 1829 |
-
bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data) {
|
| 1830 |
struct ggml_backend_graph_copy copy = ggml_backend_graph_copy(backend2, graph);
|
| 1831 |
if (copy.buffer == NULL) {
|
| 1832 |
return false;
|
|
@@ -1837,28 +1838,45 @@ bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t
|
|
| 1837 |
|
| 1838 |
assert(g1->n_nodes == g2->n_nodes);
|
| 1839 |
|
| 1840 |
-
|
| 1841 |
-
|
| 1842 |
-
|
|
|
|
| 1843 |
|
| 1844 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1845 |
|
| 1846 |
-
|
| 1847 |
-
|
|
|
|
|
|
|
|
|
|
| 1848 |
|
| 1849 |
-
|
| 1850 |
-
ggml_backend_graph_compute(backend2, &g2v);
|
| 1851 |
|
| 1852 |
-
|
| 1853 |
-
|
| 1854 |
-
}
|
| 1855 |
|
| 1856 |
-
|
| 1857 |
-
|
| 1858 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1859 |
}
|
| 1860 |
}
|
| 1861 |
-
|
| 1862 |
ggml_backend_graph_copy_free(copy);
|
| 1863 |
|
| 1864 |
return true;
|
|
|
|
| 817 |
}
|
| 818 |
if (sched->debug > 1) {
|
| 819 |
ggml_backend_t tensor_backend = ggml_backend_sched_get_tensor_backend(sched, node);
|
| 820 |
+
GGML_LOG_DEBUG("node #%3d (%10.10s): %20.20s (%5.5s) [%5.5s %8.8s] use=%d:", i, ggml_op_name(node->op), node->name,
|
| 821 |
+
fmt_size(ggml_nbytes(node)), tensor_backend ? ggml_backend_name(tensor_backend) : "NULL", GET_CAUSE(node),
|
| 822 |
+
graph->use_counts[ggml_hash_find(&graph->visited_hash_set, node)]);
|
| 823 |
for (int j = 0; j < GGML_MAX_SRC; j++) {
|
| 824 |
struct ggml_tensor * src = node->src[j];
|
| 825 |
if (src == NULL) {
|
|
|
|
| 1827 |
ggml_free(copy.ctx_unallocated);
|
| 1828 |
}
|
| 1829 |
|
| 1830 |
+
bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data, struct ggml_tensor * test_node) {
|
| 1831 |
struct ggml_backend_graph_copy copy = ggml_backend_graph_copy(backend2, graph);
|
| 1832 |
if (copy.buffer == NULL) {
|
| 1833 |
return false;
|
|
|
|
| 1838 |
|
| 1839 |
assert(g1->n_nodes == g2->n_nodes);
|
| 1840 |
|
| 1841 |
+
if (test_node != nullptr) {
|
| 1842 |
+
// Compute the whole graph and only test the output for a specific tensor
|
| 1843 |
+
ggml_backend_graph_compute(backend1, g1);
|
| 1844 |
+
ggml_backend_graph_compute(backend2, g2);
|
| 1845 |
|
| 1846 |
+
int test_node_idx = -1;
|
| 1847 |
+
for (int i = 0; i < g1->n_nodes; i++) {
|
| 1848 |
+
struct ggml_tensor * t1 = g1->nodes[i];
|
| 1849 |
+
if (t1 == test_node) {
|
| 1850 |
+
test_node_idx = i;
|
| 1851 |
+
break;
|
| 1852 |
+
}
|
| 1853 |
+
}
|
| 1854 |
+
GGML_ASSERT(test_node_idx != -1);
|
| 1855 |
|
| 1856 |
+
callback(test_node_idx, g1->nodes[test_node_idx], g2->nodes[test_node_idx], user_data);
|
| 1857 |
+
} else {
|
| 1858 |
+
for (int i = 0; i < g1->n_nodes; i++) {
|
| 1859 |
+
struct ggml_tensor * t1 = g1->nodes[i];
|
| 1860 |
+
struct ggml_tensor * t2 = g2->nodes[i];
|
| 1861 |
|
| 1862 |
+
assert(t1->op == t2->op && ggml_are_same_layout(t1, t2));
|
|
|
|
| 1863 |
|
| 1864 |
+
struct ggml_cgraph g1v = ggml_graph_view(g1, i, i + 1);
|
| 1865 |
+
struct ggml_cgraph g2v = ggml_graph_view(g2, i, i + 1);
|
|
|
|
| 1866 |
|
| 1867 |
+
ggml_backend_graph_compute(backend1, &g1v);
|
| 1868 |
+
ggml_backend_graph_compute(backend2, &g2v);
|
| 1869 |
+
|
| 1870 |
+
if (ggml_is_view_op(t1->op)) {
|
| 1871 |
+
continue;
|
| 1872 |
+
}
|
| 1873 |
+
|
| 1874 |
+
// compare results, calculate rms etc
|
| 1875 |
+
if (!callback(i, t1, t2, user_data)) {
|
| 1876 |
+
break;
|
| 1877 |
+
}
|
| 1878 |
}
|
| 1879 |
}
|
|
|
|
| 1880 |
ggml_backend_graph_copy_free(copy);
|
| 1881 |
|
| 1882 |
return true;
|
|
@@ -301,6 +301,7 @@ struct ggml_cgraph {
|
|
| 301 |
struct ggml_tensor ** grads; // the outputs of these tensors are the gradients of the nodes
|
| 302 |
struct ggml_tensor ** grad_accs; // accumulators for node gradients
|
| 303 |
struct ggml_tensor ** leafs; // tensors with constant data
|
|
|
|
| 304 |
|
| 305 |
struct ggml_hash_set visited_hash_set;
|
| 306 |
|
|
@@ -467,13 +468,76 @@ static inline ggml_bf16_t ggml_compute_fp32_to_bf16(float s) {
|
|
| 467 |
#define GGML_FP32_TO_BF16(x) ggml_compute_fp32_to_bf16(x)
|
| 468 |
#define GGML_BF16_TO_FP32(x) ggml_compute_bf16_to_fp32(x)
|
| 469 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 470 |
#ifdef __cplusplus
|
| 471 |
}
|
| 472 |
#endif
|
| 473 |
|
| 474 |
#ifdef __cplusplus
|
|
|
|
| 475 |
#include <vector>
|
| 476 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 477 |
// expose GGUF internals for test code
|
| 478 |
GGML_API size_t gguf_type_size(enum gguf_type type);
|
| 479 |
GGML_API struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params);
|
|
|
|
| 301 |
struct ggml_tensor ** grads; // the outputs of these tensors are the gradients of the nodes
|
| 302 |
struct ggml_tensor ** grad_accs; // accumulators for node gradients
|
| 303 |
struct ggml_tensor ** leafs; // tensors with constant data
|
| 304 |
+
int32_t * use_counts;// number of uses of each tensor, indexed by hash table slot
|
| 305 |
|
| 306 |
struct ggml_hash_set visited_hash_set;
|
| 307 |
|
|
|
|
| 468 |
#define GGML_FP32_TO_BF16(x) ggml_compute_fp32_to_bf16(x)
|
| 469 |
#define GGML_BF16_TO_FP32(x) ggml_compute_bf16_to_fp32(x)
|
| 470 |
|
| 471 |
+
// return true if the node's results are only used by N other nodes
|
| 472 |
+
// and can be fused into their calculations.
|
| 473 |
+
static inline bool ggml_node_has_n_uses(const struct ggml_cgraph * cgraph, int node_idx, int32_t n_uses) {
|
| 474 |
+
const struct ggml_tensor * node = cgraph->nodes[node_idx];
|
| 475 |
+
|
| 476 |
+
// check the use count against how many we're replacing
|
| 477 |
+
size_t hash_pos = ggml_hash_find(&cgraph->visited_hash_set, node);
|
| 478 |
+
if (!ggml_bitset_get(cgraph->visited_hash_set.used, hash_pos) || cgraph->use_counts[hash_pos] != n_uses) {
|
| 479 |
+
return false;
|
| 480 |
+
}
|
| 481 |
+
|
| 482 |
+
// if node is a view, some other node might be using the intermediate result
|
| 483 |
+
// via the view source.
|
| 484 |
+
if (node->view_src) {
|
| 485 |
+
return false;
|
| 486 |
+
}
|
| 487 |
+
|
| 488 |
+
// If the user requested output for the node, can't fuse
|
| 489 |
+
if (node->flags & GGML_TENSOR_FLAG_OUTPUT) {
|
| 490 |
+
return false;
|
| 491 |
+
}
|
| 492 |
+
|
| 493 |
+
return true;
|
| 494 |
+
}
|
| 495 |
+
|
| 496 |
+
// Returns true if nodes [i, i+ops.size()) are the sequence of ggml_ops in ops[]
|
| 497 |
+
// and are fusable. Nodes are considered fusable according to this function if:
|
| 498 |
+
// - all nodes except the last have only one use and are not views/outputs (see ggml_node_has_N_uses).
|
| 499 |
+
// - all nodes except the last are a src of the following node.
|
| 500 |
+
// - all nodes are the same shape.
|
| 501 |
+
// TODO: Consider allowing GGML_OP_NONE nodes in between
|
| 502 |
+
static inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, const enum ggml_op * ops, int num_ops) {
|
| 503 |
+
if (node_idx + num_ops > cgraph->n_nodes) {
|
| 504 |
+
return false;
|
| 505 |
+
}
|
| 506 |
+
|
| 507 |
+
for (int i = 0; i < num_ops; ++i) {
|
| 508 |
+
struct ggml_tensor * node = cgraph->nodes[node_idx + i];
|
| 509 |
+
if (node->op != ops[i]) {
|
| 510 |
+
return false;
|
| 511 |
+
}
|
| 512 |
+
if (i < num_ops - 1 && !ggml_node_has_n_uses(cgraph, node_idx + i, 1)) {
|
| 513 |
+
return false;
|
| 514 |
+
}
|
| 515 |
+
if (i > 0) {
|
| 516 |
+
struct ggml_tensor * prev = cgraph->nodes[node_idx + i - 1];
|
| 517 |
+
if (node->src[0] != prev && node->src[1] != prev) {
|
| 518 |
+
return false;
|
| 519 |
+
}
|
| 520 |
+
if (!ggml_are_same_shape(node, prev)) {
|
| 521 |
+
return false;
|
| 522 |
+
}
|
| 523 |
+
}
|
| 524 |
+
}
|
| 525 |
+
return true;
|
| 526 |
+
}
|
| 527 |
+
|
| 528 |
#ifdef __cplusplus
|
| 529 |
}
|
| 530 |
#endif
|
| 531 |
|
| 532 |
#ifdef __cplusplus
|
| 533 |
+
#include <initializer_list>
|
| 534 |
#include <vector>
|
| 535 |
|
| 536 |
+
// nicer C++ syntax for ggml_can_fuse
|
| 537 |
+
inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops) {
|
| 538 |
+
return ggml_can_fuse(cgraph, node_idx, ops.begin(), (int)ops.size());
|
| 539 |
+
}
|
| 540 |
+
|
| 541 |
// expose GGUF internals for test code
|
| 542 |
GGML_API size_t gguf_type_size(enum gguf_type type);
|
| 543 |
GGML_API struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params);
|
|
@@ -425,6 +425,7 @@ struct vk_device_struct {
|
|
| 425 |
vk_pipeline pipeline_norm_f32;
|
| 426 |
vk_pipeline pipeline_group_norm_f32;
|
| 427 |
vk_pipeline pipeline_rms_norm_f32;
|
|
|
|
| 428 |
vk_pipeline pipeline_rms_norm_back_f32;
|
| 429 |
vk_pipeline pipeline_l2_norm_f32;
|
| 430 |
|
|
@@ -978,6 +979,10 @@ struct ggml_backend_vk_context {
|
|
| 978 |
|
| 979 |
vk_command_pool compute_cmd_pool;
|
| 980 |
vk_command_pool transfer_cmd_pool;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 981 |
};
|
| 982 |
|
| 983 |
static void * const vk_ptr_base = (void *)(uintptr_t) 0x1000; // NOLINT
|
|
@@ -2655,7 +2660,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
| 2655 |
|
| 2656 |
ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
| 2657 |
ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
| 2658 |
-
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main",
|
|
|
|
| 2659 |
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
| 2660 |
ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
| 2661 |
|
|
@@ -6430,7 +6436,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
| 6430 |
return nullptr;
|
| 6431 |
case GGML_OP_RMS_NORM:
|
| 6432 |
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
| 6433 |
-
return ctx->device->pipeline_rms_norm_f32;
|
| 6434 |
}
|
| 6435 |
return nullptr;
|
| 6436 |
case GGML_OP_RMS_NORM_BACK:
|
|
@@ -7530,18 +7536,19 @@ static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
| 7530 |
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f }, dryrun);
|
| 7531 |
}
|
| 7532 |
|
| 7533 |
-
static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
| 7534 |
float * op_params = (float *)dst->op_params;
|
| 7535 |
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
|
|
|
| 7536 |
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
| 7537 |
|
| 7538 |
-
ggml_vk_op_f32<
|
| 7539 |
(uint32_t)ggml_nelements(src0),
|
| 7540 |
-
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],
|
| 7541 |
-
(uint32_t)
|
|
|
|
| 7542 |
0,
|
| 7543 |
-
op_params[0], 0.0f,
|
| 7544 |
-
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
| 7545 |
}, dryrun);
|
| 7546 |
}
|
| 7547 |
|
|
@@ -8736,7 +8743,8 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_tensor* t
|
|
| 8736 |
|
| 8737 |
// Returns true if node has enqueued work into the queue, false otherwise
|
| 8738 |
// If submit is true the current all operations queued so far are being submitted to Vulkan to overlap cmdlist creation and GPU execution.
|
| 8739 |
-
static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx,
|
|
|
|
| 8740 |
if (ggml_is_empty(node) || !node->buffer) {
|
| 8741 |
return false;
|
| 8742 |
}
|
|
@@ -8974,8 +8982,14 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
| 8974 |
|
| 8975 |
break;
|
| 8976 |
case GGML_OP_RMS_NORM:
|
| 8977 |
-
|
| 8978 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8979 |
break;
|
| 8980 |
case GGML_OP_RMS_NORM_BACK:
|
| 8981 |
ggml_vk_rms_norm_back(ctx, compute_ctx, src0, src1, node, dryrun);
|
|
@@ -9710,10 +9724,15 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|
| 9710 |
|
| 9711 |
uint64_t total_mat_mul_bytes = 0;
|
| 9712 |
for (int i = 0; i < cgraph->n_nodes; i++) {
|
| 9713 |
-
|
|
|
|
|
|
|
|
|
|
| 9714 |
if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) {
|
| 9715 |
total_mat_mul_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
|
| 9716 |
}
|
|
|
|
|
|
|
| 9717 |
}
|
| 9718 |
if (ctx->device->need_compiles) {
|
| 9719 |
ggml_vk_load_shaders(ctx->device);
|
|
@@ -9775,14 +9794,18 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|
| 9775 |
mul_mat_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
|
| 9776 |
}
|
| 9777 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9778 |
// Signal the almost_ready fence when the graph is mostly complete (< 20% remaining)
|
| 9779 |
bool almost_ready = (cgraph->n_nodes - i) < cgraph->n_nodes / 5;
|
| 9780 |
bool submit = (submitted_nodes >= nodes_per_submit) ||
|
| 9781 |
(mul_mat_bytes >= mul_mat_bytes_per_submit) ||
|
| 9782 |
-
(i == last_node) ||
|
| 9783 |
(almost_ready && !ctx->almost_ready_fence_pending);
|
| 9784 |
|
| 9785 |
-
bool enqueued = ggml_vk_build_graph(ctx, cgraph
|
| 9786 |
|
| 9787 |
if (vk_perf_logger_enabled) {
|
| 9788 |
if (ctx->compute_ctx.expired()) {
|
|
@@ -9792,7 +9815,10 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|
| 9792 |
} else {
|
| 9793 |
compute_ctx = ctx->compute_ctx.lock();
|
| 9794 |
}
|
| 9795 |
-
|
|
|
|
|
|
|
|
|
|
| 9796 |
}
|
| 9797 |
|
| 9798 |
if (enqueued) {
|
|
@@ -9814,6 +9840,8 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|
| 9814 |
}
|
| 9815 |
submit_count++;
|
| 9816 |
}
|
|
|
|
|
|
|
| 9817 |
}
|
| 9818 |
|
| 9819 |
if (vk_perf_logger_enabled) {
|
|
|
|
| 425 |
vk_pipeline pipeline_norm_f32;
|
| 426 |
vk_pipeline pipeline_group_norm_f32;
|
| 427 |
vk_pipeline pipeline_rms_norm_f32;
|
| 428 |
+
vk_pipeline pipeline_rms_norm_mul_f32;
|
| 429 |
vk_pipeline pipeline_rms_norm_back_f32;
|
| 430 |
vk_pipeline pipeline_l2_norm_f32;
|
| 431 |
|
|
|
|
| 979 |
|
| 980 |
vk_command_pool compute_cmd_pool;
|
| 981 |
vk_command_pool transfer_cmd_pool;
|
| 982 |
+
|
| 983 |
+
// number of additional consecutive nodes that are being fused with the
|
| 984 |
+
// node currently being processed
|
| 985 |
+
uint32_t num_additional_fused_ops {};
|
| 986 |
};
|
| 987 |
|
| 988 |
static void * const vk_ptr_base = (void *)(uintptr_t) 0x1000; // NOLINT
|
|
|
|
| 2660 |
|
| 2661 |
ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
| 2662 |
ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
| 2663 |
+
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1);
|
| 2664 |
+
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_f32, "rms_norm_mul_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1);
|
| 2665 |
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
| 2666 |
ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
| 2667 |
|
|
|
|
| 6436 |
return nullptr;
|
| 6437 |
case GGML_OP_RMS_NORM:
|
| 6438 |
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
| 6439 |
+
return ctx->num_additional_fused_ops > 0 ? ctx->device->pipeline_rms_norm_mul_f32 : ctx->device->pipeline_rms_norm_f32;
|
| 6440 |
}
|
| 6441 |
return nullptr;
|
| 6442 |
case GGML_OP_RMS_NORM_BACK:
|
|
|
|
| 7536 |
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f }, dryrun);
|
| 7537 |
}
|
| 7538 |
|
| 7539 |
+
static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
| 7540 |
float * op_params = (float *)dst->op_params;
|
| 7541 |
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
| 7542 |
+
const uint32_t src1_type_size = ggml_type_size(src1->type);
|
| 7543 |
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
| 7544 |
|
| 7545 |
+
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM, {
|
| 7546 |
(uint32_t)ggml_nelements(src0),
|
| 7547 |
+
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
| 7548 |
+
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
|
| 7549 |
+
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
| 7550 |
0,
|
| 7551 |
+
op_params[0], 0.0f, 0,
|
|
|
|
| 7552 |
}, dryrun);
|
| 7553 |
}
|
| 7554 |
|
|
|
|
| 8743 |
|
| 8744 |
// Returns true if node has enqueued work into the queue, false otherwise
|
| 8745 |
// If submit is true the current all operations queued so far are being submitted to Vulkan to overlap cmdlist creation and GPU execution.
|
| 8746 |
+
static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int node_idx, ggml_tensor *node_begin, int node_idx_begin, bool dryrun, bool last_node, bool almost_ready, bool submit){
|
| 8747 |
+
ggml_tensor * node = cgraph->nodes[node_idx];
|
| 8748 |
if (ggml_is_empty(node) || !node->buffer) {
|
| 8749 |
return false;
|
| 8750 |
}
|
|
|
|
| 8982 |
|
| 8983 |
break;
|
| 8984 |
case GGML_OP_RMS_NORM:
|
| 8985 |
+
if (ctx->num_additional_fused_ops > 0) {
|
| 8986 |
+
// fused rms_norm + mul
|
| 8987 |
+
ggml_tensor *mul = cgraph->nodes[node_idx + 1];
|
| 8988 |
+
ggml_tensor *other_src = mul->src[0] == node ? mul->src[1] : mul->src[0];
|
| 8989 |
+
ggml_vk_rms_norm(ctx, compute_ctx, src0, other_src, mul, dryrun);
|
| 8990 |
+
} else {
|
| 8991 |
+
ggml_vk_rms_norm(ctx, compute_ctx, src0, src0, node, dryrun);
|
| 8992 |
+
}
|
| 8993 |
break;
|
| 8994 |
case GGML_OP_RMS_NORM_BACK:
|
| 8995 |
ggml_vk_rms_norm_back(ctx, compute_ctx, src0, src1, node, dryrun);
|
|
|
|
| 9724 |
|
| 9725 |
uint64_t total_mat_mul_bytes = 0;
|
| 9726 |
for (int i = 0; i < cgraph->n_nodes; i++) {
|
| 9727 |
+
if (ggml_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
|
| 9728 |
+
ctx->num_additional_fused_ops = 1;
|
| 9729 |
+
}
|
| 9730 |
+
ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false);
|
| 9731 |
if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) {
|
| 9732 |
total_mat_mul_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
|
| 9733 |
}
|
| 9734 |
+
i += ctx->num_additional_fused_ops;
|
| 9735 |
+
ctx->num_additional_fused_ops = 0;
|
| 9736 |
}
|
| 9737 |
if (ctx->device->need_compiles) {
|
| 9738 |
ggml_vk_load_shaders(ctx->device);
|
|
|
|
| 9794 |
mul_mat_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
|
| 9795 |
}
|
| 9796 |
|
| 9797 |
+
if (ggml_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
|
| 9798 |
+
ctx->num_additional_fused_ops = 1;
|
| 9799 |
+
}
|
| 9800 |
+
|
| 9801 |
// Signal the almost_ready fence when the graph is mostly complete (< 20% remaining)
|
| 9802 |
bool almost_ready = (cgraph->n_nodes - i) < cgraph->n_nodes / 5;
|
| 9803 |
bool submit = (submitted_nodes >= nodes_per_submit) ||
|
| 9804 |
(mul_mat_bytes >= mul_mat_bytes_per_submit) ||
|
| 9805 |
+
(i + ctx->num_additional_fused_ops == last_node) ||
|
| 9806 |
(almost_ready && !ctx->almost_ready_fence_pending);
|
| 9807 |
|
| 9808 |
+
bool enqueued = ggml_vk_build_graph(ctx, cgraph, i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i + ctx->num_additional_fused_ops == last_node, almost_ready, submit);
|
| 9809 |
|
| 9810 |
if (vk_perf_logger_enabled) {
|
| 9811 |
if (ctx->compute_ctx.expired()) {
|
|
|
|
| 9815 |
} else {
|
| 9816 |
compute_ctx = ctx->compute_ctx.lock();
|
| 9817 |
}
|
| 9818 |
+
// If there are fused ops, just write out timestamps for all nodes to keep the accounting simple
|
| 9819 |
+
for (int j = 0; j < ctx->num_additional_fused_ops + 1; ++j) {
|
| 9820 |
+
compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->device->query_pool, i+j+1);
|
| 9821 |
+
}
|
| 9822 |
}
|
| 9823 |
|
| 9824 |
if (enqueued) {
|
|
|
|
| 9840 |
}
|
| 9841 |
submit_count++;
|
| 9842 |
}
|
| 9843 |
+
i += ctx->num_additional_fused_ops;
|
| 9844 |
+
ctx->num_additional_fused_ops = 0;
|
| 9845 |
}
|
| 9846 |
|
| 9847 |
if (vk_perf_logger_enabled) {
|
|
@@ -1,11 +1,13 @@
|
|
| 1 |
#version 450
|
| 2 |
|
| 3 |
-
#include "
|
| 4 |
#include "types.comp"
|
| 5 |
|
| 6 |
#extension GL_EXT_control_flow_attributes : enable
|
| 7 |
#define BLOCK_SIZE 512
|
| 8 |
|
|
|
|
|
|
|
| 9 |
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
|
| 10 |
|
| 11 |
shared FLOAT_TYPE sum[BLOCK_SIZE];
|
|
@@ -25,6 +27,7 @@ void main() {
|
|
| 25 |
const uint stride_sample = p.nb03;
|
| 26 |
|
| 27 |
uint32_t a_offset = samp*stride_sample + channel*stride_channel + row*stride_row + get_aoffset();
|
|
|
|
| 28 |
uint32_t d_offset = ((samp*nchannels + channel)*nrows + row)*ncols + get_doffset();
|
| 29 |
|
| 30 |
sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp
|
|
@@ -46,7 +49,13 @@ void main() {
|
|
| 46 |
const FLOAT_TYPE mean = sum[0] / FLOAT_TYPE(ncols);
|
| 47 |
const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1));
|
| 48 |
|
| 49 |
-
|
| 50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
}
|
| 52 |
}
|
|
|
|
| 1 |
#version 450
|
| 2 |
|
| 3 |
+
#include "generic_binary_head.comp"
|
| 4 |
#include "types.comp"
|
| 5 |
|
| 6 |
#extension GL_EXT_control_flow_attributes : enable
|
| 7 |
#define BLOCK_SIZE 512
|
| 8 |
|
| 9 |
+
layout (constant_id = 1) const bool do_multiply = false;
|
| 10 |
+
|
| 11 |
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
|
| 12 |
|
| 13 |
shared FLOAT_TYPE sum[BLOCK_SIZE];
|
|
|
|
| 27 |
const uint stride_sample = p.nb03;
|
| 28 |
|
| 29 |
uint32_t a_offset = samp*stride_sample + channel*stride_channel + row*stride_row + get_aoffset();
|
| 30 |
+
uint32_t b_offset = src1_idx(0, row, channel, samp) + get_boffset();
|
| 31 |
uint32_t d_offset = ((samp*nchannels + channel)*nrows + row)*ncols + get_doffset();
|
| 32 |
|
| 33 |
sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp
|
|
|
|
| 49 |
const FLOAT_TYPE mean = sum[0] / FLOAT_TYPE(ncols);
|
| 50 |
const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1));
|
| 51 |
|
| 52 |
+
if (do_multiply) {
|
| 53 |
+
[[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
|
| 54 |
+
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + col]));
|
| 55 |
+
}
|
| 56 |
+
} else {
|
| 57 |
+
[[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
|
| 58 |
+
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]));
|
| 59 |
+
}
|
| 60 |
}
|
| 61 |
}
|
|
@@ -497,7 +497,7 @@ void process_shaders() {
|
|
| 497 |
// Norms
|
| 498 |
string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
| 499 |
string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
| 500 |
-
string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
| 501 |
string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
|
| 502 |
string_to_spv("l2_norm_f32", "l2_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
| 503 |
|
|
|
|
| 497 |
// Norms
|
| 498 |
string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
| 499 |
string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
| 500 |
+
string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
|
| 501 |
string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
|
| 502 |
string_to_spv("l2_norm_f32", "l2_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
| 503 |
|
|
@@ -5850,19 +5850,32 @@ static void ggml_compute_backward(
|
|
| 5850 |
GGML_ASSERT(!src2_needs_grads || ggml_are_same_shape(src2, cgraph->grads[isrc2]));
|
| 5851 |
}
|
| 5852 |
|
| 5853 |
-
static
|
| 5854 |
// check if already visited
|
| 5855 |
-
|
| 5856 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5857 |
}
|
| 5858 |
|
| 5859 |
for (int i = 0; i < GGML_MAX_SRC; ++i) {
|
| 5860 |
const int k =
|
| 5861 |
(cgraph->order == GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT) ? i :
|
| 5862 |
(cgraph->order == GGML_CGRAPH_EVAL_ORDER_RIGHT_TO_LEFT) ? (GGML_MAX_SRC-1-i) :
|
| 5863 |
-
/* unknown order, just fall back to using i*/ i;
|
| 5864 |
-
|
| 5865 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5866 |
}
|
| 5867 |
}
|
| 5868 |
|
|
@@ -5886,6 +5899,8 @@ static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor *
|
|
| 5886 |
cgraph->nodes[cgraph->n_nodes] = node;
|
| 5887 |
cgraph->n_nodes++;
|
| 5888 |
}
|
|
|
|
|
|
|
| 5889 |
}
|
| 5890 |
|
| 5891 |
static void ggml_build_forward_impl(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor, bool expand) {
|
|
@@ -6023,6 +6038,7 @@ static size_t ggml_graph_nbytes(size_t size, bool grads) {
|
|
| 6023 |
incr_ptr_aligned(&p, sizeof(struct ggml_cgraph), 1);
|
| 6024 |
incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // nodes
|
| 6025 |
incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // leafs
|
|
|
|
| 6026 |
incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // hash keys
|
| 6027 |
if (grads) {
|
| 6028 |
incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // grads
|
|
@@ -6052,11 +6068,12 @@ struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t siz
|
|
| 6052 |
|
| 6053 |
void * p = cgraph + 1;
|
| 6054 |
|
| 6055 |
-
struct ggml_tensor ** nodes_ptr
|
| 6056 |
-
struct ggml_tensor ** leafs_ptr
|
| 6057 |
-
|
| 6058 |
-
struct ggml_tensor **
|
| 6059 |
-
struct ggml_tensor **
|
|
|
|
| 6060 |
|
| 6061 |
ggml_bitset_t * hash_used = incr_ptr_aligned(&p, ggml_bitset_size(hash_size) * sizeof(ggml_bitset_t), sizeof(ggml_bitset_t));
|
| 6062 |
|
|
@@ -6071,6 +6088,7 @@ struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t siz
|
|
| 6071 |
/*.grads =*/ grads_ptr,
|
| 6072 |
/*.grad_accs =*/ grad_accs_ptr,
|
| 6073 |
/*.leafs =*/ leafs_ptr,
|
|
|
|
| 6074 |
/*.hash_table =*/ { hash_size, hash_used, hash_keys_ptr },
|
| 6075 |
/*.order =*/ GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT,
|
| 6076 |
};
|
|
@@ -6097,7 +6115,8 @@ struct ggml_cgraph ggml_graph_view(struct ggml_cgraph * cgraph0, int i0, int i1)
|
|
| 6097 |
/*.grads =*/ NULL, // gradients would need visited_hash_set
|
| 6098 |
/*.grad_accs =*/ NULL,
|
| 6099 |
/*.leafs =*/ NULL,
|
| 6100 |
-
/*.
|
|
|
|
| 6101 |
/*.order =*/ cgraph0->order,
|
| 6102 |
};
|
| 6103 |
|
|
@@ -6124,7 +6143,8 @@ void ggml_graph_cpy(struct ggml_cgraph * src, struct ggml_cgraph * dst) {
|
|
| 6124 |
for (size_t i = 0; i < src->visited_hash_set.size; ++i) {
|
| 6125 |
// copy all hashset keys (tensors) that are in use
|
| 6126 |
if (ggml_bitset_get(src->visited_hash_set.used, i)) {
|
| 6127 |
-
ggml_hash_insert(&dst->visited_hash_set, src->visited_hash_set.keys[i]);
|
|
|
|
| 6128 |
}
|
| 6129 |
}
|
| 6130 |
|
|
|
|
| 5850 |
GGML_ASSERT(!src2_needs_grads || ggml_are_same_shape(src2, cgraph->grads[isrc2]));
|
| 5851 |
}
|
| 5852 |
|
| 5853 |
+
static size_t ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) {
|
| 5854 |
// check if already visited
|
| 5855 |
+
size_t node_hash_pos = ggml_hash_find(&cgraph->visited_hash_set, node);
|
| 5856 |
+
GGML_ASSERT(node_hash_pos != GGML_HASHSET_FULL);
|
| 5857 |
+
if (!ggml_bitset_get(cgraph->visited_hash_set.used, node_hash_pos)) {
|
| 5858 |
+
// This is the first time we see this node in the current graph.
|
| 5859 |
+
cgraph->visited_hash_set.keys[node_hash_pos] = node;
|
| 5860 |
+
ggml_bitset_set(cgraph->visited_hash_set.used, node_hash_pos);
|
| 5861 |
+
cgraph->use_counts[node_hash_pos] = 0;
|
| 5862 |
+
} else {
|
| 5863 |
+
// already visited
|
| 5864 |
+
return node_hash_pos;
|
| 5865 |
}
|
| 5866 |
|
| 5867 |
for (int i = 0; i < GGML_MAX_SRC; ++i) {
|
| 5868 |
const int k =
|
| 5869 |
(cgraph->order == GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT) ? i :
|
| 5870 |
(cgraph->order == GGML_CGRAPH_EVAL_ORDER_RIGHT_TO_LEFT) ? (GGML_MAX_SRC-1-i) :
|
| 5871 |
+
/* unknown order, just fall back to using i */ i;
|
| 5872 |
+
|
| 5873 |
+
struct ggml_tensor * src = node->src[k];
|
| 5874 |
+
if (src) {
|
| 5875 |
+
size_t src_hash_pos = ggml_visit_parents(cgraph, src);
|
| 5876 |
+
|
| 5877 |
+
// Update the use count for this operand.
|
| 5878 |
+
cgraph->use_counts[src_hash_pos]++;
|
| 5879 |
}
|
| 5880 |
}
|
| 5881 |
|
|
|
|
| 5899 |
cgraph->nodes[cgraph->n_nodes] = node;
|
| 5900 |
cgraph->n_nodes++;
|
| 5901 |
}
|
| 5902 |
+
|
| 5903 |
+
return node_hash_pos;
|
| 5904 |
}
|
| 5905 |
|
| 5906 |
static void ggml_build_forward_impl(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor, bool expand) {
|
|
|
|
| 6038 |
incr_ptr_aligned(&p, sizeof(struct ggml_cgraph), 1);
|
| 6039 |
incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // nodes
|
| 6040 |
incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // leafs
|
| 6041 |
+
incr_ptr_aligned(&p, hash_size * sizeof(int32_t), sizeof(int32_t)); // use_counts
|
| 6042 |
incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // hash keys
|
| 6043 |
if (grads) {
|
| 6044 |
incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // grads
|
|
|
|
| 6068 |
|
| 6069 |
void * p = cgraph + 1;
|
| 6070 |
|
| 6071 |
+
struct ggml_tensor ** nodes_ptr = incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
|
| 6072 |
+
struct ggml_tensor ** leafs_ptr = incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
|
| 6073 |
+
int32_t * use_counts_ptr = incr_ptr_aligned(&p, hash_size * sizeof(int32_t), sizeof(int32_t));
|
| 6074 |
+
struct ggml_tensor ** hash_keys_ptr = incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
|
| 6075 |
+
struct ggml_tensor ** grads_ptr = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL;
|
| 6076 |
+
struct ggml_tensor ** grad_accs_ptr = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL;
|
| 6077 |
|
| 6078 |
ggml_bitset_t * hash_used = incr_ptr_aligned(&p, ggml_bitset_size(hash_size) * sizeof(ggml_bitset_t), sizeof(ggml_bitset_t));
|
| 6079 |
|
|
|
|
| 6088 |
/*.grads =*/ grads_ptr,
|
| 6089 |
/*.grad_accs =*/ grad_accs_ptr,
|
| 6090 |
/*.leafs =*/ leafs_ptr,
|
| 6091 |
+
/*.use_counts =*/ use_counts_ptr,
|
| 6092 |
/*.hash_table =*/ { hash_size, hash_used, hash_keys_ptr },
|
| 6093 |
/*.order =*/ GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT,
|
| 6094 |
};
|
|
|
|
| 6115 |
/*.grads =*/ NULL, // gradients would need visited_hash_set
|
| 6116 |
/*.grad_accs =*/ NULL,
|
| 6117 |
/*.leafs =*/ NULL,
|
| 6118 |
+
/*.use_counts =*/ cgraph0->use_counts,
|
| 6119 |
+
/*.visited_hash_set =*/ cgraph0->visited_hash_set,
|
| 6120 |
/*.order =*/ cgraph0->order,
|
| 6121 |
};
|
| 6122 |
|
|
|
|
| 6143 |
for (size_t i = 0; i < src->visited_hash_set.size; ++i) {
|
| 6144 |
// copy all hashset keys (tensors) that are in use
|
| 6145 |
if (ggml_bitset_get(src->visited_hash_set.used, i)) {
|
| 6146 |
+
size_t new_hash_pos = ggml_hash_insert(&dst->visited_hash_set, src->visited_hash_set.keys[i]);
|
| 6147 |
+
dst->use_counts[new_hash_pos] = src->use_counts[i];
|
| 6148 |
}
|
| 6149 |
}
|
| 6150 |
|