JohannesGaessler commited on
Commit
dd33ace
·
1 Parent(s): 026d20b

ggml: new optimization interface (ggml/988)

Browse files

* ggml: new optimization interface

remove test2.c, test3.c

store adamw params in tensor

move grads from tensor to graph

* avoid segfault upon API misuse

* add ggml-opt.h to public headers

* remove dependence of ggml-opt.cpp on ggml-cpu.h

ggml/include/ggml-backend.h CHANGED
@@ -86,7 +86,7 @@ extern "C" {
86
  GGML_API void ggml_backend_tensor_set_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
87
  GGML_API void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
88
 
89
- // "offset" refers to the offset of the tensor data for setting/getting data
90
  GGML_API void ggml_backend_tensor_set( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
91
  GGML_API void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
92
  GGML_API void ggml_backend_tensor_memset( struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size);
@@ -242,14 +242,20 @@ extern "C" {
242
  ggml_backend_sched_reserve(sched, reserve_graph);
243
 
244
  // compute
245
- graph = build_graph(sched);
246
- ggml_backend_sched_graph_compute(sched, graph);
 
 
247
 
248
  // if there are graph inputs:
249
- ggml_backend_sched_reset(sched);
250
- ggml_backend_sched_alloc_graph(sched, graph);
251
- ggml_backend_tensor_set(input_tensor, ...);
252
- ggml_backend_sched_graph_compute(sched, graph);
 
 
 
 
253
  }
254
  */
255
 
@@ -264,7 +270,7 @@ extern "C" {
264
  //
265
  typedef bool (*ggml_backend_sched_eval_callback)(struct ggml_tensor * t, bool ask, void * user_data);
266
 
267
- // Initialize a backend scheduler
268
  GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size, bool parallel);
269
  GGML_API void ggml_backend_sched_free(ggml_backend_sched_t sched);
270
 
@@ -289,7 +295,9 @@ extern "C" {
289
  GGML_API enum ggml_status ggml_backend_sched_graph_compute_async(ggml_backend_sched_t sched, struct ggml_cgraph * graph);
290
  GGML_API void ggml_backend_sched_synchronize(ggml_backend_sched_t sched);
291
 
292
- // Reset all assignments and allocators - must be called before changing the node backends
 
 
293
  GGML_API void ggml_backend_sched_reset(ggml_backend_sched_t sched);
294
 
295
  // Set a callback to be called for each resulting node during graph compute
 
86
  GGML_API void ggml_backend_tensor_set_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
87
  GGML_API void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
88
 
89
+ // "offset" refers to the offset in tensor->data for setting/getting data
90
  GGML_API void ggml_backend_tensor_set( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
91
  GGML_API void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
92
  GGML_API void ggml_backend_tensor_memset( struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size);
 
242
  ggml_backend_sched_reserve(sched, reserve_graph);
243
 
244
  // compute
245
+ graph = build_graph(sched); // the graph and its tensors are single-use in terms of allocation, multi-use in terms of computation
246
+ for (int i = 0; i < 10; ++i) {
247
+ ggml_backend_sched_graph_compute(sched, graph); // on the first iteration the graph is allocated automatically
248
+ }
249
 
250
  // if there are graph inputs:
251
+ graph = build_graph(sched); // get a new graph that is not allocated (the metadata for the old graph is freed once ggml_free is called)
252
+ ggml_backend_sched_reset(sched); // clear the allocation of the previous graph
253
+ ggml_backend_sched_alloc_graph(sched, graph); // explicitly allocate the new graph but do not execute it
254
+ ggml_backend_tensor_set(input_tensor, ...); // copy data to the newly allocated graph tensors
255
+ ggml_backend_sched_graph_compute(sched, graph); // execute the graph
256
+
257
+ // as an alternative to the above it is also possible to assign the inputs to a dedicated context and
258
+ // allocate them statically via ggml_backend_alloc_ctx_tensors
259
  }
260
  */
261
 
 
270
  //
271
  typedef bool (*ggml_backend_sched_eval_callback)(struct ggml_tensor * t, bool ask, void * user_data);
272
 
273
+ // Initialize a backend scheduler, backends with low index are given priority over backends with high index
274
  GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size, bool parallel);
275
  GGML_API void ggml_backend_sched_free(ggml_backend_sched_t sched);
276
 
 
295
  GGML_API enum ggml_status ggml_backend_sched_graph_compute_async(ggml_backend_sched_t sched, struct ggml_cgraph * graph);
296
  GGML_API void ggml_backend_sched_synchronize(ggml_backend_sched_t sched);
297
 
298
+ // Reset all assignments and allocators - must be called before changing the node backends or allocating a new graph.
299
+ // This in effect deallocates all tensors that were previously allocated and leaves them with dangling pointers.
300
+ // The correct way to use this API is to discard the deallocated tensors and create new ones.
301
  GGML_API void ggml_backend_sched_reset(ggml_backend_sched_t sched);
302
 
303
  // Set a callback to be called for each resulting node during graph compute
ggml/include/ggml-opt.h ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // This file contains functionality for training models using GGML.
2
+ // It is not strictly needed vs. just vanilla GGML but it provides a more high-level interface for common needs such as datasets.
3
+ // At the bottom of this file especially there are relatively high-level functions that are suitable use or adaptation in user code.
4
+ //
5
+ // Module maintainer: Johannes Gäßler (@JohannesGaessler, [email protected])
6
+
7
+ #pragma once
8
+
9
+ #include "ggml.h"
10
+ #include "ggml-backend.h"
11
+
12
+ #include <stdint.h>
13
+
14
+ #ifdef __cplusplus
15
+ extern "C" {
16
+ #endif
17
+
18
+ struct ggml_opt_dataset;
19
+ struct ggml_opt_context;
20
+ struct ggml_opt_result;
21
+
22
+ typedef struct ggml_opt_dataset * ggml_opt_dataset_t;
23
+ typedef struct ggml_opt_context * ggml_opt_context_t;
24
+ typedef struct ggml_opt_result * ggml_opt_result_t;
25
+
26
+ // ====== Loss ======
27
+
28
+ // built-in loss types, i.e. the built-in quantities minimized by the optimizer
29
+ // custom loss types can be defined via mean or sum which simply reduce the outputs for all datapoints to a single value
30
+ enum ggml_opt_loss_type {
31
+ GGML_OPT_LOSS_TYPE_MEAN,
32
+ GGML_OPT_LOSS_TYPE_SUM,
33
+ GGML_OPT_LOSS_TYPE_CROSS_ENTROPY,
34
+ GGML_OPT_LOSS_TYPE_MEAN_SQUARED_ERROR,
35
+ };
36
+
37
+ // ====== Dataset ======
38
+
39
+ GGML_API ggml_opt_dataset_t ggml_opt_dataset_init(
40
+ int64_t ne_datapoint, // number of elements per datapoint
41
+ int64_t ne_label, // number of elements per label
42
+ int64_t ndata, // total number of datapoints/labels
43
+ int64_t ndata_shard); // number of datapoints/labels per shard (unit at which the dataset is shuffled/copied)
44
+ GGML_API void ggml_opt_dataset_free(ggml_opt_dataset_t dataset);
45
+
46
+ // get underlying tensors that store the data
47
+ GGML_API struct ggml_tensor * ggml_opt_dataset_data (ggml_opt_dataset_t dataset); // shape = [ne_datapoint, ndata]
48
+ GGML_API struct ggml_tensor * ggml_opt_dataset_labels(ggml_opt_dataset_t dataset); // shape = [nd_label, ndata]
49
+
50
+ // shuffle idata first datapoints from dataset with RNG from opt_ctx, shuffle all datapoints if idata is negative
51
+ GGML_API void ggml_opt_dataset_shuffle(ggml_opt_context_t opt_ctx, ggml_opt_dataset_t dataset, int64_t idata);
52
+
53
+ // get batch at position ibatch from dataset and copy the data to data_batch and labels_batch
54
+ GGML_API void ggml_opt_dataset_get_batch(
55
+ ggml_opt_dataset_t dataset,
56
+ struct ggml_tensor * data_batch, // shape = [ne_datapoint, ndata_batch]
57
+ struct ggml_tensor * labels_batch, // shape = [ne_label, ndata_batch]
58
+ int64_t ibatch);
59
+
60
+ // ====== Model / Context ======
61
+
62
+ enum ggml_opt_build_type {
63
+ GGML_OPT_BUILD_TYPE_FORWARD,
64
+ GGML_OPT_BUILD_TYPE_GRAD,
65
+ GGML_OPT_BUILD_TYPE_OPT,
66
+ };
67
+
68
+ // parameters that control which optimizer is used and how said optimizer tries to find the minimal loss
69
+ struct ggml_opt_optimizer_params {
70
+ // AdamW optimizer parameters
71
+ struct {
72
+ float alpha; // learning rate
73
+ float beta1;
74
+ float beta2;
75
+ float eps; // epsilon for numerical stability
76
+ float wd; // weight decay for AdamW, use 0.0f to disable
77
+ } adamw;
78
+ };
79
+
80
+ // callback to calculate optimizer parameters prior to a backward pass
81
+ // userdata can be used to pass arbitrary data
82
+ typedef struct ggml_opt_optimizer_params (*ggml_opt_get_optimizer_params)(void * userdata);
83
+
84
+ // returns the default optimizer params (constant)
85
+ // userdata is not used
86
+ GGML_API struct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params(void * userdata);
87
+
88
+ // parameters for initializing a new optimization context
89
+ struct ggml_opt_params {
90
+ ggml_backend_sched_t backend_sched; // defines which backends are used to construct the compute graphs
91
+
92
+ struct ggml_context * ctx_compute; // created in user code, holds non-static tensors
93
+
94
+ // the forward graph is defined by inputs and outputs
95
+ // those tensors and all tensors inbetween are not intended to be reusable between multiple optimization contexts
96
+ struct ggml_tensor * inputs;
97
+ struct ggml_tensor * outputs;
98
+
99
+ enum ggml_opt_loss_type loss_type;
100
+ enum ggml_opt_build_type build_type;
101
+
102
+ int32_t opt_period; // after how many gradient accumulation steps an optimizer step should be done
103
+
104
+ ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters
105
+ void * get_opt_pars_ud; // userdata for calculating optimizer parameters
106
+ };
107
+
108
+ // get parameters for an optimization context with defaults set where possible
109
+ // parameters for which no sensible defaults exist are supplied as arguments to this function
110
+ GGML_API ggml_opt_params ggml_opt_default_params(
111
+ ggml_backend_sched_t backend_sched,
112
+ struct ggml_context * ctx_compute,
113
+ struct ggml_tensor * inputs,
114
+ struct ggml_tensor * outputs,
115
+ enum ggml_opt_loss_type loss_type);
116
+
117
+ GGML_API ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params);
118
+ GGML_API void ggml_opt_free(ggml_opt_context_t opt_ctx);
119
+
120
+ // set gradients to zero, initilize loss, and optionally reset the optimizer
121
+ GGML_API void ggml_opt_reset(ggml_opt_context_t opt_ctx, bool optimizer);
122
+
123
+ // get underlying tensors that store data
124
+ GGML_API struct ggml_tensor * ggml_opt_inputs( ggml_opt_context_t opt_ctx); // forward graph input tensor
125
+ GGML_API struct ggml_tensor * ggml_opt_outputs( ggml_opt_context_t opt_ctx); // forward graph output tensor
126
+ GGML_API struct ggml_tensor * ggml_opt_labels( ggml_opt_context_t opt_ctx); // labels to compare outputs against
127
+ GGML_API struct ggml_tensor * ggml_opt_loss( ggml_opt_context_t opt_ctx); // scalar tensor that contains the loss
128
+ GGML_API struct ggml_tensor * ggml_opt_pred( ggml_opt_context_t opt_ctx); // predictions made by outputs
129
+ GGML_API struct ggml_tensor * ggml_opt_ncorrect(ggml_opt_context_t opt_ctx); // number of matching predictions between outputs and labels
130
+
131
+ GGML_API struct ggml_tensor * ggml_opt_grad_acc(ggml_opt_context_t opt_ctx, struct ggml_tensor * node);
132
+
133
+ // ====== Optimization Result ======
134
+
135
+ GGML_API ggml_opt_result_t ggml_opt_result_init();
136
+ GGML_API void ggml_opt_result_free(ggml_opt_result_t result);
137
+ GGML_API void ggml_opt_result_reset(ggml_opt_result_t result);
138
+
139
+ // get data from result, uncertainties are optional and can be ignored by passing NULL
140
+ GGML_API void ggml_opt_result_ndata( ggml_opt_result_t result, int64_t * ndata); // writes 1 value, number of datapoints
141
+ GGML_API void ggml_opt_result_loss( ggml_opt_result_t result, double * loss, double * unc); // writes 1 value
142
+ GGML_API void ggml_opt_result_pred( ggml_opt_result_t result, int32_t * pred); // writes ndata values
143
+ GGML_API void ggml_opt_result_accuracy(ggml_opt_result_t result, double * accuracy, double * unc); // writes 1 value
144
+
145
+ // ====== Computation ======
146
+
147
+ // do forward pass, increment result if not NULL
148
+ GGML_API void ggml_opt_forward(ggml_opt_context_t opt_ctx, ggml_opt_result_t result);
149
+
150
+ // do forward pass, increment result if not NULL, do backward pass
151
+ GGML_API void ggml_opt_forward_backward(ggml_opt_context_t opt_ctx, ggml_opt_result_t result);
152
+
153
+ // ############################################################################
154
+ // ## The high-level functions start here. They do not depend on any private ##
155
+ // ## functions or structs and can be copied to and adapted for user code. ##
156
+ // ############################################################################
157
+
158
+ // ====== Intended Usage ======
159
+ //
160
+ // 1. Select the appropriate loss for your problem.
161
+ // 2. Create a dataset and set the data for the "data" tensor. Also set the "labels" tensor if your loss needs them.
162
+ // Setting the shard size to 1 will be fine, it's the granularity with which data is shuffled/loaded (bigger values are faster).
163
+ // 3. Create a GGML graph for your model with no_alloc == true. Use two separate contexts for the tensors.
164
+ // The first context should contain the model parameters and inputs and be allocated statically in user code.
165
+ // The second context should contain all other tensors and will be (re)allocated automatically.
166
+ // Due to this automated allocation the data of the second context is not defined when accessed in user code.
167
+ // Note that the second dimension of the inputs/outputs are interpreted as the number of datapoints in those tensors.
168
+ // 4. Call ggml_opt_fit. If you need more control you can use ggml_opt_epoch instead.
169
+
170
+ // signature for a callback while evaluating opt_ctx on dataset, called after an evaluation
171
+ typedef void (*ggml_opt_epoch_callback)(
172
+ bool train, // true after training evaluation, false after validation evaluation
173
+ ggml_opt_context_t opt_ctx,
174
+ ggml_opt_dataset_t dataset,
175
+ ggml_opt_result_t result, // result associated with the dataset subsection
176
+ int64_t ibatch, // number of batches that have been evaluated so far
177
+ int64_t ibatch_max, // total number of batches in this dataset subsection
178
+ int64_t t_start_us); // time at which the evaluation on the dataset subsection was started
179
+
180
+ // do training on front of dataset, do evaluation only on back of dataset
181
+ GGML_API void ggml_opt_epoch(
182
+ ggml_opt_context_t opt_ctx,
183
+ ggml_opt_dataset_t dataset,
184
+ ggml_opt_result_t result_train, // result to increment during training, ignored if NULL
185
+ ggml_opt_result_t result_eval, // result to increment during evaluation, ignored if NULL
186
+ int64_t idata_split, // data index at which to split training and evaluation
187
+ ggml_opt_epoch_callback callback_train,
188
+ ggml_opt_epoch_callback callback_eval);
189
+
190
+ // callback that prints a progress bar on stderr
191
+ GGML_API void ggml_opt_epoch_callback_progress_bar(
192
+ bool train,
193
+ ggml_opt_context_t opt_ctx,
194
+ ggml_opt_dataset_t dataset,
195
+ ggml_opt_result_t result,
196
+ int64_t ibatch,
197
+ int64_t ibatch_max,
198
+ int64_t t_start_us);
199
+
200
+ // fit model defined by inputs and outputs to dataset
201
+ GGML_API void ggml_opt_fit(
202
+ ggml_backend_sched_t backend_sched, // backend scheduler for constructing the compute graphs
203
+ ggml_context * ctx_compute, // context with temporarily allocated tensors to calculate the outputs
204
+ ggml_tensor * inputs, // input tensor with shape [ne_datapoint, ndata_batch]
205
+ ggml_tensor * outputs, // output tensor, must have shape [ne_label, ndata_batch] if labels are used
206
+ ggml_opt_dataset_t dataset, // dataset with data and optionally also labels
207
+ enum ggml_opt_loss_type loss_type, // loss to minimize
208
+ ggml_opt_get_optimizer_params get_opt_pars, // callback to get optimizer params, userdata is pointer to epoch (of type int64_t)
209
+ int64_t nepoch, // how many times the dataset should be iterated over
210
+ int64_t nbatch_logical, // datapoints optimizer step, must be a multiple of ndata_batch in inputs/outputs
211
+ float val_split, // fraction of the dataset to use for validation, must be in [0.0f, 1.0f)
212
+ bool silent); // whether or not info prints to stderr should be suppressed
213
+
214
+ #ifdef __cplusplus
215
+ }
216
+ #endif
ggml/include/ggml.h CHANGED
@@ -602,7 +602,6 @@ extern "C" {
602
 
603
  int32_t flags;
604
 
605
- struct ggml_tensor * grad;
606
  struct ggml_tensor * src[GGML_MAX_SRC];
607
 
608
  // source tensor and offset for views
@@ -615,7 +614,7 @@ extern "C" {
615
 
616
  void * extra; // extra things e.g. for ggml-cuda.cu
617
 
618
- // char padding[4];
619
  };
620
 
621
  static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor);
@@ -1985,28 +1984,20 @@ extern "C" {
1985
  struct ggml_context * ctx,
1986
  struct ggml_tensor * a,
1987
  struct ggml_tensor * grad,
1988
- float alpha,
1989
- float beta1,
1990
- float beta2,
1991
- float eps,
1992
- float wd); // weight decay
1993
 
1994
  //
1995
  // automatic differentiation
1996
  //
1997
 
1998
- GGML_API void ggml_build_forward_expand (struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
1999
- GGML_API void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool accumulate);
2000
-
2001
- GGML_API void ggml_build_opt_adamw(
2002
- struct ggml_context * ctx,
2003
- struct ggml_cgraph * gf,
2004
- struct ggml_cgraph * gb,
2005
- float alpha,
2006
- float beta1,
2007
- float beta2,
2008
- float eps,
2009
- float wd); // weight decay
2010
 
2011
  // graph allocation in a context
2012
  GGML_API struct ggml_cgraph * ggml_new_graph (struct ggml_context * ctx); // size = GGML_DEFAULT_GRAPH_SIZE, grads = false
@@ -2026,7 +2017,9 @@ extern "C" {
2026
  GGML_API size_t ggml_graph_overhead(void);
2027
  GGML_API size_t ggml_graph_overhead_custom(size_t size, bool grads);
2028
 
2029
- GGML_API struct ggml_tensor * ggml_graph_get_tensor(struct ggml_cgraph * cgraph, const char * name);
 
 
2030
 
2031
  GGML_API void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname);
2032
  GGML_API struct ggml_cgraph * ggml_graph_import(const char * fname, struct ggml_context ** ctx_data, struct ggml_context ** ctx_eval);
@@ -2037,198 +2030,15 @@ extern "C" {
2037
  // dump the graph into a file using the dot format
2038
  GGML_API void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename);
2039
 
2040
- // build gradient checkpointing backward graph gb for gf using provided checkpoints
2041
- // gb_tmp will contain original backward graph with rewritten backward process nodes,
2042
- // but without the second forward pass nodes.
2043
- GGML_API void ggml_build_backward_gradient_checkpointing(
2044
- struct ggml_context * ctx,
2045
- struct ggml_cgraph * gf,
2046
- struct ggml_cgraph * gb,
2047
- struct ggml_cgraph * gb_tmp,
2048
- struct ggml_tensor * * checkpoints,
2049
- int n_checkpoints);
2050
- //
2051
- // optimization
2052
- //
2053
-
2054
- // optimization methods
2055
- enum ggml_opt_type {
2056
- GGML_OPT_TYPE_ADAM,
2057
- GGML_OPT_TYPE_LBFGS,
2058
- };
2059
-
2060
- // linesearch methods
2061
- enum ggml_linesearch {
2062
- GGML_LINESEARCH_DEFAULT = 1,
2063
-
2064
- GGML_LINESEARCH_BACKTRACKING_ARMIJO = 0,
2065
- GGML_LINESEARCH_BACKTRACKING_WOLFE = 1,
2066
- GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE = 2,
2067
- };
2068
-
2069
- // optimization return values
2070
- enum ggml_opt_result {
2071
- GGML_OPT_RESULT_OK = 0,
2072
- GGML_OPT_RESULT_DID_NOT_CONVERGE,
2073
- GGML_OPT_RESULT_NO_CONTEXT,
2074
- GGML_OPT_RESULT_INVALID_WOLFE,
2075
- GGML_OPT_RESULT_FAIL,
2076
- GGML_OPT_RESULT_CANCEL,
2077
-
2078
- GGML_LINESEARCH_FAIL = -128,
2079
- GGML_LINESEARCH_MINIMUM_STEP,
2080
- GGML_LINESEARCH_MAXIMUM_STEP,
2081
- GGML_LINESEARCH_MAXIMUM_ITERATIONS,
2082
- GGML_LINESEARCH_INVALID_PARAMETERS,
2083
- };
2084
-
2085
- typedef void (*ggml_opt_callback)(void * data, int accum_step, float * sched, bool * cancel);
2086
  typedef void (*ggml_log_callback)(enum ggml_log_level level, const char * text, void * user_data);
2087
 
2088
  // Set callback for all future logging events.
2089
  // If this is not called, or NULL is supplied, everything is output on stderr.
2090
  GGML_API void ggml_log_set(ggml_log_callback log_callback, void * user_data);
2091
 
2092
- // optimization parameters
2093
- //
2094
- // see ggml.c (ggml_opt_default_params) for default values
2095
- //
2096
- struct ggml_opt_params {
2097
- enum ggml_opt_type type;
2098
-
2099
- size_t graph_size;
2100
-
2101
- int n_threads;
2102
-
2103
- // delta-based convergence test
2104
- //
2105
- // if past == 0 - disabled
2106
- // if past > 0:
2107
- // stop if |f(x) - f(x_past)| < delta * max(1, |f(x)|)
2108
- //
2109
- int past;
2110
- float delta;
2111
-
2112
- // maximum number of iterations without improvement
2113
- //
2114
- // if 0 - disabled
2115
- // if > 0:
2116
- // assume convergence if no cost improvement in this number of iterations
2117
- //
2118
- int max_no_improvement;
2119
-
2120
- bool print_forward_graph;
2121
- bool print_backward_graph;
2122
-
2123
- int n_gradient_accumulation;
2124
-
2125
- // ADAM parameters
2126
- struct {
2127
- int n_iter;
2128
-
2129
- float sched; // schedule multiplier (fixed, decay or warmup)
2130
- float decay; // weight decay for AdamW, use 0.0f to disable
2131
- int decay_min_ndim; // minimum number of tensor dimension to apply weight decay
2132
- float alpha; // learning rate
2133
- float beta1;
2134
- float beta2;
2135
- float eps; // epsilon for numerical stability
2136
- float eps_f; // epsilon for convergence test
2137
- float eps_g; // epsilon for convergence test
2138
- float gclip; // gradient clipping
2139
- } adam;
2140
-
2141
- // LBFGS parameters
2142
- struct {
2143
- int m; // number of corrections to approximate the inv. Hessian
2144
- int n_iter;
2145
- int max_linesearch;
2146
-
2147
- float eps; // convergence tolerance
2148
- float ftol; // line search tolerance
2149
- float wolfe;
2150
- float min_step;
2151
- float max_step;
2152
-
2153
- enum ggml_linesearch linesearch;
2154
- } lbfgs;
2155
- };
2156
-
2157
- struct ggml_opt_context {
2158
- struct ggml_context * ctx;
2159
- struct ggml_opt_params params;
2160
-
2161
- int iter;
2162
- int64_t nx; // number of parameter elements
2163
-
2164
- bool just_initialized;
2165
-
2166
- float loss_before;
2167
- float loss_after;
2168
-
2169
- struct {
2170
- struct ggml_tensor * g; // current gradient
2171
- struct ggml_tensor * m; // first moment
2172
- struct ggml_tensor * v; // second moment
2173
- struct ggml_tensor * pf; // past function values
2174
- float fx_best;
2175
- float fx_prev;
2176
- int n_no_improvement;
2177
- } adam;
2178
-
2179
- struct {
2180
- struct ggml_tensor * x; // current parameters
2181
- struct ggml_tensor * xp; // previous parameters
2182
- struct ggml_tensor * g; // current gradient
2183
- struct ggml_tensor * gp; // previous gradient
2184
- struct ggml_tensor * d; // search direction
2185
- struct ggml_tensor * pf; // past function values
2186
- struct ggml_tensor * lmal; // the L-BFGS memory alpha
2187
- struct ggml_tensor * lmys; // the L-BFGS memory ys
2188
- struct ggml_tensor * lms; // the L-BFGS memory s
2189
- struct ggml_tensor * lmy; // the L-BFGS memory y
2190
- float fx_best;
2191
- float step;
2192
- int j;
2193
- int k;
2194
- int end;
2195
- int n_no_improvement;
2196
- } lbfgs;
2197
- };
2198
-
2199
  GGML_API struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor);
2200
 
2201
- GGML_API struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type);
2202
-
2203
- // optimize the function defined by the tensor f
2204
- GGML_API enum ggml_opt_result ggml_opt(
2205
- struct ggml_context * ctx,
2206
- struct ggml_opt_params params,
2207
- struct ggml_tensor * f);
2208
-
2209
- // initialize optimizer context
2210
- GGML_API void ggml_opt_init(
2211
- struct ggml_context * ctx,
2212
- struct ggml_opt_context * opt,
2213
- struct ggml_opt_params params,
2214
- int64_t nx);
2215
-
2216
- // continue optimizing the function defined by the tensor f
2217
- GGML_API enum ggml_opt_result ggml_opt_resume(
2218
- struct ggml_context * ctx,
2219
- struct ggml_opt_context * opt,
2220
- struct ggml_tensor * f);
2221
-
2222
- // continue optimizing the function defined by the tensor f
2223
- GGML_API enum ggml_opt_result ggml_opt_resume_g(
2224
- struct ggml_context * ctx,
2225
- struct ggml_opt_context * opt,
2226
- struct ggml_tensor * f,
2227
- struct ggml_cgraph * gf,
2228
- struct ggml_cgraph * gb,
2229
- ggml_opt_callback callback,
2230
- void * callback_data);
2231
-
2232
  //
2233
  // quantization
2234
  //
 
602
 
603
  int32_t flags;
604
 
 
605
  struct ggml_tensor * src[GGML_MAX_SRC];
606
 
607
  // source tensor and offset for views
 
614
 
615
  void * extra; // extra things e.g. for ggml-cuda.cu
616
 
617
+ char padding[8];
618
  };
619
 
620
  static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor);
 
1984
  struct ggml_context * ctx,
1985
  struct ggml_tensor * a,
1986
  struct ggml_tensor * grad,
1987
+ struct ggml_tensor * m,
1988
+ struct ggml_tensor * v,
1989
+ struct ggml_tensor * adamw_params); // parameters such a the learning rate
 
 
1990
 
1991
  //
1992
  // automatic differentiation
1993
  //
1994
 
1995
+ GGML_API void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
1996
+ GGML_API void ggml_build_backward_expand(
1997
+ struct ggml_context * ctx_static, // context for static gradients (loss + gradient accumulation)
1998
+ struct ggml_context * ctx_compute, // context for gradient computation
1999
+ struct ggml_cgraph * cgraph,
2000
+ bool accumulate); // whether or not gradients should be accumulated, requires static allocation of tensors in ctx_static
 
 
 
 
 
 
2001
 
2002
  // graph allocation in a context
2003
  GGML_API struct ggml_cgraph * ggml_new_graph (struct ggml_context * ctx); // size = GGML_DEFAULT_GRAPH_SIZE, grads = false
 
2017
  GGML_API size_t ggml_graph_overhead(void);
2018
  GGML_API size_t ggml_graph_overhead_custom(size_t size, bool grads);
2019
 
2020
+ GGML_API struct ggml_tensor * ggml_graph_get_tensor (const struct ggml_cgraph * cgraph, const char * name);
2021
+ GGML_API struct ggml_tensor * ggml_graph_get_grad (const struct ggml_cgraph * cgraph, const struct ggml_tensor * node);
2022
+ GGML_API struct ggml_tensor * ggml_graph_get_grad_acc(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node);
2023
 
2024
  GGML_API void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname);
2025
  GGML_API struct ggml_cgraph * ggml_graph_import(const char * fname, struct ggml_context ** ctx_data, struct ggml_context ** ctx_eval);
 
2030
  // dump the graph into a file using the dot format
2031
  GGML_API void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename);
2032
 
2033
+ // TODO these functions were sandwiched in the old optimization interface, is there a better place for them?
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2034
  typedef void (*ggml_log_callback)(enum ggml_log_level level, const char * text, void * user_data);
2035
 
2036
  // Set callback for all future logging events.
2037
  // If this is not called, or NULL is supplied, everything is output on stderr.
2038
  GGML_API void ggml_log_set(ggml_log_callback log_callback, void * user_data);
2039
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2040
  GGML_API struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor);
2041
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2042
  //
2043
  // quantization
2044
  //
ggml/src/CMakeLists.txt CHANGED
@@ -207,9 +207,11 @@ add_library(ggml-base
207
  ../include/ggml-alloc.h
208
  ../include/ggml-backend.h
209
  ../include/ggml-cpp.h
 
210
  ggml.c
211
  ggml-alloc.c
212
  ggml-backend.cpp
 
213
  ggml-threading.cpp
214
  ggml-threading.h
215
  ggml-quants.c
 
207
  ../include/ggml-alloc.h
208
  ../include/ggml-backend.h
209
  ../include/ggml-cpp.h
210
+ ../include/ggml-opt.h
211
  ggml.c
212
  ggml-alloc.c
213
  ggml-backend.cpp
214
+ ggml-opt.cpp
215
  ggml-threading.cpp
216
  ggml-threading.h
217
  ggml-quants.c
ggml/src/ggml-alloc.c CHANGED
@@ -466,18 +466,12 @@ static bool ggml_gallocr_is_own(ggml_gallocr_t galloc, struct ggml_tensor * t) {
466
  return ggml_gallocr_hash_get(galloc, t)->allocated;
467
  }
468
 
469
- static void ggml_gallocr_set_node_offset(ggml_gallocr_t galloc, struct ggml_tensor * node, int buffer_id, size_t offset) {
470
- struct hash_node * hn = ggml_gallocr_hash_get(galloc, node);
471
- hn->buffer_id = buffer_id;
472
- hn->offset = offset;
473
- hn->allocated = true;
474
- }
475
-
476
  static bool ggml_gallocr_is_allocated(ggml_gallocr_t galloc, struct ggml_tensor * t) {
477
  return t->data != NULL || ggml_gallocr_hash_get(galloc, t)->allocated;
478
  }
479
 
480
  static void ggml_gallocr_allocate_node(ggml_gallocr_t galloc, struct ggml_tensor * node, int buffer_id) {
 
481
  struct hash_node * hn = ggml_gallocr_hash_get(galloc, node);
482
 
483
  if (!ggml_gallocr_is_allocated(galloc, node) && !ggml_is_view(node)) {
@@ -816,7 +810,11 @@ static void ggml_gallocr_init_tensor(ggml_gallocr_t galloc, struct ggml_tensor *
816
  }
817
 
818
  static bool ggml_gallocr_node_needs_realloc(ggml_gallocr_t galloc, struct ggml_tensor * node, struct tensor_alloc * talloc) {
819
- size_t node_size = (node->data || node->view_src) ? 0 : ggml_backend_buft_get_alloc_size(galloc->bufts[talloc->buffer_id], node);
 
 
 
 
820
  return talloc->size_max >= node_size;
821
  }
822
 
 
466
  return ggml_gallocr_hash_get(galloc, t)->allocated;
467
  }
468
 
 
 
 
 
 
 
 
469
  static bool ggml_gallocr_is_allocated(ggml_gallocr_t galloc, struct ggml_tensor * t) {
470
  return t->data != NULL || ggml_gallocr_hash_get(galloc, t)->allocated;
471
  }
472
 
473
  static void ggml_gallocr_allocate_node(ggml_gallocr_t galloc, struct ggml_tensor * node, int buffer_id) {
474
+ GGML_ASSERT(buffer_id >= 0);
475
  struct hash_node * hn = ggml_gallocr_hash_get(galloc, node);
476
 
477
  if (!ggml_gallocr_is_allocated(galloc, node) && !ggml_is_view(node)) {
 
810
  }
811
 
812
  static bool ggml_gallocr_node_needs_realloc(ggml_gallocr_t galloc, struct ggml_tensor * node, struct tensor_alloc * talloc) {
813
+ size_t node_size = 0;
814
+ if (!node->data && !node->view_src) {
815
+ GGML_ASSERT(talloc->buffer_id >= 0); // prevent segfault when misusing the API
816
+ node_size = ggml_backend_buft_get_alloc_size(galloc->bufts[talloc->buffer_id], node);
817
+ }
818
  return talloc->size_max >= node_size;
819
  }
820
 
ggml/src/ggml-backend.cpp CHANGED
@@ -279,7 +279,7 @@ void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, siz
279
  buf->iface.get_tensor(buf, tensor, data, offset, size);
280
  }
281
 
282
- GGML_API void ggml_backend_tensor_memset(struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
283
  ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
284
 
285
  if (size == 0) {
 
279
  buf->iface.get_tensor(buf, tensor, data, offset, size);
280
  }
281
 
282
+ void ggml_backend_tensor_memset(struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
283
  ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
284
 
285
  if (size == 0) {
ggml/src/ggml-cpu/ggml-cpu.c CHANGED
@@ -12216,11 +12216,16 @@ static void ggml_compute_forward_opt_step_adamw_f32(
12216
  const struct ggml_compute_params * params,
12217
  struct ggml_tensor * dst) {
12218
 
12219
- const struct ggml_tensor * src0 = dst->src[0];
12220
- const struct ggml_tensor * src0_grad = dst->src[1];
12221
- const struct ggml_tensor * src0_grad_m = dst->src[2];
12222
- const struct ggml_tensor * src0_grad_v = dst->src[3];
 
 
12223
  GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
 
 
 
12224
 
12225
  const int ith = params->ith;
12226
  const int nth = params->nth;
@@ -12237,16 +12242,14 @@ static void ggml_compute_forward_opt_step_adamw_f32(
12237
  const int ir0 = dr*ith;
12238
  const int ir1 = MIN(ir0 + dr, nr);
12239
 
12240
- /* const float gnorm = 1.0f; */
12241
- int64_t iter; memcpy(&iter, &dst->op_params[0], sizeof(int64_t));
12242
- const float alpha = ggml_get_op_params_f32(dst, 2);
12243
- const float beta1 = ggml_get_op_params_f32(dst, 3);
12244
- const float beta2 = ggml_get_op_params_f32(dst, 4);
12245
- const float eps = ggml_get_op_params_f32(dst, 5);
12246
- const float wd = ggml_get_op_params_f32(dst, 6);
12247
-
12248
- const float beta1h = alpha/(1.0f - powf(beta1, iter));
12249
- const float beta2h = 1.0f/(1.0f - powf(beta2, iter));
12250
 
12251
  for (int ir = ir0; ir < ir1; ++ir) {
12252
  const int64_t i03 = ir/(ne02*ne01);
@@ -12270,17 +12273,9 @@ static void ggml_compute_forward_opt_step_adamw_f32(
12270
  // The weight decay is applied independently of the Adam momenta m and v.
12271
  // This is NOT equivalent to l2 regularization that adds w[i00]*w[i00] to the loss.
12272
  // See: https://arxiv.org/pdf/1711.05101v3.pdf
12273
- w[i00] = w[i00]*(1.0f - alpha*wd) - mh/vh;
12274
  }
12275
  }
12276
-
12277
- ggml_barrier(params->threadpool);
12278
- if (ith != 0) {
12279
- return;
12280
- }
12281
-
12282
- iter++;
12283
- memcpy(&dst->op_params[0], &iter, sizeof(int64_t));
12284
  }
12285
 
12286
  static void ggml_compute_forward_opt_step_adamw(
 
12216
  const struct ggml_compute_params * params,
12217
  struct ggml_tensor * dst) {
12218
 
12219
+ const struct ggml_tensor * src0 = dst->src[0];
12220
+ const struct ggml_tensor * src0_grad = dst->src[1];
12221
+ const struct ggml_tensor * src0_grad_m = dst->src[2];
12222
+ const struct ggml_tensor * src0_grad_v = dst->src[3];
12223
+ const struct ggml_tensor * adamw_params = dst->src[4];
12224
+
12225
  GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
12226
+ GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_m));
12227
+ GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_v));
12228
+ GGML_ASSERT(ggml_nelements(adamw_params) == 7);
12229
 
12230
  const int ith = params->ith;
12231
  const int nth = params->nth;
 
12242
  const int ir0 = dr*ith;
12243
  const int ir1 = MIN(ir0 + dr, nr);
12244
 
12245
+ const float * adamw_params_ptr = ggml_get_data_f32(adamw_params);
12246
+ const float alpha = adamw_params_ptr[0];
12247
+ const float beta1 = adamw_params_ptr[1];
12248
+ const float beta2 = adamw_params_ptr[2];
12249
+ const float eps = adamw_params_ptr[3];
12250
+ const float wd = adamw_params_ptr[4];
12251
+ const float beta1h = adamw_params_ptr[5];
12252
+ const float beta2h = adamw_params_ptr[6];
 
 
12253
 
12254
  for (int ir = ir0; ir < ir1; ++ir) {
12255
  const int64_t i03 = ir/(ne02*ne01);
 
12273
  // The weight decay is applied independently of the Adam momenta m and v.
12274
  // This is NOT equivalent to l2 regularization that adds w[i00]*w[i00] to the loss.
12275
  // See: https://arxiv.org/pdf/1711.05101v3.pdf
12276
+ w[i00] = w[i00]*(1.0f - alpha*wd) - alpha*mh/vh;
12277
  }
12278
  }
 
 
 
 
 
 
 
 
12279
  }
12280
 
12281
  static void ggml_compute_forward_opt_step_adamw(
ggml/src/ggml-cuda/opt-step-adamw.cu CHANGED
@@ -1,11 +1,11 @@
 
1
  #include "opt-step-adamw.cuh"
2
 
3
  #include <cstdint>
4
 
5
  static __global__ void opt_step_adamw_f32(
6
- float * __restrict__ x, const float * __restrict__ g, float * __restrict__ g_m, float * __restrict__ g_v, const int64_t k,
7
- const float alpha, const float beta1, const float beta2, const float eps, const float wd,
8
- const float beta1h, const float beta2h) {
9
 
10
  const int64_t i = (int64_t) blockIdx.x*blockDim.x + threadIdx.x;
11
 
@@ -13,6 +13,14 @@ static __global__ void opt_step_adamw_f32(
13
  return;
14
  }
15
 
 
 
 
 
 
 
 
 
16
  const float gi = g[i];
17
  const float gmi = g_m[i]*beta1 + gi*(1.0f - beta1);
18
  const float gvi = g_v[i]*beta2 + gi*gi*(1.0f - beta2);
@@ -23,58 +31,48 @@ static __global__ void opt_step_adamw_f32(
23
  const float mh = gmi*beta1h;
24
  const float vh = sqrtf(gvi*beta2h) + eps;
25
 
26
- x[i] = x[i]*(1.0f - alpha*wd) - mh/vh;
27
  }
28
 
29
  static void opt_step_adamw_f32_cuda(
30
- float * x, const float * g, float * g_m, float * g_v, const int64_t k,
31
- const float alpha, const float beta1, const float beta2, const float eps, const float wd,
32
- const float beta1h, const float beta2h, cudaStream_t stream) {
33
 
34
  const dim3 block_dims(CUDA_OPT_STEP_ADAMW_BLOCK_SIZE, 1, 1);
35
  const dim3 block_nums((k + CUDA_OPT_STEP_ADAMW_BLOCK_SIZE - 1) / CUDA_OPT_STEP_ADAMW_BLOCK_SIZE, 1, 1);
36
- opt_step_adamw_f32<<<block_nums, block_dims, 0, stream>>>(x, g, g_m, g_v, k, alpha, beta1, beta2, eps, wd, beta1h, beta2h);
37
  }
38
 
39
  void ggml_cuda_opt_step_adamw(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
40
- const ggml_tensor * src0 = dst->src[0];
41
- const ggml_tensor * src0_grad = dst->src[1];
42
- const ggml_tensor * src0_grad_m = dst->src[2];
43
- const ggml_tensor * src0_grad_v = dst->src[3];
44
-
45
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
46
- GGML_ASSERT(src0_grad->type == GGML_TYPE_F32);
47
- GGML_ASSERT(src0_grad_m->type == GGML_TYPE_F32);
48
- GGML_ASSERT(src0_grad_v->type == GGML_TYPE_F32);
 
 
49
  GGML_ASSERT(ggml_is_contiguous(src0));
50
  GGML_ASSERT(ggml_is_contiguous(src0_grad));
51
  GGML_ASSERT(ggml_is_contiguous(src0_grad_m));
52
  GGML_ASSERT(ggml_is_contiguous(src0_grad_v));
 
53
  GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
54
  GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_m));
55
  GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_v));
 
56
 
57
- float * src0_d = (float *) src0->data;
58
- const float * src0_grad_d = (const float *) src0_grad->data;
59
- float * src0_grad_m_d = (float *) src0_grad_m->data;
60
- float * src0_grad_v_d = (float *) src0_grad_v->data;
 
61
 
62
  cudaStream_t stream = ctx.stream();
63
 
64
  const int64_t ne = ggml_nelements(src0);
65
 
66
- int64_t iter; memcpy(&iter, &dst->op_params[0], sizeof(int64_t));
67
- float alpha; memcpy(&alpha, &dst->op_params[2], sizeof(float));
68
- float beta1; memcpy(&beta1, &dst->op_params[3], sizeof(float));
69
- float beta2; memcpy(&beta2, &dst->op_params[4], sizeof(float));
70
- float eps; memcpy(&eps, &dst->op_params[5], sizeof(float));
71
- float wd; memcpy(&wd, &dst->op_params[6], sizeof(float));
72
-
73
- const float beta1h = alpha/(1.0f - powf(beta1, iter));
74
- const float beta2h = 1.0f/(1.0f - powf(beta2, iter));
75
-
76
- opt_step_adamw_f32_cuda(src0_d, src0_grad_d, src0_grad_m_d, src0_grad_v_d, ne, alpha, beta1, beta2, eps, wd, beta1h, beta2h, stream);
77
-
78
- iter++;
79
- memcpy(&dst->op_params[0], &iter, sizeof(int64_t));
80
  }
 
1
+ #include "ggml-impl.h"
2
  #include "opt-step-adamw.cuh"
3
 
4
  #include <cstdint>
5
 
6
  static __global__ void opt_step_adamw_f32(
7
+ float * __restrict__ x, const float * __restrict__ g, float * __restrict__ g_m, float * __restrict__ g_v,
8
+ const float * __restrict__ pars, const int64_t k) {
 
9
 
10
  const int64_t i = (int64_t) blockIdx.x*blockDim.x + threadIdx.x;
11
 
 
13
  return;
14
  }
15
 
16
+ const float alpha = pars[0];
17
+ const float beta1 = pars[1];
18
+ const float beta2 = pars[2];
19
+ const float eps = pars[3];
20
+ const float wd = pars[4];
21
+ const float beta1h = pars[5];
22
+ const float beta2h = pars[6];
23
+
24
  const float gi = g[i];
25
  const float gmi = g_m[i]*beta1 + gi*(1.0f - beta1);
26
  const float gvi = g_v[i]*beta2 + gi*gi*(1.0f - beta2);
 
31
  const float mh = gmi*beta1h;
32
  const float vh = sqrtf(gvi*beta2h) + eps;
33
 
34
+ x[i] = x[i]*(1.0f - alpha*wd) - alpha*mh/vh;
35
  }
36
 
37
  static void opt_step_adamw_f32_cuda(
38
+ float * x, const float * g, float * g_m, float * g_v, const float * pars, const int64_t k, cudaStream_t stream) {
 
 
39
 
40
  const dim3 block_dims(CUDA_OPT_STEP_ADAMW_BLOCK_SIZE, 1, 1);
41
  const dim3 block_nums((k + CUDA_OPT_STEP_ADAMW_BLOCK_SIZE - 1) / CUDA_OPT_STEP_ADAMW_BLOCK_SIZE, 1, 1);
42
+ opt_step_adamw_f32<<<block_nums, block_dims, 0, stream>>>(x, g, g_m, g_v, pars, k);
43
  }
44
 
45
  void ggml_cuda_opt_step_adamw(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
46
+ const ggml_tensor * src0 = dst->src[0];
47
+ const ggml_tensor * src0_grad = dst->src[1];
48
+ const ggml_tensor * src0_grad_m = dst->src[2];
49
+ const ggml_tensor * src0_grad_v = dst->src[3];
50
+ const ggml_tensor * adamw_params = dst->src[4];
51
+
52
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
53
+ GGML_ASSERT(src0_grad->type == GGML_TYPE_F32);
54
+ GGML_ASSERT(src0_grad_m->type == GGML_TYPE_F32);
55
+ GGML_ASSERT(src0_grad_v->type == GGML_TYPE_F32);
56
+ GGML_ASSERT(adamw_params->type == GGML_TYPE_F32);
57
  GGML_ASSERT(ggml_is_contiguous(src0));
58
  GGML_ASSERT(ggml_is_contiguous(src0_grad));
59
  GGML_ASSERT(ggml_is_contiguous(src0_grad_m));
60
  GGML_ASSERT(ggml_is_contiguous(src0_grad_v));
61
+ GGML_ASSERT(ggml_is_contiguous(adamw_params));
62
  GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
63
  GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_m));
64
  GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_v));
65
+ GGML_ASSERT(ggml_nelements(adamw_params) == 7);
66
 
67
+ float * src0_d = (float *) src0->data;
68
+ const float * src0_grad_d = (const float *) src0_grad->data;
69
+ float * src0_grad_m_d = (float *) src0_grad_m->data;
70
+ float * src0_grad_v_d = (float *) src0_grad_v->data;
71
+ const float * adamw_params_d = (const float *) adamw_params->data;
72
 
73
  cudaStream_t stream = ctx.stream();
74
 
75
  const int64_t ne = ggml_nelements(src0);
76
 
77
+ opt_step_adamw_f32_cuda(src0_d, src0_grad_d, src0_grad_m_d, src0_grad_v_d, adamw_params_d, ne, stream);
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  }
ggml/src/ggml-impl.h CHANGED
@@ -196,7 +196,7 @@ void ggml_hash_set_reset(struct ggml_hash_set * hash_set);
196
  static bool ggml_hash_contains(const struct ggml_hash_set * hash_set, struct ggml_tensor * key);
197
 
198
  // returns GGML_HASHSET_FULL if table is full, otherwise the current index of the key or where it should be inserted
199
- static size_t ggml_hash_find(const struct ggml_hash_set * hash_set, struct ggml_tensor * key);
200
 
201
  // returns GGML_HASHSET_ALREADY_EXISTS if key already exists, index otherwise, asserts if table is full
202
  static size_t ggml_hash_insert(struct ggml_hash_set * hash_set, struct ggml_tensor * key);
@@ -210,7 +210,7 @@ static inline size_t ggml_hash(const struct ggml_tensor * p) {
210
  return (size_t)(uintptr_t)p >> 4;
211
  }
212
 
213
- static size_t ggml_hash_find(const struct ggml_hash_set * hash_set, struct ggml_tensor * key) {
214
  size_t h = ggml_hash(key) % hash_set->size;
215
 
216
  // linear probing
@@ -281,13 +281,14 @@ enum ggml_cgraph_eval_order {
281
  };
282
 
283
  struct ggml_cgraph {
284
- int size;
285
- int n_nodes;
286
- int n_leafs;
287
-
288
- struct ggml_tensor ** nodes;
289
- struct ggml_tensor ** grads;
290
- struct ggml_tensor ** leafs;
 
291
 
292
  struct ggml_hash_set visited_hash_set;
293
 
 
196
  static bool ggml_hash_contains(const struct ggml_hash_set * hash_set, struct ggml_tensor * key);
197
 
198
  // returns GGML_HASHSET_FULL if table is full, otherwise the current index of the key or where it should be inserted
199
+ static size_t ggml_hash_find(const struct ggml_hash_set * hash_set, const struct ggml_tensor * key);
200
 
201
  // returns GGML_HASHSET_ALREADY_EXISTS if key already exists, index otherwise, asserts if table is full
202
  static size_t ggml_hash_insert(struct ggml_hash_set * hash_set, struct ggml_tensor * key);
 
210
  return (size_t)(uintptr_t)p >> 4;
211
  }
212
 
213
+ static size_t ggml_hash_find(const struct ggml_hash_set * hash_set, const struct ggml_tensor * key) {
214
  size_t h = ggml_hash(key) % hash_set->size;
215
 
216
  // linear probing
 
281
  };
282
 
283
  struct ggml_cgraph {
284
+ int size; // maximum number of nodes/leafs/grads/grad_accs
285
+ int n_nodes; // number of nodes currently in use
286
+ int n_leafs; // number of leafs currently in use
287
+
288
+ struct ggml_tensor ** nodes; // tensors with data that can change if the graph is evaluated
289
+ struct ggml_tensor ** grads; // the outputs of these tensors are the gradients of the nodes
290
+ struct ggml_tensor ** grad_accs; // accumulators for node gradients
291
+ struct ggml_tensor ** leafs; // tensors with constant data
292
 
293
  struct ggml_hash_set visited_hash_set;
294
 
ggml/src/ggml-metal/ggml-metal.m CHANGED
@@ -3639,6 +3639,12 @@ static void * ggml_backend_metal_buffer_get_base(ggml_backend_buffer_t buffer) {
3639
  return ctx->all_data;
3640
  }
3641
 
 
 
 
 
 
 
3642
  static void ggml_backend_metal_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
3643
  memcpy((char *)tensor->data + offset, data, size);
3644
 
@@ -3671,7 +3677,7 @@ static struct ggml_backend_buffer_i ggml_backend_metal_buffer_i = {
3671
  /* .free_buffer = */ ggml_backend_metal_buffer_free_buffer,
3672
  /* .get_base = */ ggml_backend_metal_buffer_get_base,
3673
  /* .init_tensor = */ NULL,
3674
- /* .memset_tensor = */ NULL,
3675
  /* .set_tensor = */ ggml_backend_metal_buffer_set_tensor,
3676
  /* .get_tensor = */ ggml_backend_metal_buffer_get_tensor,
3677
  /* .cpy_tensor = */ ggml_backend_metal_buffer_cpy_tensor,
 
3639
  return ctx->all_data;
3640
  }
3641
 
3642
+ static void ggml_backend_metal_buffer_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
3643
+ memset((char *)tensor->data + offset, value, size);
3644
+
3645
+ UNUSED(buffer);
3646
+ }
3647
+
3648
  static void ggml_backend_metal_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
3649
  memcpy((char *)tensor->data + offset, data, size);
3650
 
 
3677
  /* .free_buffer = */ ggml_backend_metal_buffer_free_buffer,
3678
  /* .get_base = */ ggml_backend_metal_buffer_get_base,
3679
  /* .init_tensor = */ NULL,
3680
+ /* .memset_tensor = */ ggml_backend_metal_buffer_memset_tensor,
3681
  /* .set_tensor = */ ggml_backend_metal_buffer_set_tensor,
3682
  /* .get_tensor = */ ggml_backend_metal_buffer_get_tensor,
3683
  /* .cpy_tensor = */ ggml_backend_metal_buffer_cpy_tensor,
ggml/src/ggml-opt.cpp ADDED
@@ -0,0 +1,867 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "ggml-opt.h"
2
+
3
+ #include "ggml.h"
4
+ #include "ggml-alloc.h"
5
+ #include "ggml-backend.h"
6
+ #include "ggml-impl.h"
7
+
8
+ #include <algorithm>
9
+ #include <cmath>
10
+ #include <cstdint>
11
+ #include <inttypes.h>
12
+ #include <map>
13
+ #include <random>
14
+ #include <vector>
15
+
16
+ struct ggml_opt_dataset {
17
+ struct ggml_context * ctx;
18
+ ggml_backend_buffer_t buf;
19
+ struct ggml_tensor * data;
20
+ struct ggml_tensor * labels;
21
+
22
+ int64_t ndata;
23
+ int64_t ndata_shard;
24
+ size_t nbs_data;
25
+ size_t nbs_labels;
26
+
27
+ std::vector<int64_t> permutation;
28
+ };
29
+
30
+ struct ggml_opt_context {
31
+ ggml_backend_sched_t backend_sched;
32
+ ggml_cgraph * allocated_graph;
33
+ ggml_cgraph * allocated_graph_copy;
34
+ struct ggml_context * ctx_static;
35
+ struct ggml_context * ctx_static_cpu;
36
+ struct ggml_context * ctx_compute;
37
+ struct ggml_context * ctx_copy;
38
+ ggml_backend_buffer_t buf_static;
39
+ ggml_backend_buffer_t buf_static_cpu;
40
+ std::mt19937 rng;
41
+
42
+ struct ggml_tensor * inputs;
43
+ struct ggml_tensor * outputs;
44
+ struct ggml_tensor * labels;
45
+
46
+ struct ggml_tensor * loss;
47
+ struct ggml_tensor * pred;
48
+ struct ggml_tensor * ncorrect;
49
+
50
+ struct ggml_cgraph * gf;
51
+ struct ggml_cgraph * gb_grad;
52
+ struct ggml_cgraph * gb_opt;
53
+
54
+ int64_t iter;
55
+ int32_t opt_period;
56
+ int32_t opt_i;
57
+ bool loss_per_datapoint;
58
+
59
+ ggml_opt_get_optimizer_params get_opt_pars;
60
+ void * get_opt_pars_ud;
61
+ struct ggml_tensor * adamw_params;
62
+ };
63
+
64
+ struct ggml_opt_result {
65
+ int64_t ndata = 0;
66
+ std::vector<float> loss;
67
+ std::vector<int32_t> pred;
68
+ int64_t ncorrect = 0;
69
+
70
+ bool loss_per_datapoint = false;
71
+ int64_t opt_period = -1;
72
+ };
73
+
74
+ // ====== Dataset ======
75
+
76
+ ggml_opt_dataset_t ggml_opt_dataset_init(int64_t ne_datapoint, int64_t ne_label, int64_t ndata, int64_t ndata_shard) {
77
+ GGML_ASSERT(ne_datapoint > 0);
78
+ GGML_ASSERT(ne_label >= 0);
79
+ GGML_ASSERT(ndata > 0);
80
+ GGML_ASSERT(ndata_shard > 0);
81
+
82
+ ggml_opt_dataset_t result = new ggml_opt_dataset;
83
+ result->ndata = ndata;
84
+ result->ndata_shard = ndata_shard;
85
+
86
+ {
87
+ struct ggml_init_params params = {
88
+ /*.mem_size =*/ 2*ggml_tensor_overhead(),
89
+ /*.mem_buffer =*/ nullptr,
90
+ /*.no_alloc =*/ true,
91
+ };
92
+ result->ctx = ggml_init(params);
93
+ }
94
+
95
+ result->data = ggml_new_tensor_2d(result->ctx, GGML_TYPE_F32, ne_datapoint, ndata);
96
+ result->nbs_data = ggml_nbytes(result->data) * ndata_shard/ndata;
97
+
98
+ if (ne_label > 0) {
99
+ result->labels = ggml_new_tensor_2d(result->ctx, GGML_TYPE_F32, ne_label, ndata);
100
+ result->nbs_labels = ggml_nbytes(result->labels) * ndata_shard/ndata;
101
+ } else {
102
+ result->labels = nullptr;
103
+ result->nbs_labels = 0;
104
+ }
105
+
106
+ result->buf = ggml_backend_alloc_ctx_tensors_from_buft(result->ctx, ggml_backend_cpu_buffer_type());
107
+
108
+ const int64_t nshards = ndata/ndata_shard;
109
+ result->permutation.resize(nshards);
110
+ for (int64_t i = 0; i < nshards; ++i) {
111
+ result->permutation[i] = i;
112
+ }
113
+ return result;
114
+ }
115
+
116
+ void ggml_opt_dataset_free(ggml_opt_dataset_t dataset) {
117
+ ggml_backend_buffer_free(dataset->buf);
118
+ ggml_free(dataset->ctx);
119
+ delete dataset;
120
+ }
121
+
122
+ struct ggml_tensor * ggml_opt_dataset_data(ggml_opt_dataset_t dataset) {
123
+ return dataset->data;
124
+ }
125
+
126
+ struct ggml_tensor * ggml_opt_dataset_labels(ggml_opt_dataset_t dataset) {
127
+ return dataset->labels;
128
+ }
129
+
130
+ void ggml_opt_dataset_shuffle(ggml_opt_context_t opt_ctx, ggml_opt_dataset_t dataset, int64_t idata) {
131
+ GGML_ASSERT(idata <= dataset->ndata);
132
+
133
+ if (idata < 0) {
134
+ std::shuffle(dataset->permutation.begin(), dataset->permutation.end(), opt_ctx->rng);
135
+ return;
136
+ }
137
+
138
+ GGML_ASSERT(idata % dataset->ndata_shard == 0);
139
+ const int64_t ishard_max = idata / dataset->ndata_shard;
140
+ std::shuffle(dataset->permutation.begin(), dataset->permutation.begin() + ishard_max, opt_ctx->rng);
141
+ }
142
+
143
+ void ggml_opt_dataset_get_batch(ggml_opt_dataset_t dataset, struct ggml_tensor * data_batch, struct ggml_tensor * labels_batch, int64_t ibatch) {
144
+ GGML_ASSERT( data_batch && ggml_is_contiguous(data_batch));
145
+ GGML_ASSERT(!labels_batch || ggml_is_contiguous(labels_batch));
146
+ GGML_ASSERT((labels_batch == nullptr) == (dataset->labels == nullptr));
147
+
148
+ const size_t nb_data_batch = ggml_nbytes(data_batch);
149
+ GGML_ASSERT(nb_data_batch % dataset->nbs_data == 0);
150
+ const int64_t shards_per_batch = nb_data_batch / dataset->nbs_data;
151
+
152
+ if (labels_batch) {
153
+ const size_t nb_labels_batch = ggml_nbytes(labels_batch);
154
+ GGML_ASSERT(nb_labels_batch == shards_per_batch*dataset->nbs_labels);
155
+ }
156
+
157
+ GGML_ASSERT((ibatch + 1)*shards_per_batch <= int64_t(dataset->permutation.size()));
158
+
159
+ for (int64_t ishard_batch = 0; ishard_batch < shards_per_batch; ++ishard_batch) {
160
+ const int64_t ishard = dataset->permutation[ibatch*shards_per_batch + ishard_batch];
161
+
162
+ const char * ptr_data = (const char *) dataset->data->data + ishard*dataset->nbs_data;
163
+ ggml_backend_tensor_set(data_batch, ptr_data, ishard_batch*dataset->nbs_data, dataset->nbs_data);
164
+
165
+ if (!labels_batch) {
166
+ continue;
167
+ }
168
+
169
+ const char * ptr_labels = (const char *) dataset->labels->data + ishard*dataset->nbs_labels;
170
+ ggml_backend_tensor_set(labels_batch, ptr_labels, ishard_batch*dataset->nbs_labels, dataset->nbs_labels);
171
+ }
172
+ }
173
+
174
+ // ====== Model / Context ======
175
+
176
+ struct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params(void * userdata) {
177
+ GGML_UNUSED(userdata);
178
+
179
+ ggml_opt_optimizer_params result;
180
+
181
+ result.adamw.alpha = 0.001f;
182
+ result.adamw.beta1 = 0.9f;
183
+ result.adamw.beta2 = 0.999f;
184
+ result.adamw.eps = 1e-8f;
185
+ result.adamw.wd = 0.0f;
186
+
187
+ return result;
188
+ }
189
+
190
+ struct ggml_opt_params ggml_opt_default_params(
191
+ ggml_backend_sched_t backend_sched,
192
+ struct ggml_context * ctx_compute,
193
+ struct ggml_tensor * inputs,
194
+ struct ggml_tensor * outputs,
195
+ enum ggml_opt_loss_type loss_type) {
196
+ return {
197
+ /*backend_sched =*/ backend_sched,
198
+ /*ctx_compute =*/ ctx_compute,
199
+ /*inputs =*/ inputs,
200
+ /*logits =*/ outputs,
201
+ /*loss_type =*/ loss_type,
202
+ /*build_type =*/ GGML_OPT_BUILD_TYPE_OPT,
203
+ /*opt_period =*/ 1,
204
+ /*get_opt_pars =*/ ggml_opt_get_default_optimizer_params,
205
+ /*get_opt_pars_ud =*/ nullptr,
206
+ };
207
+ }
208
+
209
+ static ggml_tensor * map_tensor(std::map<ggml_tensor *, ggml_tensor *> & tensor_map, ggml_context * ctx, ggml_tensor * tensor) {
210
+ if (!tensor) {
211
+ return nullptr;
212
+ }
213
+
214
+ if (tensor_map.find(tensor) != tensor_map.end()) {
215
+ return tensor_map[tensor];
216
+ }
217
+
218
+ ggml_tensor * new_tensor = ggml_dup_tensor(ctx, tensor);
219
+ tensor_map[tensor] = new_tensor;
220
+
221
+ new_tensor->op = tensor->op;
222
+ for (int i = 0; i < GGML_MAX_DIMS; i++) {
223
+ new_tensor->nb[i] = tensor->nb[i];
224
+ }
225
+ new_tensor->flags = tensor->flags;
226
+ memcpy(new_tensor->op_params, tensor->op_params, sizeof(tensor->op_params));
227
+ strcpy(new_tensor->name, tensor->name);
228
+ new_tensor->data = tensor->data;
229
+ new_tensor->buffer = tensor->buffer;
230
+ new_tensor->extra = tensor->extra;
231
+ new_tensor->view_offs = tensor->view_offs;
232
+ new_tensor->view_src = map_tensor(tensor_map, ctx, tensor->view_src);
233
+ for (int i = 0; i < GGML_MAX_SRC; i++) {
234
+ new_tensor->src[i] = map_tensor(tensor_map, ctx, tensor->src[i]);
235
+ }
236
+
237
+ return new_tensor;
238
+ }
239
+
240
+ static ggml_cgraph * dup_graph(ggml_context * ctx, ggml_cgraph * graph) {
241
+ std::map<ggml_tensor *, ggml_tensor *> tensor_map;
242
+
243
+ ggml_cgraph * new_graph = ggml_new_graph_custom(ctx, GGML_DEFAULT_GRAPH_SIZE, /*grads =*/ true);
244
+
245
+ for (int i = 0; i < graph->n_leafs; i++) {
246
+ ggml_build_forward_expand(new_graph, map_tensor(tensor_map, ctx, graph->leafs[i]));
247
+ }
248
+ for (int i = 0; i < graph->n_nodes; i++) {
249
+ ggml_build_forward_expand(new_graph, map_tensor(tensor_map, ctx, graph->nodes[i]));
250
+ }
251
+ for (int i = 0; i < graph->n_nodes; ++i) {
252
+ const size_t igrad_src = ggml_hash_find(&graph->visited_hash_set, graph->nodes[i]);
253
+ const size_t igrad_dst = ggml_hash_find(&new_graph->visited_hash_set, new_graph->nodes[i]);
254
+ graph->grads[igrad_dst] = new_graph->grads[igrad_src];
255
+ graph->grad_accs[igrad_dst] = new_graph->grad_accs[igrad_src];
256
+ }
257
+
258
+ return new_graph;
259
+ }
260
+
261
+ static void ggml_opt_alloc_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph) {
262
+ GGML_ASSERT(graph);
263
+ if (opt_ctx->allocated_graph == graph) {
264
+ return;
265
+ }
266
+
267
+ ggml_backend_sched_reset(opt_ctx->backend_sched); // clear allocation of previous graph
268
+
269
+ {
270
+ ggml_init_params params = {
271
+ /*.mem_size =*/ ggml_tensor_overhead() * GGML_DEFAULT_GRAPH_SIZE,
272
+ /*.mem_buffer =*/ nullptr,
273
+ /*.no_alloc =*/ true,
274
+ };
275
+ ggml_free(opt_ctx->ctx_copy);
276
+ opt_ctx->ctx_copy = ggml_init(params);
277
+ }
278
+
279
+ opt_ctx->allocated_graph_copy = dup_graph(opt_ctx->ctx_copy, graph);
280
+
281
+ ggml_backend_sched_alloc_graph(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy);
282
+ opt_ctx->allocated_graph = graph;
283
+ }
284
+
285
+ ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params) {
286
+ ggml_opt_context_t result = new struct ggml_opt_context;
287
+ result->backend_sched = params.backend_sched;
288
+ result->allocated_graph = nullptr;
289
+ result->allocated_graph_copy = nullptr;
290
+ result->ctx_compute = params.ctx_compute;
291
+ result->ctx_copy = nullptr;
292
+ result->inputs = params.inputs;
293
+ result->outputs = params.outputs;
294
+ result->iter = 1;
295
+ result->opt_period = params.opt_period;
296
+ result->opt_i = 0;
297
+ result->get_opt_pars = params.get_opt_pars;
298
+ result->get_opt_pars_ud = params.get_opt_pars_ud;
299
+
300
+ GGML_ASSERT(result->inputs->data && "the inputs must be allocated statically");
301
+ GGML_ASSERT(result->opt_period >= 1);
302
+
303
+ const bool accumulate = params.build_type == GGML_OPT_BUILD_TYPE_GRAD ||
304
+ (params.build_type == GGML_OPT_BUILD_TYPE_OPT && result->opt_period > 1);
305
+
306
+ ggml_set_input(result->inputs);
307
+ ggml_set_output(result->outputs);
308
+
309
+ result->gf = ggml_new_graph_custom(result->ctx_compute, GGML_DEFAULT_GRAPH_SIZE, /*grads =*/ true); // Forward pass.
310
+ ggml_build_forward_expand(result->gf, result->outputs);
311
+
312
+ int n_param = 0;
313
+ for (int i = 0; i < result->gf->n_nodes; ++i) {
314
+ if (result->gf->nodes[i]->flags & GGML_TENSOR_FLAG_PARAM) {
315
+ n_param++;
316
+ }
317
+ }
318
+
319
+ {
320
+ // The static context is used for:
321
+ // - gradients (1 tensor per param if using gradient accumulation)
322
+ // - optimizer momenta (2 tensors per param)
323
+ // - labels
324
+ // - loss + its gradient (up to 5 tensors)
325
+ // - pred
326
+ // - ncorrect (2 tensors).
327
+ const size_t tensors_per_param = (accumulate ? 1 : 0) + (params.build_type == GGML_OPT_BUILD_TYPE_OPT ? 2 : 0);
328
+ const size_t size_meta = (tensors_per_param*n_param + 9) * ggml_tensor_overhead();
329
+ struct ggml_init_params params = {
330
+ /*.mem_size =*/ size_meta,
331
+ /*.mem_buffer =*/ nullptr,
332
+ /*.no_alloc =*/ true,
333
+ };
334
+ result->ctx_static = ggml_init(params);
335
+ }
336
+ {
337
+ // The static cpu context is used for:
338
+ // - optimizer parameters (1 for the entire context)
339
+ const size_t size_meta = 1 * ggml_tensor_overhead();
340
+ struct ggml_init_params params = {
341
+ /*.mem_size =*/ size_meta,
342
+ /*.mem_buffer =*/ nullptr,
343
+ /*.no_alloc =*/ true,
344
+ };
345
+ result->ctx_static_cpu = ggml_init(params);
346
+ }
347
+
348
+
349
+ switch (params.loss_type) {
350
+ case GGML_OPT_LOSS_TYPE_MEAN: {
351
+ result->labels = nullptr;
352
+ result->loss = ggml_sum(result->ctx_static, result->outputs);
353
+ ggml_set_name(result->loss, "loss_sum");
354
+ const float scale = 1.0f / (result->opt_period * ggml_nelements(result->outputs));
355
+ result->loss = ggml_scale(result->ctx_static, result->loss, scale);
356
+ ggml_set_name(result->loss, "loss_mean");
357
+ result->loss_per_datapoint = true;
358
+ break;
359
+ }
360
+ case GGML_OPT_LOSS_TYPE_SUM: {
361
+ result->labels = nullptr;
362
+ result->loss = ggml_sum(result->ctx_static, result->outputs);
363
+ ggml_set_name(result->loss, "loss_sum");
364
+ result->loss_per_datapoint = false;
365
+ break;
366
+ }
367
+ case GGML_OPT_LOSS_TYPE_CROSS_ENTROPY: {
368
+ result->labels = ggml_dup_tensor(result->ctx_static, result->outputs);
369
+ ggml_set_input(result->labels);
370
+ ggml_set_name(result->labels, "labels");
371
+ result->loss = ggml_cross_entropy_loss(result->ctx_static, result->outputs, result->labels);
372
+ ggml_set_name(result->loss, "loss_cross_entropy");
373
+ if (result->opt_period > 1) {
374
+ result->loss = ggml_scale(result->ctx_static, result->loss, 1.0f / result->opt_period);
375
+ ggml_set_name(result->loss, "loss_cross_entropy_scaled");
376
+ }
377
+ result->loss_per_datapoint = true;
378
+ break;
379
+ }
380
+ case GGML_OPT_LOSS_TYPE_MEAN_SQUARED_ERROR: {
381
+ result->labels = ggml_dup_tensor(result->ctx_static, result->outputs);
382
+ ggml_set_input(result->labels);
383
+ ggml_set_name(result->labels, "labels");
384
+ result->loss = ggml_sub(result->ctx_static, result->outputs, result->labels);
385
+ ggml_set_name(result->loss, "loss_error");
386
+ result->loss = ggml_sqr(result->ctx_static, result->loss);
387
+ ggml_set_name(result->loss, "loss_squared_error");
388
+ result->loss = ggml_sum(result->ctx_static, result->loss);
389
+ ggml_set_name(result->loss, "loss_sum_squared_error");
390
+ const float scale = 1.0f / (result->opt_period * ggml_nelements(result->outputs));
391
+ result->loss = ggml_scale(result->ctx_static, result->loss, scale);
392
+ ggml_set_name(result->loss, "loss_mean_squared_error");
393
+ result->loss_per_datapoint = true;
394
+ break;
395
+ }
396
+ }
397
+ ggml_set_output(result->loss);
398
+ ggml_set_loss(result->loss);
399
+ ggml_build_forward_expand(result->gf, result->loss);
400
+
401
+ result->pred = ggml_argmax(result->ctx_static, result->outputs);
402
+ ggml_set_name(result->pred, "pred");
403
+ ggml_set_output(result->pred);
404
+ ggml_build_forward_expand(result->gf, result->pred);
405
+
406
+ if (result->labels) {
407
+ result->ncorrect = ggml_count_equal(result->ctx_static, result->pred, ggml_argmax(result->ctx_static, result->labels));
408
+ ggml_set_name(result->ncorrect, "ncorrect");
409
+ ggml_set_output(result->ncorrect);
410
+ ggml_build_forward_expand(result->gf, result->ncorrect);
411
+ } else {
412
+ result->ncorrect = nullptr;
413
+ }
414
+
415
+ if (params.build_type == GGML_OPT_BUILD_TYPE_FORWARD) {
416
+ result->gb_grad = nullptr;
417
+ result->gb_opt = nullptr;
418
+
419
+ result->buf_static = ggml_backend_alloc_ctx_tensors(result->ctx_static, ggml_backend_sched_get_backend(result->backend_sched, 0));
420
+ result->buf_static_cpu = nullptr;
421
+
422
+ ggml_opt_alloc_graph(result, result->gf);
423
+
424
+ return result;
425
+ }
426
+
427
+ // gb_grad == graph backward gradients, forward pass, then backward pass to calculate gradients.
428
+ result->gb_grad = ggml_graph_dup(result->ctx_compute, result->gf);
429
+ ggml_build_backward_expand(result->ctx_static, result->ctx_compute, result->gb_grad, accumulate);
430
+
431
+ if (params.build_type == GGML_OPT_BUILD_TYPE_GRAD) {
432
+ result->gb_opt = nullptr;
433
+
434
+ result->buf_static = ggml_backend_alloc_ctx_tensors(result->ctx_static, ggml_backend_sched_get_backend(result->backend_sched, 0));
435
+ result->buf_static_cpu = nullptr;
436
+
437
+ ggml_opt_alloc_graph(result, result->gb_grad);
438
+ ggml_graph_reset(result->gb_grad);
439
+
440
+ return result;
441
+ }
442
+
443
+ GGML_ASSERT(params.build_type == GGML_OPT_BUILD_TYPE_OPT);
444
+
445
+ // gb_opt == graph backward optimize, forward pass, then backward pass to calculate gradients, then optimizer step.
446
+ result->gb_opt = ggml_graph_dup(result->ctx_compute, result->gb_grad);
447
+
448
+ result->adamw_params = ggml_new_tensor_1d(result->ctx_static_cpu, GGML_TYPE_F32, 7);
449
+ ggml_set_input(result->adamw_params);
450
+ ggml_set_name(result->adamw_params, "adamw_params");
451
+
452
+ for (int i = result->gf->n_nodes-1; i >= 0; --i) {
453
+ struct ggml_tensor * node = result->gb_opt->nodes[i];
454
+ struct ggml_tensor * grad = ggml_graph_get_grad(result->gb_opt, node);
455
+
456
+ if (node->flags & GGML_TENSOR_FLAG_PARAM) {
457
+ struct ggml_tensor * m = ggml_dup_tensor(result->ctx_static, node);
458
+ struct ggml_tensor * v = ggml_dup_tensor(result->ctx_static, node);
459
+ struct ggml_tensor * opt_step = ggml_opt_step_adamw(result->ctx_compute, node, grad, m, v, result->adamw_params);
460
+ ggml_build_forward_expand(result->gb_opt, opt_step);
461
+ }
462
+ }
463
+
464
+ result->buf_static = ggml_backend_alloc_ctx_tensors(
465
+ result->ctx_static, ggml_backend_sched_get_backend(result->backend_sched, 0));
466
+
467
+ result->buf_static_cpu = ggml_backend_alloc_ctx_tensors_from_buft(result->ctx_static_cpu, ggml_backend_cpu_buffer_type());
468
+
469
+ ggml_opt_alloc_graph(result, result->gb_opt);
470
+ ggml_graph_reset(result->gb_opt);
471
+
472
+ return result;
473
+ }
474
+
475
+ void ggml_opt_free(ggml_opt_context_t opt_ctx) {
476
+ if (opt_ctx == nullptr) {
477
+ return;
478
+ }
479
+ ggml_backend_buffer_free(opt_ctx->buf_static);
480
+ ggml_backend_buffer_free(opt_ctx->buf_static_cpu);
481
+ ggml_free(opt_ctx->ctx_static);
482
+ ggml_free(opt_ctx->ctx_static_cpu);
483
+ delete opt_ctx;
484
+ }
485
+
486
+ void ggml_opt_reset(ggml_opt_context_t opt_ctx, bool optimizer) {
487
+ if (optimizer) {
488
+ ggml_graph_reset(opt_ctx->gb_opt);
489
+ opt_ctx->iter = 1;
490
+ } else {
491
+ ggml_graph_reset(opt_ctx->gb_grad);
492
+ }
493
+ }
494
+
495
+ struct ggml_tensor * ggml_opt_inputs(ggml_opt_context_t opt_ctx) {
496
+ return opt_ctx->inputs;
497
+ }
498
+
499
+ struct ggml_tensor * ggml_opt_outputs(ggml_opt_context_t opt_ctx) {
500
+ return opt_ctx->outputs;
501
+ }
502
+
503
+ struct ggml_tensor * ggml_opt_labels(ggml_opt_context_t opt_ctx) {
504
+ return opt_ctx->labels;
505
+ }
506
+
507
+ struct ggml_tensor * ggml_opt_loss(ggml_opt_context_t opt_ctx) {
508
+ return opt_ctx->loss;
509
+ }
510
+
511
+ struct ggml_tensor * ggml_opt_pred(ggml_opt_context_t opt_ctx) {
512
+ return opt_ctx->pred;
513
+ }
514
+
515
+ struct ggml_tensor * ggml_opt_ncorrect(ggml_opt_context_t opt_ctx) {
516
+ return opt_ctx->ncorrect;
517
+ }
518
+
519
+ struct ggml_tensor * ggml_opt_grad_acc(ggml_opt_context_t opt_ctx, struct ggml_tensor * node) {
520
+ return ggml_graph_get_grad_acc(opt_ctx->gb_opt, node);
521
+ }
522
+
523
+ // ====== Optimization Result ======
524
+
525
+ ggml_opt_result_t ggml_opt_result_init() {
526
+ return new ggml_opt_result;
527
+ }
528
+
529
+ void ggml_opt_result_free(ggml_opt_result_t result) {
530
+ delete result;
531
+ }
532
+
533
+ void ggml_opt_result_reset(ggml_opt_result_t result) {
534
+ result->ndata = 0;
535
+ result->loss.clear();
536
+ result->pred.clear();
537
+ result->ncorrect = 0;
538
+ }
539
+
540
+ void ggml_opt_result_ndata(ggml_opt_result_t result, int64_t * ndata) {
541
+ *ndata = result->ndata;
542
+ }
543
+
544
+ void ggml_opt_result_loss(ggml_opt_result_t result, double * loss, double * unc) {
545
+ const int64_t nbatches = result->loss.size(); // Number of physical batches.
546
+
547
+ if (nbatches == 0) {
548
+ *loss = 0.0;
549
+ *unc = NAN;
550
+ return;
551
+ }
552
+
553
+ double sum = 0.0;
554
+ double sum_squared = 0.0;
555
+
556
+ for (const float & loss : result->loss) {
557
+ // If the loss is per datapoint it was scaled by 1.0f/opt_period for each physical batch.
558
+ const float loss_scaled = result->loss_per_datapoint ? loss*result->opt_period : loss;
559
+ sum += loss_scaled;
560
+ sum_squared += loss_scaled*loss_scaled;
561
+ }
562
+
563
+ const double mean = sum/nbatches;
564
+ *loss = result->loss_per_datapoint ? mean : sum;
565
+
566
+ if (!unc) {
567
+ return;
568
+ }
569
+
570
+ if (nbatches < 2) {
571
+ *unc = NAN;
572
+ return;
573
+ }
574
+
575
+ const double var_sum = sum_squared/nbatches - mean*mean; // variance without Bessel's correction, i.e. nbatches/(nbatches-1)
576
+ *unc = result->loss_per_datapoint ? sqrt(var_sum / (nbatches - 1)) : sqrt(var_sum * nbatches/(nbatches - 1));
577
+ }
578
+
579
+ void ggml_opt_result_pred(ggml_opt_result_t result, int32_t * pred) {
580
+ for (size_t i = 0; i < result->pred.size(); ++i) {
581
+ pred[i] = result->pred[i];
582
+ }
583
+ }
584
+
585
+ void ggml_opt_result_accuracy(ggml_opt_result_t result, double * accuracy, double * unc) {
586
+ *accuracy = result->ncorrect >= 0 ? double(result->ncorrect) / double(result->ndata) : NAN;
587
+
588
+ if (!unc) {
589
+ return;
590
+ }
591
+
592
+ *unc = result->ncorrect >= 0 && result->ndata >= 2 ?
593
+ sqrt((*accuracy) * (1.0 - (*accuracy)) / double(result->ndata - 1)) : NAN;
594
+ }
595
+
596
+ // ====== Computation ======
597
+
598
+ static void ggml_opt_eval_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph, ggml_opt_result * result) {
599
+ if (graph != opt_ctx->gf) {
600
+ struct ggml_opt_optimizer_params opt_pars = opt_ctx->get_opt_pars(opt_ctx->get_opt_pars_ud);
601
+
602
+ GGML_ASSERT(opt_pars.adamw.alpha > 0.0f);
603
+ GGML_ASSERT(opt_pars.adamw.beta1 >= 0.0f);
604
+ GGML_ASSERT(opt_pars.adamw.beta1 <= 1.0f);
605
+ GGML_ASSERT(opt_pars.adamw.beta2 >= 0.0f);
606
+ GGML_ASSERT(opt_pars.adamw.beta2 <= 1.0f);
607
+ GGML_ASSERT(opt_pars.adamw.eps >= 0.0f);
608
+ GGML_ASSERT(opt_pars.adamw.wd >= 0.0f);
609
+ GGML_ASSERT(opt_pars.adamw.wd <= 1.0f);
610
+
611
+ // beta1, beta2 after applying warmup
612
+ const float beta1h = 1.0f/(1.0f - powf(opt_pars.adamw.beta1, opt_ctx->iter));
613
+ const float beta2h = 1.0f/(1.0f - powf(opt_pars.adamw.beta2, opt_ctx->iter));
614
+
615
+ float * adamw_par_data = ggml_get_data_f32(opt_ctx->adamw_params);
616
+ adamw_par_data[0] = opt_pars.adamw.alpha;
617
+ adamw_par_data[1] = opt_pars.adamw.beta1;
618
+ adamw_par_data[2] = opt_pars.adamw.beta2;
619
+ adamw_par_data[3] = opt_pars.adamw.eps;
620
+ adamw_par_data[4] = opt_pars.adamw.wd;
621
+ adamw_par_data[5] = beta1h;
622
+ adamw_par_data[6] = beta2h;
623
+ }
624
+
625
+ ggml_opt_alloc_graph(opt_ctx, graph);
626
+ ggml_backend_sched_graph_compute(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy);
627
+ opt_ctx->iter += opt_ctx->allocated_graph == opt_ctx->gb_opt;
628
+
629
+ if (!result) {
630
+ return;
631
+ }
632
+
633
+ if (result->ndata == 0) {
634
+ result->loss_per_datapoint = opt_ctx->loss_per_datapoint;
635
+ result->opt_period = opt_ctx->opt_period;
636
+ } else {
637
+ GGML_ASSERT(result->loss_per_datapoint == opt_ctx->loss_per_datapoint);
638
+ GGML_ASSERT(result->opt_period == opt_ctx->opt_period);
639
+ }
640
+
641
+ const int64_t ndata = opt_ctx->outputs->ne[1];
642
+ GGML_ASSERT(result->ndata == ndata*int64_t(result->loss.size()) && "varying batch size not supported");
643
+ result->ndata += ndata;
644
+
645
+ GGML_ASSERT(ggml_is_scalar(opt_ctx->loss));
646
+ GGML_ASSERT(opt_ctx->loss->type == GGML_TYPE_F32);
647
+ float loss;
648
+ ggml_backend_tensor_get(opt_ctx->loss, &loss, 0, ggml_nbytes(opt_ctx->loss));
649
+ result->loss.push_back(loss);
650
+
651
+ GGML_ASSERT(opt_ctx->pred->type == GGML_TYPE_I32);
652
+ std::vector<int32_t> pred(ndata);
653
+ ggml_backend_tensor_get(opt_ctx->pred, pred.data(), 0, ggml_nbytes(opt_ctx->pred));
654
+ result->pred.insert(result->pred.end(), pred.begin(), pred.end());
655
+
656
+ if (!opt_ctx->labels || result->ncorrect < 0) {
657
+ result->ncorrect = -1;
658
+ return;
659
+ }
660
+
661
+ GGML_ASSERT(ggml_is_scalar(opt_ctx->ncorrect));
662
+ GGML_ASSERT(opt_ctx->ncorrect->type == GGML_TYPE_I64);
663
+ int64_t ncorrect;
664
+ ggml_backend_tensor_get(opt_ctx->ncorrect, &ncorrect, 0, ggml_nbytes(opt_ctx->ncorrect));
665
+ result->ncorrect += ncorrect;
666
+ }
667
+
668
+ void ggml_opt_forward(ggml_opt_context_t opt_ctx, ggml_opt_result * result) {
669
+ ggml_opt_eval_graph(opt_ctx, opt_ctx->gf, result);
670
+ }
671
+
672
+ void ggml_opt_forward_backward(ggml_opt_context_t opt_ctx, ggml_opt_result * result) {
673
+ if (opt_ctx->opt_period == 1) {
674
+ ggml_opt_eval_graph(opt_ctx, opt_ctx->gb_opt, result);
675
+ return;
676
+ }
677
+
678
+ const int32_t opt_i_next = (opt_ctx->opt_i + 1) % opt_ctx->opt_period;
679
+ if (opt_i_next == 0) {
680
+ ggml_opt_eval_graph(opt_ctx, opt_ctx->gb_opt, result);
681
+ ggml_opt_reset(opt_ctx, /*optimizer =*/ false);
682
+ } else {
683
+ ggml_opt_eval_graph(opt_ctx, opt_ctx->gb_grad, result);
684
+ }
685
+ opt_ctx->opt_i = opt_i_next;
686
+ }
687
+
688
+ // ====== High-Level Functions ======
689
+
690
+ void ggml_opt_epoch(
691
+ ggml_opt_context_t opt_ctx,
692
+ ggml_opt_dataset_t dataset,
693
+ ggml_opt_result_t result_train,
694
+ ggml_opt_result_t result_eval,
695
+ int64_t idata_split,
696
+ ggml_opt_epoch_callback callback_train,
697
+ ggml_opt_epoch_callback callback_eval) {
698
+ struct ggml_tensor * inputs = ggml_opt_inputs(opt_ctx);
699
+ struct ggml_tensor * labels = ggml_opt_labels(opt_ctx);
700
+ struct ggml_tensor * data = ggml_opt_dataset_data(dataset);
701
+ GGML_ASSERT(data->ne[0] == inputs->ne[0]);
702
+
703
+ const int64_t ndata = data->ne[1];
704
+ const int64_t ndata_batch = inputs->ne[1];
705
+
706
+ GGML_ASSERT(data->ne[1] % inputs->ne[1] == 0);
707
+ const int64_t nbatches = ndata/ndata_batch;
708
+
709
+ idata_split = idata_split < 0 ? ndata : idata_split;
710
+ GGML_ASSERT(idata_split % ndata_batch == 0);
711
+ const int64_t ibatch_split = idata_split / ndata_batch;
712
+
713
+ int64_t ibatch = 0;
714
+ int64_t t_loop_start = ggml_time_us();
715
+ for (; ibatch < ibatch_split; ++ibatch) {
716
+ ggml_opt_dataset_get_batch(dataset, inputs, labels, ibatch);
717
+ ggml_opt_forward_backward(opt_ctx, result_train);
718
+ if (callback_train) {
719
+ callback_train(true, opt_ctx, dataset, result_train, ibatch+1, ibatch_split, t_loop_start);
720
+ }
721
+ }
722
+ t_loop_start = ggml_time_us();
723
+ for (; ibatch < nbatches; ++ibatch) {
724
+ ggml_opt_dataset_get_batch(dataset, inputs, labels, ibatch);
725
+ ggml_opt_forward(opt_ctx, result_eval);
726
+ if (callback_eval) {
727
+ callback_eval(false, opt_ctx, dataset, result_eval, ibatch+1-ibatch_split, nbatches-ibatch_split, t_loop_start);
728
+ }
729
+ }
730
+ }
731
+
732
+ void ggml_opt_epoch_callback_progress_bar(
733
+ bool train,
734
+ ggml_opt_context_t opt_ctx,
735
+ ggml_opt_dataset_t dataset,
736
+ ggml_opt_result_t result,
737
+ int64_t ibatch,
738
+ int64_t ibatch_max,
739
+ int64_t t_start_us) {
740
+ fprintf(stderr, "%s[", train ? "train: " : "val: ");
741
+
742
+ constexpr int64_t bar_length = 25;
743
+ for (int64_t j = 0; j < bar_length; ++j) {
744
+ const int64_t ibatch_j = ibatch_max * j/bar_length;
745
+ if (ibatch_j < ibatch) {
746
+ fprintf(stderr, "=");
747
+ } else if (ibatch_max * (j - 1)/bar_length < ibatch) {
748
+ fprintf(stderr, ">");
749
+ } else {
750
+ fprintf(stderr, " ");
751
+ }
752
+ }
753
+
754
+ const int64_t batch_size = ggml_opt_inputs(opt_ctx)->ne[1];
755
+ const int64_t idata = ibatch*batch_size;
756
+ const int64_t idata_max = ibatch_max*batch_size;
757
+
758
+ double loss;
759
+ double loss_unc;
760
+ ggml_opt_result_loss(result, &loss, &loss_unc);
761
+
762
+ double accuracy;
763
+ double accuracy_unc;
764
+ ggml_opt_result_accuracy(result, &accuracy, &accuracy_unc);
765
+
766
+ const int64_t t_ibatch_us = ggml_time_us() - t_start_us;
767
+ int64_t t_ibatch_s = t_ibatch_us / 1000000;
768
+ const int64_t t_ibatch_h = t_ibatch_s / 3600;
769
+ t_ibatch_s -= t_ibatch_h * 3600;
770
+ const int64_t t_ibatch_m = t_ibatch_s / 60;
771
+ t_ibatch_s -= t_ibatch_m * 60;
772
+
773
+ const int64_t t_eta_us = t_ibatch_us * (ibatch_max - ibatch)/ibatch;
774
+ int64_t t_eta_s = t_eta_us / 1000000;
775
+ const int64_t t_eta_h = t_eta_s / 3600;
776
+ t_eta_s -= t_eta_h * 3600;
777
+ const int64_t t_eta_m = t_eta_s / 60;
778
+ t_eta_s -= t_eta_m * 60;
779
+
780
+ fprintf(stderr, "| data=%06" PRId64 "/%06" PRId64 ", loss=%.6lf+-%.6lf, accuracy=%.2lf+-%.2lf%%, "
781
+ "t=%02" PRId64 ":%02" PRId64 ":%02" PRId64 ", ETA=%02" PRId64 ":%02" PRId64 ":%02" PRId64 "]\r",
782
+ idata, idata_max, loss, loss_unc, 100.0*accuracy, 100.0*accuracy_unc,
783
+ t_ibatch_h, t_ibatch_m, t_ibatch_s, t_eta_h, t_eta_m, t_eta_s);
784
+ if (ibatch == ibatch_max) {
785
+ fprintf(stderr, "\n");
786
+ }
787
+ fflush(stderr);
788
+
789
+ GGML_UNUSED(dataset);
790
+ }
791
+
792
+ void ggml_opt_fit(
793
+ ggml_backend_sched_t backend_sched,
794
+ ggml_context * ctx_compute,
795
+ ggml_tensor * inputs,
796
+ ggml_tensor * outputs,
797
+ ggml_opt_dataset_t dataset,
798
+ enum ggml_opt_loss_type loss_type,
799
+ ggml_opt_get_optimizer_params get_opt_pars,
800
+ int64_t nepoch,
801
+ int64_t nbatch_logical,
802
+ float val_split,
803
+ bool silent) {
804
+ ggml_time_init();
805
+ const int64_t t_start_us = ggml_time_us();
806
+
807
+ const int64_t ndata = ggml_opt_dataset_data(dataset)->ne[1];
808
+ const int64_t nbatch_physical = inputs->ne[1];
809
+ GGML_ASSERT(ndata % nbatch_logical == 0);
810
+ GGML_ASSERT(nbatch_logical % nbatch_physical == 0);
811
+
812
+ const int64_t opt_period = nbatch_logical / nbatch_physical;
813
+ const int64_t nbatches_logical = ndata / nbatch_logical;
814
+
815
+ GGML_ASSERT(val_split >= 0.0f);
816
+ GGML_ASSERT(val_split < 1.0f);
817
+ const int64_t ibatch_split = int64_t(((1.0f - val_split) * nbatches_logical)) * opt_period; // train <-> val split index (physical)
818
+ const int64_t idata_split = ibatch_split * nbatch_physical;
819
+
820
+ int64_t epoch = 1;
821
+
822
+ ggml_opt_params params = ggml_opt_default_params(backend_sched, ctx_compute, inputs, outputs, loss_type);
823
+ params.opt_period = opt_period;
824
+ params.get_opt_pars = get_opt_pars;
825
+ params.get_opt_pars_ud = &epoch;
826
+ ggml_opt_context_t opt_ctx = ggml_opt_init(params);
827
+
828
+ // Shuffling the data is generally useful but there is only a point if not all data is used in a single batch.
829
+ if (nbatch_logical < ndata) {
830
+ ggml_opt_dataset_shuffle(opt_ctx, dataset, -1); // Shuffle all data (train + validation).
831
+ }
832
+
833
+ ggml_opt_result_t result_train = ggml_opt_result_init();
834
+ ggml_opt_result_t result_val = ggml_opt_result_init();
835
+
836
+ ggml_opt_epoch_callback epoch_callback = silent ? nullptr : ggml_opt_epoch_callback_progress_bar;
837
+
838
+ for (; epoch <= nepoch; ++epoch) {
839
+ if (nbatch_logical < idata_split) {
840
+ ggml_opt_dataset_shuffle(opt_ctx, dataset, idata_split);
841
+ }
842
+
843
+ ggml_opt_result_reset(result_train);
844
+ ggml_opt_result_reset(result_val);
845
+
846
+ if (!silent) {
847
+ fprintf(stderr, "%s: epoch %04" PRId64 "/%04" PRId64 ":\n", __func__, epoch, nepoch);
848
+ }
849
+ ggml_opt_epoch(opt_ctx, dataset, result_train, result_val, idata_split, epoch_callback, epoch_callback);
850
+ if (!silent) {
851
+ fprintf(stderr, "\n");
852
+ }
853
+ }
854
+
855
+ if (!silent) {
856
+ int64_t t_total_s = (ggml_time_us() - t_start_us) / 1000000;
857
+ const int64_t t_total_h = t_total_s / 3600;
858
+ t_total_s -= t_total_h * 3600;
859
+ const int64_t t_total_m = t_total_s / 60;
860
+ t_total_s -= t_total_m * 60;
861
+ fprintf(stderr, "%s: training took %02" PRId64 ":%02" PRId64 ":%02" PRId64 "\n", __func__, t_total_h, t_total_m, t_total_s);
862
+ }
863
+
864
+ ggml_opt_free(opt_ctx);
865
+ ggml_opt_result_free(result_train);
866
+ ggml_opt_result_free(result_val);
867
+ }
ggml/src/ggml.c CHANGED
@@ -1592,14 +1592,13 @@ static struct ggml_tensor * ggml_new_tensor_impl(
1592
  /*.op =*/ GGML_OP_NONE,
1593
  /*.op_params =*/ { 0 },
1594
  /*.flags =*/ 0,
1595
- /*.grad =*/ NULL,
1596
  /*.src =*/ { NULL },
1597
  /*.view_src =*/ view_src,
1598
  /*.view_offs =*/ view_offs,
1599
  /*.data =*/ obj_alloc_size > 0 ? (void *)(result + 1) : data,
1600
  /*.name =*/ { 0 },
1601
  /*.extra =*/ NULL,
1602
- ///*.padding =*/ { 0 },
1603
  };
1604
 
1605
  #ifdef __clang__
@@ -4194,8 +4193,6 @@ struct ggml_tensor * ggml_flash_attn_ext(
4194
  GGML_ASSERT(mask);
4195
  }
4196
 
4197
- bool is_node = false;
4198
-
4199
  // permute(0, 2, 1, 3)
4200
  int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] };
4201
  struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
@@ -4203,8 +4200,7 @@ struct ggml_tensor * ggml_flash_attn_ext(
4203
  float params[] = { scale, max_bias, logit_softcap };
4204
  ggml_set_op_params(result, params, sizeof(params));
4205
 
4206
- result->op = GGML_OP_FLASH_ATTN_EXT;
4207
- result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
4208
  result->src[0] = q;
4209
  result->src[1] = k;
4210
  result->src[2] = v;
@@ -4272,14 +4268,6 @@ struct ggml_tensor * ggml_flash_attn_back(
4272
 
4273
  GGML_ASSERT(ne2 % kvne2 == 0);
4274
 
4275
- bool is_node = false;
4276
-
4277
- if (q->grad || k->grad || v->grad) {
4278
- // when using this operation (in backwards pass) these grads are set.
4279
- // we don't want to create (big) grad of our result, so is_node is false.
4280
- is_node = false;
4281
- }
4282
-
4283
  // store gradients of q, k and v as continuous tensors concatenated in result.
4284
  // note: v and gradv are actually transposed, i.e. v->ne[0] != D.
4285
  const int64_t elem_q = ggml_nelements(q);
@@ -4302,8 +4290,7 @@ struct ggml_tensor * ggml_flash_attn_back(
4302
  int32_t masked_i = masked ? 1 : 0;
4303
  ggml_set_op_params(result, &masked_i, sizeof(masked_i));
4304
 
4305
- result->op = GGML_OP_FLASH_ATTN_BACK;
4306
- result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
4307
  result->src[0] = q;
4308
  result->src[1] = k;
4309
  result->src[2] = v;
@@ -4945,34 +4932,24 @@ struct ggml_tensor * ggml_opt_step_adamw(
4945
  struct ggml_context * ctx,
4946
  struct ggml_tensor * a,
4947
  struct ggml_tensor * grad,
4948
- float alpha,
4949
- float beta1,
4950
- float beta2,
4951
- float eps,
4952
- float wd) {
4953
  GGML_ASSERT(a->flags & GGML_TENSOR_FLAG_PARAM);
4954
  GGML_ASSERT(ggml_are_same_shape(a, grad));
4955
- GGML_ASSERT(alpha > 0.0f);
4956
- GGML_ASSERT(beta1 >= 0.0f && beta1 <= 1.0f);
4957
- GGML_ASSERT(beta2 >= 0.0f && beta2 <= 1.0f);
4958
- GGML_ASSERT(eps >= 0.0f);
4959
- GGML_ASSERT(wd >= 0.0f && wd <= 1.0f);
4960
 
4961
  struct ggml_tensor * result = ggml_view_tensor(ctx, a);
4962
 
4963
- const int64_t iter = 1;
4964
- memcpy(&result->op_params[0], &iter, sizeof(int64_t));
4965
- ggml_set_op_params_f32(result, 2, alpha);
4966
- ggml_set_op_params_f32(result, 3, beta1);
4967
- ggml_set_op_params_f32(result, 4, beta2);
4968
- ggml_set_op_params_f32(result, 5, eps);
4969
- ggml_set_op_params_f32(result, 6, wd);
4970
-
4971
  result->op = GGML_OP_OPT_STEP_ADAMW;
4972
  result->src[0] = a;
4973
  result->src[1] = grad;
4974
- result->src[2] = ggml_dup_tensor(ctx, grad);
4975
- result->src[3] = ggml_dup_tensor(ctx, grad);
 
4976
 
4977
  return result;
4978
  }
@@ -5041,1112 +5018,514 @@ static void ggml_hash_map_free(struct hash_map * map) {
5041
  GGML_FREE(map);
5042
  }
5043
 
5044
- // gradient checkpointing
5045
-
5046
- static struct ggml_tensor * ggml_recompute_graph_node(
5047
- struct ggml_context * ctx,
5048
- struct ggml_cgraph * graph,
5049
- struct hash_map * replacements,
5050
- struct ggml_tensor * node) {
5051
-
5052
- if (node == NULL) {
5053
- return NULL;
5054
- }
5055
-
5056
- if (node->flags & GGML_TENSOR_FLAG_PARAM) {
5057
- return node;
5058
- }
5059
-
5060
- if (!ggml_hash_contains(&graph->visited_hash_set, node)) {
5061
- return node;
5062
- }
5063
-
5064
- int count_children = 0;
5065
- for (int k = 0; k < GGML_MAX_SRC; ++k) {
5066
- if (node->src[k]) {
5067
- ++count_children;
5068
- }
5069
- }
5070
-
5071
- if (count_children == 0) {
5072
- return node;
5073
- }
5074
-
5075
- size_t i = ggml_hash_find(&replacements->set, node);
5076
- GGML_ASSERT(i != GGML_HASHSET_FULL); // assert that not full
5077
- if (replacements->set.keys[i] == node) {
5078
- return replacements->vals[i];
5079
- }
5080
-
5081
- struct ggml_tensor * clone = ggml_new_tensor(ctx, node->type, GGML_MAX_DIMS, node->ne);
5082
-
5083
- // insert clone into replacements
5084
- GGML_ASSERT(replacements->set.keys[i] == NULL); // assert that we don't overwrite
5085
- replacements->set.keys[i] = node;
5086
- replacements->vals[i] = clone;
5087
-
5088
- clone->op = node->op;
5089
- clone->grad = node->grad;
5090
- clone->flags = node->flags;
5091
- clone->extra = node->extra;
5092
- for (int k = 0; k < GGML_MAX_DIMS; ++k) {
5093
- clone->nb[k] = node->nb[k];
5094
- }
5095
- for (int k = 0; k < GGML_MAX_SRC; ++k) {
5096
- clone->src[k] = ggml_recompute_graph_node(ctx, graph, replacements, node->src[k]);
5097
- }
5098
- if (node->view_src != NULL) {
5099
- clone->data = (node->view_src->data == NULL)
5100
- ? NULL // view_src not yet allocated
5101
- : (char *) node->view_src->data // view_src already allocated
5102
- + node->view_offs;
5103
- clone->view_src = node->view_src;
5104
- clone->view_offs = node->view_offs;
5105
- }
5106
-
5107
- GGML_ASSERT(sizeof(node->op_params) == sizeof(int32_t) * (GGML_MAX_OP_PARAMS / sizeof(int32_t)));
5108
- GGML_ASSERT(sizeof(node->name) == GGML_MAX_NAME);
5109
- memcpy(clone->op_params, node->op_params, sizeof(node->op_params));
5110
- ggml_format_name(clone, "%s (clone)", ggml_get_name(node));
5111
-
5112
- return clone;
5113
- }
5114
-
5115
- void ggml_build_backward_gradient_checkpointing(
5116
- struct ggml_context * ctx,
5117
- struct ggml_cgraph * gf,
5118
- struct ggml_cgraph * gb,
5119
- struct ggml_cgraph * gb_tmp,
5120
- struct ggml_tensor * * checkpoints,
5121
- int n_checkpoints) {
5122
- ggml_graph_cpy(gf, gb_tmp);
5123
- ggml_build_backward_expand(ctx, gf, gb_tmp, false);
5124
-
5125
- if (n_checkpoints <= 0) {
5126
- ggml_graph_cpy(gb_tmp, gb);
5127
- return;
5128
- }
5129
-
5130
- struct hash_map * replacements = ggml_new_hash_map(gf->n_nodes + gf->n_leafs + n_checkpoints);
5131
-
5132
- // insert checkpoints in replacements
5133
- for (int i = 0; i < n_checkpoints; ++i) {
5134
- size_t k = ggml_hash_find(&replacements->set, checkpoints[i]);
5135
- GGML_ASSERT(k != GGML_HASHSET_FULL); // assert that not full
5136
- GGML_ASSERT(replacements->set.keys[k] == NULL); // assert that we don't overwrite
5137
- replacements->set.keys[k] = checkpoints[i];
5138
- replacements->vals[k] = checkpoints[i];
5139
- }
5140
-
5141
- ggml_graph_cpy(gf, gb);
5142
- // rewrite gb_tmp->nodes[gf->n_nodes:gb_tmp->n_nodes],
5143
- // replacing references to gb_tmp->nodes[0:gf->n_nodes] ( == gf->nodes[0:gf->n_nodes]),
5144
- // by recomputing them from checkpoints
5145
- for (int i = gf->n_nodes; i<gb_tmp->n_nodes; ++i) {
5146
- struct ggml_tensor * node = gb_tmp->nodes[i];
5147
- for (int k = 0; k < GGML_MAX_SRC; ++k) {
5148
- // insert new tensors recomputing src, reusing already made replacements,
5149
- // remember replacements: remember new tensors with mapping from corresponding gf nodes
5150
- // recurse for input tensors,
5151
- // unless (i.e. terminating when) input tensors are replacements (like checkpoints)
5152
- node->src[k] = ggml_recompute_graph_node(ctx, gf, replacements, node->src[k]);
5153
- }
5154
- // insert rewritten backward node with replacements made into resulting backward graph gb
5155
- ggml_build_forward_expand(gb, node);
5156
- }
5157
-
5158
- ggml_hash_map_free(replacements);
5159
- }
5160
-
5161
  // utility functions to change gradients
5162
  // if a is in acc_table, modify gradients in-place and mark result as gradient accumulator
5163
  // else if a is in zero_table, replace a
5164
  // else, just add/subtract/etc. the gradients
5165
 
5166
- static struct ggml_tensor * ggml_add_or_set(
5167
- struct ggml_context * ctx,
5168
- struct ggml_tensor * a,
5169
- struct ggml_tensor * b,
5170
- struct ggml_hash_set * zero_table,
5171
- struct ggml_hash_set * acc_table) {
5172
- if (ggml_hash_contains(acc_table, a)) {
5173
- struct ggml_tensor * ret = ggml_add_impl(ctx, a, b, true);
5174
- const size_t insert_result = ggml_hash_insert(acc_table, ret);
5175
- GGML_ASSERT(insert_result != GGML_HASHSET_FULL);
5176
- GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS);
5177
- return ret;
5178
- }
5179
- if (ggml_hash_contains(zero_table, a)) {
5180
- return b;
5181
  }
5182
- return ggml_add_impl(ctx, a, b, false);
5183
  }
5184
 
5185
- static struct ggml_tensor * ggml_acc_or_set(
5186
- struct ggml_context * ctx,
5187
- struct ggml_tensor * a,
5188
- struct ggml_tensor * b,
5189
- const size_t nb1,
5190
- const size_t nb2,
5191
- const size_t nb3,
5192
- const size_t offset,
5193
- struct ggml_hash_set * zero_table,
5194
- struct ggml_hash_set * acc_table) {
5195
- if (ggml_hash_contains(acc_table, a)) {
5196
- struct ggml_tensor * ret = ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, true);
5197
- const size_t insert_result = ggml_hash_insert(acc_table, ret);
5198
- GGML_ASSERT(insert_result != GGML_HASHSET_FULL);
5199
- GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS);
5200
- return ret;
5201
- }
5202
- if (ggml_hash_contains(zero_table, a)) {
5203
- struct ggml_tensor * a_zero = ggml_scale(ctx, a, 0.0f); // FIXME this is going to produce NaN if a contains inf/NaN
5204
- return ggml_acc_impl(ctx, a_zero, b, nb1, nb2, nb3, offset, false);
5205
  }
5206
- return ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false);
5207
  }
5208
 
5209
- static struct ggml_tensor * ggml_add1_or_set(
5210
- struct ggml_context * ctx,
5211
- struct ggml_tensor * a,
5212
- struct ggml_tensor * b,
5213
- struct ggml_hash_set * zero_table,
5214
- struct ggml_hash_set * acc_table) {
5215
- if (ggml_hash_contains(acc_table, a)) {
5216
- struct ggml_tensor * ret = ggml_add1_impl(ctx, a, b, true);
5217
- const size_t insert_result = ggml_hash_insert(acc_table, ret);
5218
- GGML_ASSERT(insert_result != GGML_HASHSET_FULL);
5219
- GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS);
5220
- return ret;
5221
- }
5222
- if (ggml_hash_contains(zero_table, a)) {
5223
- return ggml_repeat(ctx, b, a);
5224
  }
5225
- return ggml_add1_impl(ctx, a, b, false);
5226
  }
5227
 
5228
- static struct ggml_tensor * ggml_sub_or_set(
5229
- struct ggml_context * ctx,
5230
- struct ggml_tensor * a,
5231
- struct ggml_tensor * b,
5232
- struct ggml_hash_set * zero_table,
5233
- struct ggml_hash_set * acc_table) {
5234
- if (ggml_hash_contains(acc_table, a)) {
5235
- struct ggml_tensor * ret = ggml_sub_impl(ctx, a, b, true);
5236
- const size_t insert_result = ggml_hash_insert(acc_table, ret);
5237
- GGML_ASSERT(insert_result != GGML_HASHSET_FULL);
5238
- GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS);
5239
- return ret;
5240
- }
5241
- if (ggml_hash_contains(zero_table, a)) {
5242
- return ggml_neg(ctx, b);
5243
  }
5244
- return ggml_sub_impl(ctx, a, b, false);
5245
  }
5246
 
5247
- static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, struct ggml_hash_set * zero_table, struct ggml_hash_set * acc_table) {
 
 
 
 
 
 
 
 
5248
  struct ggml_tensor * src0 = tensor->src[0];
5249
  struct ggml_tensor * src1 = tensor->src[1];
5250
  struct ggml_tensor * src2 = tensor->src[2];
 
 
 
 
 
 
 
5251
 
5252
  switch (tensor->op) {
5253
- case GGML_OP_DUP:
5254
- {
5255
- if (src0->grad) {
5256
- src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
5257
- }
5258
- } break;
5259
- case GGML_OP_ADD:
5260
- {
5261
- if (src0->grad) {
5262
- src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
5263
- }
5264
- if (src1->grad) {
5265
- if (ggml_are_same_shape(src0, src1)) {
5266
- src1->grad = ggml_add_or_set(ctx, src1->grad, tensor->grad, zero_table, acc_table);
5267
- } else {
5268
- src1->grad = ggml_add_or_set(ctx, src1->grad, ggml_repeat_back(ctx, tensor->grad, src1), zero_table, acc_table);
5269
- }
5270
- }
5271
- } break;
5272
- case GGML_OP_ADD1:
5273
- {
5274
- if (src0->grad) {
5275
- src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
5276
- }
5277
- if (src1->grad) {
5278
- src1->grad = ggml_add_or_set(ctx,
5279
- src1->grad,
5280
- ggml_mean(ctx, tensor->grad), // TODO: should probably be sum instead of mean
5281
- zero_table, acc_table);
5282
- }
5283
- } break;
5284
- case GGML_OP_ACC:
5285
- {
5286
- if (src0->grad) {
5287
- src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
5288
- }
5289
- if (src1->grad) {
5290
- const size_t nb1 = ((int32_t *) tensor->op_params)[0];
5291
- const size_t nb2 = ((int32_t *) tensor->op_params)[1];
5292
- const size_t nb3 = ((int32_t *) tensor->op_params)[2];
5293
- const size_t offset = ((int32_t *) tensor->op_params)[3];
5294
-
5295
- struct ggml_tensor * tensor_grad_view = ggml_view_4d(ctx,
5296
- tensor->grad,
5297
- src1->grad->ne[0],
5298
- src1->grad->ne[1],
5299
- src1->grad->ne[2],
5300
- src1->grad->ne[3],
5301
- nb1, nb2, nb3, offset);
5302
-
5303
- src1->grad =
5304
- ggml_add_or_set(ctx,
5305
- src1->grad,
5306
- ggml_reshape(ctx,
5307
- ggml_cont(ctx, tensor_grad_view),
5308
- src1->grad),
5309
- zero_table, acc_table);
5310
- }
5311
- } break;
5312
- case GGML_OP_SUB:
5313
- {
5314
- if (src0->grad) {
5315
- src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
5316
- }
5317
- if (src1->grad) {
5318
- src1->grad = ggml_sub_or_set(ctx, src1->grad, tensor->grad, zero_table, acc_table);
5319
- }
5320
- } break;
5321
- case GGML_OP_MUL:
5322
- {
5323
- if (src0->grad) {
5324
- src0->grad =
5325
- ggml_add_or_set(ctx,
5326
- src0->grad,
5327
- ggml_mul(ctx, src1, tensor->grad),
5328
- zero_table, acc_table);
5329
- }
5330
- if (src1->grad) {
5331
- src1->grad =
5332
- ggml_add_or_set(ctx,
5333
- src1->grad,
5334
- ggml_mul(ctx, src0, tensor->grad),
5335
- zero_table, acc_table);
5336
- }
5337
- } break;
5338
- case GGML_OP_DIV:
5339
- {
5340
- if (src0->grad) {
5341
- src0->grad =
5342
- ggml_add_or_set(ctx,
5343
- src0->grad,
5344
- ggml_div(ctx, tensor->grad, src1),
5345
- zero_table, acc_table);
5346
- }
5347
- if (src1->grad) {
5348
- src1->grad =
5349
- ggml_sub_or_set(ctx,
5350
- src1->grad,
5351
- ggml_mul(ctx,
5352
- tensor->grad,
5353
- ggml_div(ctx, tensor, src1)),
5354
- zero_table, acc_table);
5355
- }
5356
- } break;
5357
- case GGML_OP_SQR:
5358
- {
5359
- if (src0->grad) {
5360
- src0->grad =
5361
- ggml_add_or_set(ctx,
5362
- src0->grad,
5363
- ggml_scale(ctx,
5364
- ggml_mul(ctx, src0, tensor->grad),
5365
- 2.0f),
5366
- zero_table, acc_table);
5367
- }
5368
- } break;
5369
- case GGML_OP_SQRT:
5370
- {
5371
- if (src0->grad) {
5372
- src0->grad =
5373
- ggml_add_or_set(ctx,
5374
- src0->grad,
5375
- ggml_scale(ctx,
5376
- ggml_div(ctx,
5377
- tensor->grad,
5378
- tensor),
5379
- 0.5f),
5380
- zero_table, acc_table);
5381
- }
5382
- } break;
5383
- case GGML_OP_LOG:
5384
- {
5385
- if (src0->grad) {
5386
- src0->grad =
5387
- ggml_add_or_set(ctx,
5388
- src0->grad,
5389
- ggml_div(ctx,
5390
- tensor->grad,
5391
- src0),
5392
- zero_table, acc_table);
5393
- }
5394
- } break;
5395
- case GGML_OP_SIN:
5396
- {
5397
- if (src0->grad) {
5398
- src0->grad =
5399
- ggml_add_or_set(ctx,
5400
- src0->grad,
5401
- ggml_mul(ctx,
5402
- tensor->grad,
5403
- ggml_cos(ctx, src0)),
5404
- zero_table, acc_table);
5405
- }
5406
- } break;
5407
- case GGML_OP_COS:
5408
- {
5409
- if (src0->grad) {
5410
- src0->grad =
5411
- ggml_sub_or_set(ctx,
5412
- src0->grad,
5413
- ggml_mul(ctx,
5414
- tensor->grad,
5415
- ggml_sin(ctx, src0)),
5416
- zero_table, acc_table);
5417
- }
5418
- } break;
5419
- case GGML_OP_SUM:
5420
- {
5421
- if (src0->grad) {
5422
- src0->grad =
5423
- ggml_add1_or_set(ctx,
5424
- src0->grad,
5425
- tensor->grad,
5426
- zero_table, acc_table);
5427
- }
5428
- } break;
5429
- case GGML_OP_SUM_ROWS:
5430
- {
5431
- if (src0->grad) {
5432
- src0->grad =
5433
- ggml_add_or_set(ctx,
5434
- src0->grad,
5435
- ggml_repeat(ctx,
5436
- tensor->grad,
5437
- src0->grad),
5438
- zero_table, acc_table);
5439
- }
5440
- } break;
5441
- case GGML_OP_MEAN:
5442
- case GGML_OP_ARGMAX:
5443
- case GGML_OP_COUNT_EQUAL:
5444
- {
5445
- GGML_ABORT("fatal error"); // TODO: implement
5446
- }
5447
- case GGML_OP_REPEAT:
5448
- {
5449
- // necessary for llama
5450
- if (src0->grad) {
5451
- src0->grad = ggml_add_or_set(ctx,
5452
- src0->grad,
5453
- ggml_repeat_back(ctx, tensor->grad, src0->grad),
5454
- zero_table, acc_table);
5455
- }
5456
- } break;
5457
- case GGML_OP_REPEAT_BACK:
5458
- {
5459
- if (src0->grad) {
5460
- // TODO: test this
5461
- src0->grad = ggml_add_or_set(ctx,
5462
- src0->grad,
5463
- ggml_repeat(ctx, tensor->grad, src0->grad),
5464
- zero_table, acc_table);
5465
- }
5466
- } break;
5467
- case GGML_OP_CONCAT:
5468
- {
5469
- GGML_ABORT("fatal error"); // TODO: implement
5470
- }
5471
- case GGML_OP_SILU_BACK:
5472
- {
5473
- GGML_ABORT("fatal error"); // TODO: not implemented
5474
  }
5475
- case GGML_OP_NORM:
5476
- {
5477
- GGML_ABORT("fatal error"); // TODO: not implemented
 
5478
  }
5479
- case GGML_OP_RMS_NORM:
5480
- {
5481
- // necessary for llama
5482
- if (src0->grad) {
5483
- float eps;
5484
- memcpy(&eps, tensor->op_params, sizeof(float));
5485
-
5486
- src0->grad = ggml_add_or_set(ctx,
5487
- src0->grad,
5488
- ggml_rms_norm_back(ctx, src0, tensor->grad, eps),
5489
- zero_table, acc_table);
5490
  }
5491
- } break;
5492
- case GGML_OP_RMS_NORM_BACK:
5493
- {
5494
- GGML_ABORT("fatal error"); // TODO: not implemented
5495
  }
5496
- case GGML_OP_GROUP_NORM:
5497
- {
5498
- GGML_ABORT("fatal error"); // TODO: not implemented
 
5499
  }
5500
- case GGML_OP_MUL_MAT:
5501
- {
5502
- // https://cs231n.github.io/optimization-2/#staged
5503
- // # forward pass
5504
- // s0 = np.random.randn(5, 10)
5505
- // s1 = np.random.randn(10, 3)
5506
- // t = s0.dot(s1)
5507
-
5508
- // # now suppose we had the gradient on t from above in the circuit
5509
- // dt = np.random.randn(*t.shape) # same shape as t
5510
- // ds0 = dt.dot(s1.T) #.T gives the transpose of the matrix
5511
- // ds1 = t.T.dot(dt)
5512
-
5513
- // tensor.shape [m,p,qq,rr]
5514
- // src0.shape [n,m,q1,r1]
5515
- // src1.shape [n,p,qq,rr]
5516
-
5517
- // necessary for llama
5518
- if (src0->grad) {
5519
- struct ggml_tensor * s1_tg =
5520
- ggml_out_prod(ctx, // [n,m,qq,rr]
5521
- src1, // [n,p,qq,rr]
5522
- tensor->grad); // [m,p,qq,rr]
5523
- const int64_t qq = s1_tg->ne[2];
5524
- const int64_t rr = s1_tg->ne[3];
5525
- const int64_t q1 = src0->ne[2];
5526
- const int64_t r1 = src0->ne[3];
5527
- const bool ne2_broadcasted = qq > q1;
5528
- const bool ne3_broadcasted = rr > r1;
5529
- if (ne2_broadcasted || ne3_broadcasted) {
5530
- // sum broadcast repetitions of s1_tg into shape of src0
5531
- s1_tg = ggml_repeat_back(ctx, s1_tg, src0);
5532
- }
5533
- src0->grad =
5534
- ggml_add_or_set(ctx,
5535
- src0->grad, // [n,m,q1,r1]
5536
- s1_tg, // [n,m,q1,r1]
5537
- zero_table, acc_table);
5538
- }
5539
- if (src1->grad) {
5540
- src1->grad =
5541
- ggml_add_or_set(ctx,
5542
- src1->grad, // [n,p,qq,rr]
5543
- // ggml_mul_mat(ctx, // [n,p,qq,rr]
5544
- // ggml_cont(ctx, // [m,n,q1,r1]
5545
- // ggml_transpose(ctx, src0)), // [m,n,q1,r1]
5546
- // tensor->grad), // [m,p,qq,rr]
5547
-
5548
- // // when src0 is bigger than tensor->grad (this is mostly the case in llama),
5549
- // // avoid transpose of src0, rather transpose smaller tensor->grad
5550
- // // and then use ggml_out_prod
5551
- ggml_out_prod(ctx, // [n,p,qq,rr]
5552
- src0, // [n,m,q1,r1]
5553
- ggml_transpose(ctx, // [p,m,qq,rr]
5554
- tensor->grad)), // [m,p,qq,rr]
5555
- zero_table, acc_table);
5556
- }
5557
- } break;
5558
- case GGML_OP_MUL_MAT_ID:
5559
- {
5560
- GGML_ABORT("fatal error"); // TODO: not implemented
5561
  }
5562
- case GGML_OP_OUT_PROD:
5563
- {
5564
- GGML_ABORT("fatal error"); // TODO: not implemented
 
5565
  }
5566
- case GGML_OP_SCALE:
5567
- {
5568
- // necessary for llama
5569
- if (src0->grad) {
5570
- float s;
5571
- memcpy(&s, tensor->op_params, sizeof(float));
5572
-
5573
- src0->grad =
5574
- ggml_add_or_set(ctx,
5575
- src0->grad,
5576
- ggml_scale_impl(ctx, tensor->grad, s, false),
5577
- zero_table, acc_table);
5578
- }
5579
- } break;
5580
- case GGML_OP_SET:
5581
- {
5582
- const size_t nb1 = ((int32_t *) tensor->op_params)[0];
5583
- const size_t nb2 = ((int32_t *) tensor->op_params)[1];
5584
- const size_t nb3 = ((int32_t *) tensor->op_params)[2];
5585
- const size_t offset = ((int32_t *) tensor->op_params)[3];
5586
-
5587
- struct ggml_tensor * tensor_grad_view = NULL;
5588
-
5589
- if (src0->grad || src1->grad) {
5590
- GGML_ASSERT(src0->type == tensor->type);
5591
- GGML_ASSERT(tensor->grad->type == tensor->type);
5592
- GGML_ASSERT(!src1->grad || src1->grad->type == tensor->grad->type);
5593
-
5594
- tensor_grad_view = ggml_view_4d(ctx,
5595
- tensor->grad, src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
5596
- nb1, nb2, nb3, offset);
5597
- }
5598
 
5599
- if (src0->grad) {
5600
- src0->grad = ggml_add_or_set(ctx,
5601
- src0->grad,
5602
- ggml_acc_impl(ctx,
5603
- tensor->grad,
5604
- ggml_neg(ctx, tensor_grad_view),
5605
- nb1, nb2, nb3, offset, false),
5606
- zero_table, acc_table);
5607
- }
5608
 
5609
- if (src1->grad) {
5610
- src1->grad =
5611
- ggml_add_or_set(ctx,
5612
- src1->grad,
5613
- ggml_reshape(ctx,
5614
- ggml_cont(ctx, tensor_grad_view),
5615
- src1->grad),
5616
- zero_table, acc_table);
5617
- }
5618
- } break;
5619
- case GGML_OP_CPY:
5620
- {
5621
- // necessary for llama
5622
- // cpy overwrites value of src1 by src0 and returns view(src1)
5623
- // the overwriting is mathematically equivalent to:
5624
- // tensor = src0 * 1 + src1 * 0
5625
- if (src0->grad) {
5626
- // dsrc0 = dtensor * 1
5627
- src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
5628
- }
5629
- if (src1->grad) {
5630
- // dsrc1 = dtensor * 0 -> noop
5631
- }
5632
- } break;
5633
- case GGML_OP_CONT:
5634
- {
5635
- // same as cpy
5636
- if (src0->grad) {
5637
- GGML_ASSERT(ggml_is_contiguous(src0->grad));
5638
- GGML_ASSERT(ggml_is_contiguous(tensor->grad));
5639
- src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
5640
- }
5641
- } break;
5642
- case GGML_OP_RESHAPE:
5643
- {
5644
- // necessary for llama
5645
- if (src0->grad) {
5646
- src0->grad =
5647
- ggml_add_or_set(ctx, src0->grad,
5648
- ggml_reshape(ctx,
5649
- ggml_is_contiguous(tensor->grad)
5650
- ? tensor->grad
5651
- : ggml_cont(ctx, tensor->grad),
5652
- src0->grad),
5653
- zero_table, acc_table);
5654
- }
5655
- } break;
5656
- case GGML_OP_VIEW:
5657
- {
5658
- // necessary for llama
5659
- if (src0->grad) {
5660
- size_t offset;
5661
-
5662
- memcpy(&offset, tensor->op_params, sizeof(offset));
5663
-
5664
- size_t nb1 = tensor->nb[1];
5665
- size_t nb2 = tensor->nb[2];
5666
- size_t nb3 = tensor->nb[3];
5667
-
5668
- if (src0->type != src0->grad->type) {
5669
- // gradient is typically F32, but src0 could be other type
5670
- size_t ng = ggml_element_size(src0->grad);
5671
- size_t n0 = ggml_element_size(src0);
5672
- GGML_ASSERT(offset % n0 == 0);
5673
- GGML_ASSERT(nb1 % n0 == 0);
5674
- GGML_ASSERT(nb2 % n0 == 0);
5675
- GGML_ASSERT(nb3 % n0 == 0);
5676
- offset = (offset / n0) * ng;
5677
- nb1 = (nb1 / n0) * ng;
5678
- nb2 = (nb2 / n0) * ng;
5679
- nb3 = (nb3 / n0) * ng;
5680
- }
5681
-
5682
- src0->grad = ggml_acc_or_set(ctx, src0->grad, tensor->grad, nb1, nb2, nb3, offset, zero_table, acc_table);
5683
- }
5684
- } break;
5685
- case GGML_OP_PERMUTE:
5686
- {
5687
- // necessary for llama
5688
- if (src0->grad) {
5689
- int32_t * axes = (int32_t *) tensor->op_params;
5690
- int axis0 = axes[0] & 0x3;
5691
- int axis1 = axes[1] & 0x3;
5692
- int axis2 = axes[2] & 0x3;
5693
- int axis3 = axes[3] & 0x3;
5694
- int axes_backward[4] = {0,0,0,0};
5695
- axes_backward[axis0] = 0;
5696
- axes_backward[axis1] = 1;
5697
- axes_backward[axis2] = 2;
5698
- axes_backward[axis3] = 3;
5699
- src0->grad =
5700
- ggml_add_or_set(ctx, src0->grad,
5701
- ggml_permute(ctx,
5702
- tensor->grad,
5703
- axes_backward[0],
5704
- axes_backward[1],
5705
- axes_backward[2],
5706
- axes_backward[3]),
5707
- zero_table, acc_table);
5708
- }
5709
- } break;
5710
- case GGML_OP_TRANSPOSE:
5711
- {
5712
- // necessary for llama
5713
- if (src0->grad) {
5714
- src0->grad =
5715
- ggml_add_or_set(ctx, src0->grad,
5716
- ggml_transpose(ctx, tensor->grad),
5717
- zero_table, acc_table);
5718
- }
5719
- } break;
5720
- case GGML_OP_GET_ROWS:
5721
- {
5722
- // necessary for llama (only for tokenizer)
5723
- if (src0->grad) {
5724
- src0->grad =
5725
- ggml_add_or_set(ctx, src0->grad,
5726
- // last ggml_get_rows_back argument src0->grad is only
5727
- // necessary to setup correct output shape
5728
- ggml_get_rows_back(ctx, tensor->grad, src1, src0->grad),
5729
- zero_table, acc_table);
5730
- }
5731
- if (src1->grad) {
5732
- // noop
5733
- }
5734
- } break;
5735
- case GGML_OP_GET_ROWS_BACK:
5736
- {
5737
- GGML_ABORT("fatal error"); // TODO: not implemented
5738
  }
5739
- case GGML_OP_DIAG:
5740
- {
5741
- GGML_ABORT("fatal error"); // TODO: not implemented
 
5742
  }
5743
- case GGML_OP_DIAG_MASK_INF:
5744
- {
5745
- // necessary for llama
5746
- if (src0->grad) {
5747
- const int n_past = ((int32_t *) tensor->op_params)[0];
5748
- src0->grad =
5749
- ggml_add_or_set(ctx, src0->grad,
5750
- /* ggml_diag_mask_inf_impl() shouldn't be here */
5751
- /* ref: https://github.com/ggerganov/llama.cpp/pull/4203#discussion_r1412377992 */
5752
- ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false),
5753
- zero_table, acc_table);
5754
- }
5755
- } break;
5756
- case GGML_OP_DIAG_MASK_ZERO:
5757
- {
5758
- // necessary for llama
5759
- if (src0->grad) {
5760
- const int n_past = ((int32_t *) tensor->op_params)[0];
5761
- src0->grad =
5762
- ggml_add_or_set(ctx, src0->grad,
5763
- ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false),
5764
- zero_table, acc_table);
5765
- }
5766
- } break;
5767
- case GGML_OP_SOFT_MAX:
5768
- {
5769
- // necessary for llama
5770
- if (src0->grad) {
5771
- src0->grad =
5772
- ggml_add_or_set(ctx, src0->grad,
5773
- ggml_soft_max_back(ctx, tensor->grad, tensor),
5774
- zero_table, acc_table);
5775
- }
5776
- GGML_ASSERT((!src1 || !src1->grad) && "backward pass for softmax mask not implemented");
5777
- } break;
5778
- case GGML_OP_SOFT_MAX_BACK:
5779
- {
5780
- GGML_ABORT("fatal error"); // TODO: not implemented
5781
  }
5782
- case GGML_OP_ROPE:
5783
- {
5784
- // necessary for llama
5785
- if (src0->grad) {
5786
- //const int n_past = ((int32_t *) tensor->op_params)[0];
5787
- const int n_dims = ((int32_t *) tensor->op_params)[1];
5788
- const int mode = ((int32_t *) tensor->op_params)[2];
5789
- //const int n_ctx = ((int32_t *) tensor->op_params)[3];
5790
- const int n_ctx_orig = ((int32_t *) tensor->op_params)[4];
5791
- float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
5792
-
5793
- memcpy(&freq_base, (int32_t *) tensor->op_params + 5, sizeof(float));
5794
- memcpy(&freq_scale, (int32_t *) tensor->op_params + 6, sizeof(float));
5795
- memcpy(&ext_factor, (int32_t *) tensor->op_params + 7, sizeof(float));
5796
- memcpy(&attn_factor, (int32_t *) tensor->op_params + 8, sizeof(float));
5797
- memcpy(&beta_fast, (int32_t *) tensor->op_params + 9, sizeof(float));
5798
- memcpy(&beta_slow, (int32_t *) tensor->op_params + 10, sizeof(float));
5799
-
5800
- src0->grad = ggml_add_or_set(ctx,
5801
- src0->grad,
5802
- ggml_rope_back(ctx,
5803
- tensor->grad,
5804
- src1,
5805
- src2,
5806
- n_dims,
5807
- mode,
5808
- n_ctx_orig,
5809
- freq_base,
5810
- freq_scale,
5811
- ext_factor,
5812
- attn_factor,
5813
- beta_fast,
5814
- beta_slow),
5815
- zero_table, acc_table);
5816
- }
5817
- GGML_ASSERT((!src2 || !src2->grad) && "gradients for freq factors not implemented");
5818
- } break;
5819
- case GGML_OP_ROPE_BACK:
5820
- {
5821
- if (src0->grad) {
5822
- //const int n_past = ((int32_t *) tensor->op_params)[0];
5823
- const int n_dims = ((int32_t *) tensor->op_params)[1];
5824
- const int mode = ((int32_t *) tensor->op_params)[2];
5825
- //const int n_ctx = ((int32_t *) tensor->op_params)[3];
5826
- const int n_ctx_orig = ((int32_t *) tensor->op_params)[4];
5827
- float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
5828
-
5829
- memcpy(&freq_base, (int32_t *) tensor->op_params + 5, sizeof(float));
5830
- memcpy(&freq_scale, (int32_t *) tensor->op_params + 6, sizeof(float));
5831
- memcpy(&ext_factor, (int32_t *) tensor->op_params + 7, sizeof(float));
5832
- memcpy(&attn_factor, (int32_t *) tensor->op_params + 8, sizeof(float));
5833
- memcpy(&beta_fast, (int32_t *) tensor->op_params + 9, sizeof(float));
5834
- memcpy(&beta_slow, (int32_t *) tensor->op_params + 10, sizeof(float));
5835
-
5836
- src0->grad = ggml_add_or_set(ctx,
5837
- src0->grad,
5838
- ggml_rope_impl(ctx,
5839
- tensor->grad,
5840
- src1,
5841
- src2,
5842
- n_dims,
5843
- mode,
5844
- n_ctx_orig,
5845
- freq_base,
5846
- freq_scale,
5847
- ext_factor,
5848
- attn_factor,
5849
- beta_fast,
5850
- beta_slow,
5851
- false),
5852
- zero_table, acc_table);
5853
  }
5854
- } break;
5855
- case GGML_OP_CLAMP:
5856
- {
5857
- GGML_ABORT("fatal error"); // TODO: not implemented
5858
  }
5859
- case GGML_OP_CONV_TRANSPOSE_1D:
5860
- {
5861
- GGML_ABORT("fatal error"); // TODO: not implemented
 
5862
  }
5863
- case GGML_OP_IM2COL:
5864
- {
5865
- if (src1->grad) {
5866
- const int32_t s0 = ggml_get_op_params_i32(tensor, 0);
5867
- const int32_t s1 = ggml_get_op_params_i32(tensor, 1);
5868
- const int32_t p0 = ggml_get_op_params_i32(tensor, 2);
5869
- const int32_t p1 = ggml_get_op_params_i32(tensor, 3);
5870
- const int32_t d0 = ggml_get_op_params_i32(tensor, 4);
5871
- const int32_t d1 = ggml_get_op_params_i32(tensor, 5);
5872
- const bool is_2D = ggml_get_op_params_i32(tensor, 6) == 1;
5873
-
5874
- src1->grad = ggml_add_or_set(ctx,
5875
- src1->grad,
5876
- ggml_im2col_back(ctx, src0, tensor->grad, src1->ne, s0, s1, p0, p1, d0, d1, is_2D),
5877
- zero_table, acc_table);
5878
- }
5879
- } break;
5880
- case GGML_OP_IM2COL_BACK:
5881
- {
5882
- GGML_ABORT("fatal error"); // TODO: not implemented
5883
  }
5884
- case GGML_OP_CONV_TRANSPOSE_2D:
5885
- {
5886
- GGML_ABORT("fatal error"); // TODO: not implemented
 
5887
  }
5888
- case GGML_OP_POOL_1D:
5889
- {
5890
- GGML_ABORT("fatal error"); // TODO: not implemented
 
5891
  }
5892
- case GGML_OP_POOL_2D:
5893
- {
5894
- if (src0->grad) {
5895
- const enum ggml_op_pool op = ggml_get_op_params_i32(tensor, 0);
5896
- const int32_t k0 = ggml_get_op_params_i32(tensor, 1);
5897
- const int32_t k1 = ggml_get_op_params_i32(tensor, 2);
5898
- const int32_t s0 = ggml_get_op_params_i32(tensor, 3);
5899
- const int32_t s1 = ggml_get_op_params_i32(tensor, 4);
5900
- const int32_t p0 = ggml_get_op_params_i32(tensor, 5);
5901
- const int32_t p1 = ggml_get_op_params_i32(tensor, 6);
5902
-
5903
- src0->grad = ggml_add_or_set(ctx,
5904
- src0->grad,
5905
- ggml_pool_2d_back(ctx, tensor->grad, src0, op, k0, k1, s0, s1, p0, p1),
5906
- zero_table, acc_table);
5907
- }
5908
- } break;
5909
- case GGML_OP_POOL_2D_BACK:
5910
- {
5911
- GGML_ABORT("fatal error"); // TODO: not implemented
5912
  }
5913
- case GGML_OP_UPSCALE:
5914
- {
5915
- GGML_ABORT("fatal error"); // TODO: not implemented
 
5916
  }
5917
- case GGML_OP_PAD:
5918
- {
5919
- GGML_ABORT("fatal error"); // TODO: not implemented
 
5920
  }
5921
- case GGML_OP_ARANGE:
5922
- {
5923
- GGML_ABORT("fatal error"); // TODO: not implemented
 
5924
  }
5925
- case GGML_OP_TIMESTEP_EMBEDDING:
5926
- {
5927
- GGML_ABORT("fatal error"); // TODO: not implemented
 
5928
  }
5929
- case GGML_OP_ARGSORT:
5930
- {
5931
- GGML_ABORT("fatal error"); // TODO: not implemented
 
5932
  }
5933
- case GGML_OP_LEAKY_RELU:
5934
- {
5935
- GGML_ABORT("fatal error"); // TODO: not implemented
 
5936
  }
5937
- case GGML_OP_FLASH_ATTN_EXT:
5938
- {
5939
- GGML_ABORT("FA backward pass not adapted after rework");
5940
- struct ggml_tensor * flash_grad = NULL;
5941
- if (src0->grad || src1->grad || tensor->src[2]->grad) {
5942
- int32_t t = ggml_get_op_params_i32(tensor, 0);
5943
- GGML_ASSERT(t == 0 || t == 1);
5944
- bool masked = t != 0;
5945
- flash_grad =
5946
- ggml_flash_attn_back(ctx,
5947
- src0,
5948
- src1,
5949
- tensor->src[2],
5950
- tensor->grad,
5951
- masked);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5952
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5953
 
5954
- const int64_t elem_q = ggml_nelements(src0);
5955
- const int64_t elem_k = ggml_nelements(src1);
5956
- const int64_t elem_v = ggml_nelements(src2);
5957
-
5958
- enum ggml_type result_type = flash_grad->type;
5959
- GGML_ASSERT(ggml_blck_size(result_type) == 1);
5960
- const size_t tsize = ggml_type_size(result_type);
5961
-
5962
- const size_t offs_q = 0;
5963
- const size_t offs_k = offs_q + GGML_PAD(elem_q * tsize, GGML_MEM_ALIGN);
5964
- const size_t offs_v = offs_k + GGML_PAD(elem_k * tsize, GGML_MEM_ALIGN);
5965
-
5966
- if (src0->grad) {
5967
- struct ggml_tensor * view_q = ggml_view_1d(ctx, flash_grad, elem_q, offs_q);
5968
- struct ggml_tensor * grad_q = ggml_reshape(ctx, view_q, src0);
5969
- src0->grad = ggml_add_or_set(ctx,
5970
- src0->grad,
5971
- grad_q,
5972
- zero_table, acc_table);
5973
- }
5974
- if (src1->grad) {
5975
- struct ggml_tensor * view_k = ggml_view_1d(ctx, flash_grad, elem_k, offs_k);
5976
- struct ggml_tensor * grad_k = ggml_reshape(ctx, view_k, src1);
5977
- src1->grad = ggml_add_or_set(ctx,
5978
- src1->grad,
5979
- grad_k,
5980
- zero_table, acc_table);
5981
- }
5982
- if (src2->grad) {
5983
- struct ggml_tensor * view_v = ggml_view_1d(ctx, flash_grad, elem_v, offs_v);
5984
- struct ggml_tensor * grad_v = ggml_reshape(ctx, view_v, src2);
5985
- src2->grad = ggml_add_or_set(ctx,
5986
- src2->grad,
5987
- grad_v,
5988
- zero_table, acc_table);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5989
  }
5990
- } break;
5991
- case GGML_OP_FLASH_ATTN_BACK:
5992
- {
5993
- GGML_ABORT("fatal error"); // not supported
5994
  }
5995
- case GGML_OP_SSM_CONV:
5996
- case GGML_OP_SSM_SCAN:
5997
- {
5998
- GGML_ABORT("fatal error"); // TODO: not implemented
 
 
 
 
 
 
 
 
 
 
5999
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6000
  case GGML_OP_WIN_PART:
6001
  case GGML_OP_WIN_UNPART:
6002
- case GGML_OP_UNARY:
6003
- {
6004
- switch (ggml_get_unary_op(tensor)) {
6005
- case GGML_UNARY_OP_ABS:
6006
- {
6007
- if (src0->grad) {
6008
- src0->grad =
6009
- ggml_add_or_set(ctx,
6010
- src0->grad,
6011
- ggml_mul(ctx,
6012
- ggml_sgn(ctx, src0),
6013
- tensor->grad),
6014
- zero_table, acc_table);
6015
- }
6016
- } break;
6017
- case GGML_UNARY_OP_SGN:
6018
- {
6019
- if (src0->grad) {
6020
- // noop
6021
- }
6022
- } break;
6023
- case GGML_UNARY_OP_NEG:
6024
- {
6025
- if (src0->grad) {
6026
- src0->grad = ggml_sub_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
6027
- }
6028
- } break;
6029
- case GGML_UNARY_OP_STEP:
6030
- {
6031
- if (src0->grad) {
6032
- // noop
6033
- }
6034
- } break;
6035
- case GGML_UNARY_OP_TANH:
6036
- {
6037
- GGML_ABORT("fatal error"); // TODO: not implemented
6038
- }
6039
- case GGML_UNARY_OP_ELU:
6040
- {
6041
- GGML_ABORT("fatal error"); // TODO: not implemented
6042
- }
6043
- case GGML_UNARY_OP_RELU:
6044
- {
6045
- if (src0->grad) {
6046
- src0->grad = ggml_add_or_set(ctx,
6047
- src0->grad,
6048
- ggml_mul(ctx,
6049
- ggml_step(ctx, src0),
6050
- tensor->grad),
6051
- zero_table, acc_table);
6052
- }
6053
- } break;
6054
- case GGML_UNARY_OP_SIGMOID:
6055
- {
6056
- GGML_ABORT("fatal error"); // TODO: not implemented
6057
- }
6058
- case GGML_UNARY_OP_GELU:
6059
- {
6060
- GGML_ABORT("fatal error"); // TODO: not implemented
6061
- }
6062
- case GGML_UNARY_OP_GELU_QUICK:
6063
- {
6064
- GGML_ABORT("fatal error"); // TODO: not implemented
6065
- }
6066
- case GGML_UNARY_OP_SILU:
6067
- {
6068
- // necessary for llama
6069
- if (src0->grad) {
6070
- src0->grad = ggml_add_or_set(ctx,
6071
- src0->grad,
6072
- ggml_silu_back(ctx, src0, tensor->grad),
6073
- zero_table, acc_table);
6074
- }
6075
- } break;
6076
- case GGML_UNARY_OP_EXP:
6077
- {
6078
- if (src0->grad) {
6079
- src0->grad = ggml_add_or_set(ctx,
6080
- src0->grad,
6081
- ggml_mul(ctx, tensor, tensor->grad),
6082
- zero_table, acc_table);
6083
- }
6084
- } break;
6085
- default:
6086
- GGML_ABORT("fatal error");
6087
- }
6088
- } break;
6089
- case GGML_OP_GET_REL_POS:
6090
- case GGML_OP_ADD_REL_POS:
6091
- case GGML_OP_RWKV_WKV6:
6092
- case GGML_OP_MAP_UNARY:
6093
- case GGML_OP_MAP_BINARY:
6094
- case GGML_OP_MAP_CUSTOM1_F32:
6095
- case GGML_OP_MAP_CUSTOM2_F32:
6096
- case GGML_OP_MAP_CUSTOM3_F32:
6097
- case GGML_OP_MAP_CUSTOM1:
6098
- case GGML_OP_MAP_CUSTOM2:
6099
- case GGML_OP_MAP_CUSTOM3:
6100
- {
6101
- GGML_ABORT("fatal error"); // not supported
6102
- }
6103
- case GGML_OP_CROSS_ENTROPY_LOSS:
6104
- {
6105
- if (src0->grad) {
6106
- src0->grad = ggml_add_or_set(ctx,
6107
- src0->grad,
6108
- ggml_cross_entropy_loss_back(ctx,
6109
- src0,
6110
- src1,
6111
- tensor->grad),
6112
- zero_table, acc_table);
6113
- }
6114
- GGML_ASSERT(!src1->grad && "backward pass for labels not implemented");
6115
- } break;
6116
- case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
6117
- {
6118
- GGML_ABORT("fatal error"); // not supported
6119
  }
6120
- case GGML_OP_OPT_STEP_ADAMW:
6121
- {
6122
- GGML_ABORT("fatal error"); // not supported
 
6123
  }
6124
- case GGML_OP_NONE:
6125
- {
6126
- // nop
6127
- } break;
 
6128
  case GGML_OP_COUNT:
6129
- {
6130
- GGML_ABORT("fatal error");
6131
- }
 
6132
  }
6133
 
6134
- for (int i = 0; i < GGML_MAX_SRC; ++i) {
6135
- if (tensor->src[i] && tensor->src[i]->grad) {
6136
- GGML_ASSERT(ggml_are_same_shape(tensor->src[i], tensor->src[i]->grad));
6137
- }
6138
- }
6139
  }
6140
 
6141
  static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) {
6142
- if (node->grad == NULL) {
6143
- // this usually happens when we generate intermediate nodes from constants in the backward pass
6144
- // it can also happen during forward pass, if the user performs computations with constants
6145
- if (node->op != GGML_OP_NONE) {
6146
- //GGML_PRINT_DEBUG("%s: warning: node %p has no grad, but op %d\n", __func__, (void *) node, node->op);
6147
- }
6148
- }
6149
-
6150
  // check if already visited
6151
  if (ggml_hash_insert(&cgraph->visited_hash_set, node) == GGML_HASHSET_ALREADY_EXISTS) {
6152
  return;
@@ -6207,18 +5586,42 @@ void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor *
6207
  ggml_build_forward_impl(cgraph, tensor, true);
6208
  }
6209
 
6210
- void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool accumulate) {
6211
- GGML_ASSERT(gf->n_nodes > 0);
6212
- GGML_ASSERT(gf->grads);
 
 
 
 
 
 
 
6213
 
6214
- for (int i = 0; i < gf->n_nodes; ++i) {
6215
- struct ggml_tensor * node = gf->nodes[i];
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6216
 
6217
  if (node->type == GGML_TYPE_I32) {
6218
  continue;
6219
  }
6220
 
6221
- bool needs_grad = node->flags & GGML_TENSOR_FLAG_PARAM;
6222
  bool ignore_src[GGML_MAX_SRC] = {false};
6223
  switch (node->op) {
6224
  // gradients in node->src[0] for one reason or another have no effect on output gradients
@@ -6246,14 +5649,14 @@ void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph *
6246
  break;
6247
  }
6248
  for (int j = 0; j < GGML_MAX_SRC; ++j) {
6249
- if (!node->src[j] || !node->src[j]->grad || ignore_src[j]) {
6250
  continue;
6251
  }
6252
  GGML_ASSERT(node->src[j]->type == GGML_TYPE_F32 || node->src[j]->type == GGML_TYPE_F16);
6253
- needs_grad = true;
6254
  break;
6255
  }
6256
- if (!needs_grad) {
6257
  continue;
6258
  }
6259
 
@@ -6261,73 +5664,21 @@ void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph *
6261
  GGML_ASSERT(!node->view_src || node->op == GGML_OP_CPY || node->op == GGML_OP_VIEW ||
6262
  node->op == GGML_OP_RESHAPE || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_TRANSPOSE);
6263
 
6264
- // create a new tensor with the same type and shape as the node and set it as grad
6265
- node->grad = ggml_dup_tensor(ctx, node);
6266
- }
6267
-
6268
- // keep tables of original gradients for replacement/accumulation logic
6269
- struct ggml_hash_set zero_table = ggml_hash_set_new(gf->size);
6270
- struct ggml_hash_set acc_table = ggml_hash_set_new(gf->size);
6271
- for (int i = 0; i < gf->n_nodes; i++) {
6272
- struct ggml_tensor * node = gf->nodes[i];
6273
-
6274
- if (node->grad) {
6275
- {
6276
- const size_t insert_result = ggml_hash_insert(&zero_table, node->grad);
6277
- GGML_ASSERT(insert_result != GGML_HASHSET_FULL);
6278
- GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS);
6279
- }
6280
-
6281
- // only gradients of trainable parameters should be accumulated
6282
- if (accumulate && (node->flags & GGML_TENSOR_FLAG_PARAM)) {
6283
- const size_t insert_result = ggml_hash_insert(&acc_table, node->grad);
6284
- GGML_ASSERT(insert_result != GGML_HASHSET_FULL);
6285
- GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS);
6286
- }
6287
  }
 
6288
  }
6289
 
6290
- for (int i = gf->n_nodes - 1; i >= 0; i--) {
6291
- struct ggml_tensor * node = gf->nodes[i];
6292
-
6293
  // inplace operations to add gradients are not created by ggml_compute_backward except for gradient accumulation
6294
  // use allocator to automatically make inplace operations
6295
- if (node->grad) {
6296
- ggml_compute_backward(ctx, node, &zero_table, &acc_table);
6297
- }
6298
  }
6299
 
6300
- for (int i = 0; i < gf->n_nodes; i++) {
6301
- struct ggml_tensor * node = gf->nodes[i];
6302
-
6303
- if (node->flags & GGML_TENSOR_FLAG_PARAM) {
6304
- GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node);
6305
- ggml_build_forward_expand(gb, node->grad);
6306
- }
6307
- }
6308
-
6309
- ggml_hash_set_free(&zero_table);
6310
- ggml_hash_set_free(&acc_table);
6311
- }
6312
-
6313
- void ggml_build_opt_adamw(
6314
- struct ggml_context * ctx,
6315
- struct ggml_cgraph * gf,
6316
- struct ggml_cgraph * gb,
6317
- float alpha,
6318
- float beta1,
6319
- float beta2,
6320
- float eps,
6321
- float wd) {
6322
- for (int i = 0; i < gf->n_nodes; i++) {
6323
- struct ggml_tensor * node = gf->nodes[i];
6324
-
6325
- if (node->flags & GGML_TENSOR_FLAG_PARAM) {
6326
- GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node);
6327
- struct ggml_tensor * opt_step = ggml_opt_step_adamw(ctx, node, node->grad, alpha, beta1, beta2, eps, wd);
6328
- ggml_build_forward_expand(gb, opt_step);
6329
- }
6330
- }
6331
  }
6332
 
6333
  static void * incr_ptr_aligned(void ** p, size_t size, size_t align) {
@@ -6345,7 +5696,8 @@ static size_t ggml_graph_nbytes(size_t size, bool grads) {
6345
  incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // leafs
6346
  incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // hash keys
6347
  if (grads) {
6348
- incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // grads
 
6349
  }
6350
  incr_ptr_aligned(&p, ggml_bitset_size(hash_size) * sizeof(ggml_bitset_t), sizeof(ggml_bitset_t));
6351
 
@@ -6371,10 +5723,12 @@ struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t siz
6371
 
6372
  void * p = cgraph + 1;
6373
 
6374
- struct ggml_tensor ** nodes_ptr = incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
6375
- struct ggml_tensor ** leafs_ptr = incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
6376
- struct ggml_tensor ** hash_keys_ptr = incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
6377
- struct ggml_tensor ** grads_ptr = grads ? incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL;
 
 
6378
  ggml_bitset_t * hash_used = incr_ptr_aligned(&p, ggml_bitset_size(hash_size) * sizeof(ggml_bitset_t), sizeof(ggml_bitset_t));
6379
 
6380
  // check that we allocated the correct amount of memory
@@ -6386,12 +5740,17 @@ struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t siz
6386
  /*.n_leafs =*/ 0,
6387
  /*.nodes =*/ nodes_ptr,
6388
  /*.grads =*/ grads_ptr,
 
6389
  /*.leafs =*/ leafs_ptr,
6390
  /*.hash_table =*/ { hash_size, hash_used, hash_keys_ptr },
6391
  /*.order =*/ GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT,
6392
  };
6393
 
6394
  ggml_hash_set_reset(&cgraph->visited_hash_set);
 
 
 
 
6395
 
6396
  return cgraph;
6397
  }
@@ -6407,6 +5766,7 @@ struct ggml_cgraph ggml_graph_view(struct ggml_cgraph * cgraph0, int i0, int i1)
6407
  /*.n_leafs =*/ 0,
6408
  /*.nodes =*/ cgraph0->nodes + i0,
6409
  /*.grads =*/ cgraph0->grads ? cgraph0->grads + i0 : NULL,
 
6410
  /*.leafs =*/ NULL,
6411
  /*.hash_table =*/ { 0, NULL, NULL },
6412
  /*.order =*/ cgraph0->order,
@@ -6432,19 +5792,23 @@ void ggml_graph_cpy(struct ggml_cgraph * src, struct ggml_cgraph * dst) {
6432
  dst->nodes[i] = src->nodes[i];
6433
  }
6434
 
6435
- if (src->grads) {
6436
- GGML_ASSERT(dst->grads != NULL);
6437
- for (int i = 0; i < src->n_nodes; ++i) {
6438
- dst->grads[i] = src->grads[i];
6439
- }
6440
- }
6441
-
6442
  for (size_t i = 0; i < src->visited_hash_set.size; ++i) {
6443
  // copy all hashset keys (tensors) that are in use
6444
  if (ggml_bitset_get(src->visited_hash_set.used, i)) {
6445
  ggml_hash_insert(&dst->visited_hash_set, src->visited_hash_set.keys[i]);
6446
  }
6447
  }
 
 
 
 
 
 
 
 
 
 
 
6448
  }
6449
 
6450
  struct ggml_cgraph * ggml_graph_dup(struct ggml_context * ctx, struct ggml_cgraph * cgraph) {
@@ -6470,29 +5834,36 @@ void ggml_graph_reset(struct ggml_cgraph * cgraph) {
6470
  GGML_ASSERT(cgraph->grads != NULL);
6471
 
6472
  for (int i = 0; i < cgraph->n_nodes; i++) {
6473
- struct ggml_tensor * node = cgraph->nodes[i];
 
 
 
 
 
 
 
 
 
 
 
6474
 
6475
  // initial gradients of loss should be 1, 0 otherwise
6476
- if (node->grad) {
6477
  if (node->flags & GGML_TENSOR_FLAG_LOSS) {
6478
- GGML_ASSERT(node->grad->buffer);
6479
- GGML_ASSERT(node->type == GGML_TYPE_F32);
6480
- GGML_ASSERT(ggml_is_scalar(node));
6481
 
6482
  const float onef = 1.0f;
6483
- ggml_backend_tensor_set(node->grad, &onef, 0, ggml_nbytes(node->grad));
 
 
 
 
 
6484
  } else {
6485
- ggml_set_zero(node->grad);
6486
  }
6487
  }
6488
-
6489
- GGML_ASSERT(node);
6490
- if (node->op == GGML_OP_OPT_STEP_ADAMW) {
6491
- // set iteration to 1 and clear momenta
6492
- ggml_set_op_params_i32(node, 0, 1);
6493
- ggml_set_zero(node->src[2]);
6494
- ggml_set_zero(node->src[3]);
6495
- }
6496
  }
6497
  }
6498
 
@@ -6530,7 +5901,7 @@ void ggml_graph_add_node(struct ggml_cgraph * cgraph, struct ggml_tensor * tenso
6530
  cgraph->n_nodes++;
6531
  }
6532
 
6533
- struct ggml_tensor * ggml_graph_get_tensor(struct ggml_cgraph * cgraph, const char * name) {
6534
  for (int i = 0; i < cgraph->n_leafs; i++) {
6535
  struct ggml_tensor * leaf = cgraph->leafs[i];
6536
 
@@ -6550,6 +5921,16 @@ struct ggml_tensor * ggml_graph_get_tensor(struct ggml_cgraph * cgraph, const ch
6550
  return NULL;
6551
  }
6552
 
 
 
 
 
 
 
 
 
 
 
6553
  void ggml_graph_print(const struct ggml_cgraph * cgraph) {
6554
  GGML_LOG_INFO("=== GRAPH ===\n");
6555
 
@@ -6560,7 +5941,8 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
6560
  GGML_LOG_INFO(" - %3d: [ %5" PRId64 ", %5" PRId64 ", %5" PRId64 "] %16s %s\n",
6561
  i,
6562
  node->ne[0], node->ne[1], node->ne[2],
6563
- ggml_op_name(node->op), (node->flags & GGML_TENSOR_FLAG_PARAM) ? "x" : node->grad ? "g" : " ");
 
6564
  }
6565
 
6566
  GGML_LOG_INFO("n_leafs = %d\n", cgraph->n_leafs);
@@ -6595,8 +5977,9 @@ static bool ggml_graph_find(const struct ggml_cgraph * cgraph, const struct ggml
6595
  static struct ggml_tensor * ggml_graph_get_parent(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) {
6596
  for (int i = 0; i < cgraph->n_nodes; i++) {
6597
  struct ggml_tensor * parent = cgraph->nodes[i];
 
6598
 
6599
- if (parent->grad == node) {
6600
  return parent;
6601
  }
6602
  }
@@ -6636,6 +6019,7 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph
6636
 
6637
  for (int i = 0; i < gb->n_nodes; i++) {
6638
  struct ggml_tensor * node = gb->nodes[i];
 
6639
 
6640
  if (ggml_graph_get_parent(gb, node) != NULL) {
6641
  continue;
@@ -6643,7 +6027,7 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph
6643
 
6644
  if (node->flags & GGML_TENSOR_FLAG_PARAM) {
6645
  snprintf(color, sizeof(color), "yellow");
6646
- } else if (node->grad) {
6647
  if (ggml_graph_find(gf, node)) {
6648
  snprintf(color, sizeof(color), "green");
6649
  } else {
@@ -6670,8 +6054,8 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph
6670
  fprintf(fp, "%d [%" PRId64 ", %" PRId64 ", %" PRId64 "] | <x>%s", i, node->ne[0], node->ne[1], node->ne[2], ggml_op_symbol(node->op));
6671
  }
6672
 
6673
- if (node->grad) {
6674
- fprintf(fp, " | <g>%s\"; ]\n", ggml_op_symbol(node->grad->op));
6675
  } else {
6676
  fprintf(fp, "\"; ]\n");
6677
  }
 
1592
  /*.op =*/ GGML_OP_NONE,
1593
  /*.op_params =*/ { 0 },
1594
  /*.flags =*/ 0,
 
1595
  /*.src =*/ { NULL },
1596
  /*.view_src =*/ view_src,
1597
  /*.view_offs =*/ view_offs,
1598
  /*.data =*/ obj_alloc_size > 0 ? (void *)(result + 1) : data,
1599
  /*.name =*/ { 0 },
1600
  /*.extra =*/ NULL,
1601
+ /*.padding =*/ { 0 },
1602
  };
1603
 
1604
  #ifdef __clang__
 
4193
  GGML_ASSERT(mask);
4194
  }
4195
 
 
 
4196
  // permute(0, 2, 1, 3)
4197
  int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] };
4198
  struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
 
4200
  float params[] = { scale, max_bias, logit_softcap };
4201
  ggml_set_op_params(result, params, sizeof(params));
4202
 
4203
+ result->op = GGML_OP_FLASH_ATTN_EXT;
 
4204
  result->src[0] = q;
4205
  result->src[1] = k;
4206
  result->src[2] = v;
 
4268
 
4269
  GGML_ASSERT(ne2 % kvne2 == 0);
4270
 
 
 
 
 
 
 
 
 
4271
  // store gradients of q, k and v as continuous tensors concatenated in result.
4272
  // note: v and gradv are actually transposed, i.e. v->ne[0] != D.
4273
  const int64_t elem_q = ggml_nelements(q);
 
4290
  int32_t masked_i = masked ? 1 : 0;
4291
  ggml_set_op_params(result, &masked_i, sizeof(masked_i));
4292
 
4293
+ result->op = GGML_OP_FLASH_ATTN_BACK;
 
4294
  result->src[0] = q;
4295
  result->src[1] = k;
4296
  result->src[2] = v;
 
4932
  struct ggml_context * ctx,
4933
  struct ggml_tensor * a,
4934
  struct ggml_tensor * grad,
4935
+ struct ggml_tensor * m,
4936
+ struct ggml_tensor * v,
4937
+ struct ggml_tensor * adamw_params) {
 
 
4938
  GGML_ASSERT(a->flags & GGML_TENSOR_FLAG_PARAM);
4939
  GGML_ASSERT(ggml_are_same_shape(a, grad));
4940
+ GGML_ASSERT(ggml_are_same_shape(a, m));
4941
+ GGML_ASSERT(ggml_are_same_shape(a, v));
4942
+ GGML_ASSERT(adamw_params->type == GGML_TYPE_F32);
4943
+ GGML_ASSERT(ggml_nelements(adamw_params) == 7);
 
4944
 
4945
  struct ggml_tensor * result = ggml_view_tensor(ctx, a);
4946
 
 
 
 
 
 
 
 
 
4947
  result->op = GGML_OP_OPT_STEP_ADAMW;
4948
  result->src[0] = a;
4949
  result->src[1] = grad;
4950
+ result->src[2] = m;
4951
+ result->src[3] = v;
4952
+ result->src[4] = adamw_params;
4953
 
4954
  return result;
4955
  }
 
5018
  GGML_FREE(map);
5019
  }
5020
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5021
  // utility functions to change gradients
5022
  // if a is in acc_table, modify gradients in-place and mark result as gradient accumulator
5023
  // else if a is in zero_table, replace a
5024
  // else, just add/subtract/etc. the gradients
5025
 
5026
+ static void ggml_add_or_set(
5027
+ struct ggml_context * ctx,
5028
+ struct ggml_cgraph * cgraph,
5029
+ size_t isrc,
5030
+ struct ggml_tensor * tensor) {
5031
+ if (cgraph->grads[isrc]) {
5032
+ cgraph->grads[isrc] = ggml_add_impl(ctx, cgraph->grads[isrc], tensor, cgraph->grad_accs[isrc]);
5033
+ } else {
5034
+ cgraph->grads[isrc] = tensor;
 
 
 
 
 
 
5035
  }
5036
+ ggml_build_forward_expand(cgraph, cgraph->grads[isrc]);
5037
  }
5038
 
5039
+ static void ggml_acc_or_set(
5040
+ struct ggml_context * ctx,
5041
+ struct ggml_cgraph * cgraph,
5042
+ size_t isrc,
5043
+ struct ggml_tensor * src,
5044
+ struct ggml_tensor * tensor,
5045
+ const size_t nb1,
5046
+ const size_t nb2,
5047
+ const size_t nb3,
5048
+ const size_t offset) {
5049
+ if (cgraph->grads[isrc]) {
5050
+ cgraph->grads[isrc] = ggml_acc_impl(ctx, cgraph->grads[isrc], tensor, nb1, nb2, nb3, offset, cgraph->grad_accs[isrc]);
5051
+ } else {
5052
+ struct ggml_tensor * a_zero = ggml_scale(ctx, src, 0.0f); // FIXME this is going to produce NaN if a contains inf/NaN
5053
+ cgraph->grads[isrc] = ggml_acc_impl(ctx, a_zero, tensor, nb1, nb2, nb3, offset, false);
 
 
 
 
 
5054
  }
5055
+ ggml_build_forward_expand(cgraph, cgraph->grads[isrc]);
5056
  }
5057
 
5058
+ static void ggml_add1_or_set(
5059
+ struct ggml_context * ctx,
5060
+ struct ggml_cgraph * cgraph,
5061
+ size_t isrc,
5062
+ struct ggml_tensor * src,
5063
+ struct ggml_tensor * tensor) {
5064
+ if (cgraph->grads[isrc]) {
5065
+ cgraph->grads[isrc] = ggml_add1_impl(ctx, cgraph->grads[isrc], tensor, cgraph->grad_accs[isrc]);
5066
+ } else {
5067
+ cgraph->grads[isrc] = ggml_repeat(ctx, tensor, src);
 
 
 
 
 
5068
  }
5069
+ ggml_build_forward_expand(cgraph, cgraph->grads[isrc]);
5070
  }
5071
 
5072
+ static void ggml_sub_or_set(
5073
+ struct ggml_context * ctx,
5074
+ struct ggml_cgraph * cgraph,
5075
+ size_t isrc,
5076
+ struct ggml_tensor * tensor) {
5077
+ if (cgraph->grads[isrc]) {
5078
+ cgraph->grads[isrc] = ggml_sub_impl(ctx, cgraph->grads[isrc], tensor, cgraph->grad_accs[isrc]);
5079
+ } else {
5080
+ cgraph->grads[isrc] = ggml_neg(ctx, tensor);
 
 
 
 
 
 
5081
  }
5082
+ ggml_build_forward_expand(cgraph, cgraph->grads[isrc]);
5083
  }
5084
 
5085
+ static void ggml_compute_backward(
5086
+ struct ggml_context * ctx, struct ggml_cgraph * cgraph, int i, bool * grads_needed) {
5087
+ struct ggml_tensor * tensor = cgraph->nodes[i];
5088
+ struct ggml_tensor * grad = ggml_graph_get_grad(cgraph, tensor);
5089
+
5090
+ if (!grad) {
5091
+ return;
5092
+ }
5093
+
5094
  struct ggml_tensor * src0 = tensor->src[0];
5095
  struct ggml_tensor * src1 = tensor->src[1];
5096
  struct ggml_tensor * src2 = tensor->src[2];
5097
+ struct ggml_hash_set * hash_set = &cgraph->visited_hash_set;
5098
+ const size_t isrc0 = ggml_hash_find(hash_set, src0);
5099
+ const size_t isrc1 = ggml_hash_find(hash_set, src1);
5100
+ const size_t isrc2 = ggml_hash_find(hash_set, src2);
5101
+ const bool src0_needs_grads = isrc0 != GGML_HASHSET_FULL && ggml_bitset_get(hash_set->used, isrc0) && grads_needed[isrc0];
5102
+ const bool src1_needs_grads = isrc1 != GGML_HASHSET_FULL && ggml_bitset_get(hash_set->used, isrc1) && grads_needed[isrc1];
5103
+ const bool src2_needs_grads = isrc2 != GGML_HASHSET_FULL && ggml_bitset_get(hash_set->used, isrc2) && grads_needed[isrc2];
5104
 
5105
  switch (tensor->op) {
5106
+ case GGML_OP_DUP: {
5107
+ if (src0_needs_grads) {
5108
+ ggml_add_or_set(ctx, cgraph, isrc0, grad);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5109
  }
5110
+ } break;
5111
+ case GGML_OP_ADD: {
5112
+ if (src0_needs_grads) {
5113
+ ggml_add_or_set(ctx, cgraph, isrc0, grad);
5114
  }
5115
+ if (src1_needs_grads) {
5116
+ struct ggml_tensor * tmp = grad;
5117
+ if (!ggml_are_same_shape(src0, src1)) {
5118
+ tmp = ggml_repeat_back(ctx, tmp, src1);
 
 
 
 
 
 
 
5119
  }
5120
+ ggml_add_or_set(ctx, cgraph, isrc1, tmp);
 
 
 
5121
  }
5122
+ } break;
5123
+ case GGML_OP_ADD1: {
5124
+ if (src0_needs_grads) {
5125
+ ggml_add_or_set(ctx, cgraph, isrc0, grad);
5126
  }
5127
+ if (src1_needs_grads) {
5128
+ ggml_add_or_set(ctx, cgraph, isrc1, ggml_mean(ctx, grad)); // TODO: should probably be sum instead of mean
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5129
  }
5130
+ } break;
5131
+ case GGML_OP_ACC: {
5132
+ if (src0_needs_grads) {
5133
+ ggml_add_or_set(ctx, cgraph, isrc0, grad);
5134
  }
5135
+ if (src1_needs_grads) {
5136
+ const size_t nb1 = ((int32_t *) tensor->op_params)[0];
5137
+ const size_t nb2 = ((int32_t *) tensor->op_params)[1];
5138
+ const size_t nb3 = ((int32_t *) tensor->op_params)[2];
5139
+ const size_t offset = ((int32_t *) tensor->op_params)[3];
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5140
 
5141
+ struct ggml_tensor * tensor_grad_view = ggml_view_4d(ctx,
5142
+ grad, src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
5143
+ nb1, nb2, nb3, offset);
 
 
 
 
 
 
5144
 
5145
+ ggml_add_or_set(ctx, cgraph, isrc1, ggml_reshape(ctx, ggml_cont(ctx, tensor_grad_view), src1));
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5146
  }
5147
+ } break;
5148
+ case GGML_OP_SUB: {
5149
+ if (src0_needs_grads) {
5150
+ ggml_add_or_set(ctx, cgraph, isrc0, grad);
5151
  }
5152
+ if (src1_needs_grads) {
5153
+ ggml_sub_or_set(ctx, cgraph, isrc1, grad);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5154
  }
5155
+ } break;
5156
+ case GGML_OP_MUL: {
5157
+ if (src0_needs_grads) {
5158
+ ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, src1, grad));
5159
+ }
5160
+ if (src1_needs_grads) {
5161
+ struct ggml_tensor * tmp = ggml_mul(ctx, src0, grad);
5162
+ if (!ggml_are_same_shape(src0, src1)) {
5163
+ tmp = ggml_repeat_back(ctx, tmp, src1);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5164
  }
5165
+ ggml_add_or_set(ctx, cgraph, isrc1, tmp);
 
 
 
5166
  }
5167
+ } break;
5168
+ case GGML_OP_DIV: {
5169
+ if (src0_needs_grads) {
5170
+ ggml_add_or_set(ctx, cgraph, isrc0, ggml_div(ctx, grad, src1));
5171
  }
5172
+ if (src1_needs_grads) {
5173
+ ggml_sub_or_set(ctx, cgraph, isrc1, ggml_mul(ctx, grad, ggml_div(ctx, tensor, src1)));
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5174
  }
5175
+ } break;
5176
+ case GGML_OP_SQR: {
5177
+ if (src0_needs_grads) {
5178
+ ggml_add_or_set(ctx, cgraph, isrc0, ggml_scale(ctx, ggml_mul(ctx, src0, grad), 2.0f));
5179
  }
5180
+ } break;
5181
+ case GGML_OP_SQRT: {
5182
+ if (src0_needs_grads) {
5183
+ ggml_add_or_set(ctx, cgraph, isrc0, ggml_scale(ctx, ggml_div(ctx, grad, tensor), 0.5f));
5184
  }
5185
+ } break;
5186
+ case GGML_OP_LOG: {
5187
+ if (src0_needs_grads) {
5188
+ ggml_add_or_set(ctx, cgraph, isrc0, ggml_div(ctx, grad, src0));
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5189
  }
5190
+ } break;
5191
+ case GGML_OP_SIN: {
5192
+ if (src0_needs_grads) {
5193
+ ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, grad, ggml_cos(ctx, src0)));
5194
  }
5195
+ } break;
5196
+ case GGML_OP_COS: {
5197
+ if (src0_needs_grads) {
5198
+ ggml_sub_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, grad, ggml_sin(ctx, src0)));
5199
  }
5200
+ } break;
5201
+ case GGML_OP_SUM: {
5202
+ if (src0_needs_grads) {
5203
+ ggml_add1_or_set(ctx, cgraph, isrc0, src0, grad);
5204
  }
5205
+ } break;
5206
+ case GGML_OP_SUM_ROWS: {
5207
+ if (src0_needs_grads) {
5208
+ ggml_add_or_set(ctx, cgraph, isrc0, ggml_repeat(ctx, grad, src0));
5209
  }
5210
+ } break;
5211
+ case GGML_OP_MEAN: {
5212
+ if (src0_needs_grads) {
5213
+ ggml_add1_or_set(ctx, cgraph, isrc0, src0, ggml_scale_impl(ctx, grad, 1.0f/src0->ne[0], false));
5214
  }
5215
+ } break;
5216
+ case GGML_OP_REPEAT: {
5217
+ if (src0_needs_grads) {
5218
+ ggml_add_or_set(ctx, cgraph, isrc0, ggml_repeat_back(ctx, grad, src0));
5219
  }
5220
+ } break;
5221
+ case GGML_OP_REPEAT_BACK: {
5222
+ if (src0_needs_grads) {
5223
+ ggml_add_or_set(ctx, cgraph, isrc0, ggml_repeat(ctx, grad, src0));
5224
+ }
5225
+ } break;
5226
+ case GGML_OP_RMS_NORM: {
5227
+ if (src0_needs_grads) {
5228
+ float eps;
5229
+ memcpy(&eps, tensor->op_params, sizeof(float));
5230
+ ggml_add_or_set(ctx, cgraph, isrc0, ggml_rms_norm_back(ctx, src0, grad, eps));
5231
+ }
5232
+ } break;
5233
+ case GGML_OP_MUL_MAT: {
5234
+ // https://cs231n.github.io/optimization-2/#staged
5235
+ // # forward pass
5236
+ // s0 = np.random.randn(5, 10)
5237
+ // s1 = np.random.randn(10, 3)
5238
+ // t = s0.dot(s1)
5239
+
5240
+ // # now suppose we had the gradient on t from above in the circuit
5241
+ // dt = np.random.randn(*t.shape) # same shape as t
5242
+ // ds0 = dt.dot(s1.T) #.T gives the transpose of the matrix
5243
+ // ds1 = t.T.dot(dt)
5244
+
5245
+ // tensor.shape [m,p,qq,rr]
5246
+ // src0.shape [n,m,q1,r1]
5247
+ // src1.shape [n,p,qq,rr]
5248
+
5249
+ if (src0_needs_grads) {
5250
+ struct ggml_tensor * s1_tg =
5251
+ ggml_out_prod(ctx, // [n,m,qq,rr]
5252
+ src1, // [n,p,qq,rr]
5253
+ grad); // [m,p,qq,rr]
5254
+ const int64_t qq = s1_tg->ne[2];
5255
+ const int64_t rr = s1_tg->ne[3];
5256
+ const int64_t q1 = src0->ne[2];
5257
+ const int64_t r1 = src0->ne[3];
5258
+ const bool ne2_broadcasted = qq > q1;
5259
+ const bool ne3_broadcasted = rr > r1;
5260
+ if (ne2_broadcasted || ne3_broadcasted) {
5261
+ // sum broadcast repetitions of s1_tg into shape of src0
5262
+ s1_tg = ggml_repeat_back(ctx, s1_tg, src0);
5263
  }
5264
+ ggml_add_or_set(ctx, cgraph, isrc0, s1_tg /*= [n,m,q1,r1]*/);
5265
+ }
5266
+ if (src1_needs_grads) {
5267
+ ggml_add_or_set(ctx, cgraph, isrc1,
5268
+ // ggml_mul_mat(ctx, // [n,p,qq,rr]
5269
+ // ggml_cont(ctx, // [m,n,q1,r1]
5270
+ // ggml_transpose(ctx, src0)), // [m,n,q1,r1]
5271
+ // grad), // [m,p,qq,rr]
5272
+
5273
+ // when src0 is bigger than tensor->grad (this is mostly the case in llama),
5274
+ // avoid transpose of src0, rather transpose smaller tensor->grad
5275
+ // and then use ggml_out_prod
5276
+ ggml_out_prod(ctx, // [n,p,qq,rr]
5277
+ src0, // [n,m,q1,r1]
5278
+ ggml_transpose(ctx, // [p,m,qq,rr]
5279
+ grad))); // [m,p,qq,rr]
5280
+ }
5281
+ } break;
5282
+ case GGML_OP_SCALE: {
5283
+ if (src0_needs_grads) {
5284
+ float s;
5285
+ memcpy(&s, tensor->op_params, sizeof(float));
5286
+ ggml_add_or_set(ctx, cgraph, isrc0, ggml_scale_impl(ctx, grad, s, false));
5287
+ }
5288
+ } break;
5289
+ case GGML_OP_SET: {
5290
+ const size_t nb1 = ((const int32_t *) tensor->op_params)[0];
5291
+ const size_t nb2 = ((const int32_t *) tensor->op_params)[1];
5292
+ const size_t nb3 = ((const int32_t *) tensor->op_params)[2];
5293
+ const size_t offset = ((const int32_t *) tensor->op_params)[3];
5294
+
5295
+ struct ggml_tensor * tensor_grad_view = NULL;
5296
+
5297
+ if (src0_needs_grads || src1_needs_grads) {
5298
+ GGML_ASSERT(src0->type == tensor->type);
5299
+ GGML_ASSERT(!cgraph->grads[isrc0] || cgraph->grads[isrc0]->type == grad->type);
5300
+ GGML_ASSERT(!cgraph->grads[isrc1] || !src1_needs_grads || cgraph->grads[isrc1]->type == grad->type);
5301
+
5302
+ tensor_grad_view = ggml_view_4d(ctx,
5303
+ grad, src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
5304
+ nb1, nb2, nb3, offset);
5305
+ }
5306
 
5307
+ if (src0_needs_grads) {
5308
+ struct ggml_tensor * tmp = ggml_neg(ctx, tensor_grad_view);
5309
+ ggml_add_or_set(ctx, cgraph, isrc0, ggml_acc_impl(ctx, grad, tmp, nb1, nb2, nb3, offset, false));
5310
+ }
5311
+
5312
+ if (src1_needs_grads) {
5313
+ ggml_add_or_set(ctx, cgraph, isrc1, ggml_reshape(ctx, ggml_cont(ctx, tensor_grad_view), src1));
5314
+ }
5315
+ } break;
5316
+ case GGML_OP_CPY: {
5317
+ // cpy overwrites value of src1 by src0 and returns view(src1)
5318
+ // the overwriting is mathematically equivalent to:
5319
+ // tensor = src0 * 1 + src1 * 0
5320
+ if (src0_needs_grads) {
5321
+ // dsrc0 = dtensor * 1
5322
+ ggml_add_or_set(ctx, cgraph, isrc0, grad);
5323
+ }
5324
+ if (src1_needs_grads) {
5325
+ // dsrc1 = dtensor * 0 -> noop
5326
+ }
5327
+ } break;
5328
+ case GGML_OP_CONT: {
5329
+ // same as cpy
5330
+ if (src0_needs_grads) {
5331
+ GGML_ASSERT(!cgraph->grads[isrc0] || ggml_is_contiguous(cgraph->grads[isrc0]));
5332
+ GGML_ASSERT(ggml_is_contiguous(grad));
5333
+ ggml_add_or_set(ctx, cgraph, isrc0, grad);
5334
+ }
5335
+ } break;
5336
+ case GGML_OP_RESHAPE: {
5337
+ if (src0_needs_grads) {
5338
+ struct ggml_tensor * grad_cont = ggml_is_contiguous(grad) ? grad : ggml_cont(ctx, grad);
5339
+ ggml_add_or_set(ctx, cgraph, isrc0, ggml_reshape(ctx, grad_cont, src0));
5340
+ }
5341
+ } break;
5342
+ case GGML_OP_VIEW: {
5343
+ if (src0_needs_grads) {
5344
+ size_t offset;
5345
+
5346
+ memcpy(&offset, tensor->op_params, sizeof(offset));
5347
+
5348
+ size_t nb1 = tensor->nb[1];
5349
+ size_t nb2 = tensor->nb[2];
5350
+ size_t nb3 = tensor->nb[3];
5351
+
5352
+ if (cgraph->grads[isrc0] && src0->type != cgraph->grads[isrc0]->type) {
5353
+ // gradient is typically F32, but src0 could be other type
5354
+ size_t ng = ggml_element_size(cgraph->grads[isrc0]);
5355
+ size_t n0 = ggml_element_size(src0);
5356
+ GGML_ASSERT(offset % n0 == 0);
5357
+ GGML_ASSERT(nb1 % n0 == 0);
5358
+ GGML_ASSERT(nb2 % n0 == 0);
5359
+ GGML_ASSERT(nb3 % n0 == 0);
5360
+ offset = (offset / n0) * ng;
5361
+ nb1 = (nb1 / n0) * ng;
5362
+ nb2 = (nb2 / n0) * ng;
5363
+ nb3 = (nb3 / n0) * ng;
5364
  }
5365
+
5366
+ ggml_acc_or_set(ctx, cgraph, isrc0, src0, grad, nb1, nb2, nb3, offset);
 
 
5367
  }
5368
+ } break;
5369
+ case GGML_OP_PERMUTE: {
5370
+ if (src0_needs_grads) {
5371
+ const int32_t * axes = (const int32_t *) tensor->op_params;
5372
+ const int axis0 = axes[0] & 0x3;
5373
+ const int axis1 = axes[1] & 0x3;
5374
+ const int axis2 = axes[2] & 0x3;
5375
+ const int axis3 = axes[3] & 0x3;
5376
+ int axb[4] = {0,0,0,0}; // axes backward
5377
+ axb[axis0] = 0;
5378
+ axb[axis1] = 1;
5379
+ axb[axis2] = 2;
5380
+ axb[axis3] = 3;
5381
+ ggml_add_or_set(ctx, cgraph, isrc0, ggml_permute(ctx, grad, axb[0], axb[1], axb[2], axb[3]));
5382
  }
5383
+ } break;
5384
+ case GGML_OP_TRANSPOSE: {
5385
+ if (src0_needs_grads) {
5386
+ ggml_add_or_set(ctx, cgraph, isrc0, ggml_transpose(ctx, grad));
5387
+ }
5388
+ } break;
5389
+ case GGML_OP_GET_ROWS: {
5390
+ if (src0_needs_grads) {
5391
+ ggml_add_or_set(ctx, cgraph, isrc0, ggml_get_rows_back(ctx, grad, src1, src0));
5392
+ }
5393
+ if (src1_needs_grads) {
5394
+ // noop
5395
+ }
5396
+ } break;
5397
+ case GGML_OP_DIAG_MASK_INF: {
5398
+ if (src0_needs_grads) {
5399
+ /* ggml_diag_mask_inf_impl() shouldn't be here */
5400
+ /* ref: https://github.com/ggerganov/llama.cpp/pull/4203#discussion_r1412377992 */
5401
+ const int n_past = ((const int32_t *) tensor->op_params)[0];
5402
+ ggml_add_or_set(ctx, cgraph, isrc0, ggml_diag_mask_zero_impl(ctx, grad, n_past, false));
5403
+ }
5404
+ } break;
5405
+ case GGML_OP_DIAG_MASK_ZERO: {
5406
+ if (src0_needs_grads) {
5407
+ const int n_past = ((const int32_t *) tensor->op_params)[0];
5408
+ ggml_add_or_set(ctx, cgraph, isrc0, ggml_diag_mask_zero_impl(ctx, grad, n_past, false));
5409
+ }
5410
+ } break;
5411
+ case GGML_OP_SOFT_MAX: {
5412
+ if (src0_needs_grads) {
5413
+ ggml_add_or_set(ctx, cgraph, isrc0, ggml_soft_max_back(ctx, grad, tensor));
5414
+ }
5415
+ GGML_ASSERT((!src1 || !src1_needs_grads) && "backward pass for softmax mask not implemented");
5416
+ } break;
5417
+ case GGML_OP_ROPE: {
5418
+ if (src0_needs_grads) {
5419
+ //const int n_past = ((int32_t *) tensor->op_params)[0];
5420
+ const int n_dims = ((const int32_t *) tensor->op_params)[1];
5421
+ const int mode = ((const int32_t *) tensor->op_params)[2];
5422
+ //const int n_ctx = ((int32_t *) tensor->op_params)[3];
5423
+ const int n_ctx_orig = ((const int32_t *) tensor->op_params)[4];
5424
+ float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
5425
+
5426
+ memcpy(&freq_base, (const float *) tensor->op_params + 5, sizeof(float));
5427
+ memcpy(&freq_scale, (const float *) tensor->op_params + 6, sizeof(float));
5428
+ memcpy(&ext_factor, (const float *) tensor->op_params + 7, sizeof(float));
5429
+ memcpy(&attn_factor, (const float *) tensor->op_params + 8, sizeof(float));
5430
+ memcpy(&beta_fast, (const float *) tensor->op_params + 9, sizeof(float));
5431
+ memcpy(&beta_slow, (const float *) tensor->op_params + 10, sizeof(float));
5432
+
5433
+ ggml_add_or_set(ctx, cgraph, isrc0,
5434
+ ggml_rope_back(ctx, grad, src1, src2, n_dims, mode, n_ctx_orig, freq_base,
5435
+ freq_scale, ext_factor, attn_factor, beta_fast, beta_slow));
5436
+ }
5437
+ GGML_ASSERT((!src2 || !src2_needs_grads) && "gradients for freq factors not implemented");
5438
+ } break;
5439
+ case GGML_OP_IM2COL: {
5440
+ if (src1_needs_grads) {
5441
+ const int32_t s0 = ggml_get_op_params_i32(tensor, 0);
5442
+ const int32_t s1 = ggml_get_op_params_i32(tensor, 1);
5443
+ const int32_t p0 = ggml_get_op_params_i32(tensor, 2);
5444
+ const int32_t p1 = ggml_get_op_params_i32(tensor, 3);
5445
+ const int32_t d0 = ggml_get_op_params_i32(tensor, 4);
5446
+ const int32_t d1 = ggml_get_op_params_i32(tensor, 5);
5447
+ const bool is_2D = ggml_get_op_params_i32(tensor, 6) == 1;
5448
+
5449
+ ggml_add_or_set(ctx, cgraph, isrc1, ggml_im2col_back(ctx, src0, grad, src1->ne, s0, s1, p0, p1, d0, d1, is_2D));
5450
+ }
5451
+ } break;
5452
+ case GGML_OP_POOL_2D: {
5453
+ if (src0_needs_grads) {
5454
+ const enum ggml_op_pool op = ggml_get_op_params_i32(tensor, 0);
5455
+ const int32_t k0 = ggml_get_op_params_i32(tensor, 1);
5456
+ const int32_t k1 = ggml_get_op_params_i32(tensor, 2);
5457
+ const int32_t s0 = ggml_get_op_params_i32(tensor, 3);
5458
+ const int32_t s1 = ggml_get_op_params_i32(tensor, 4);
5459
+ const int32_t p0 = ggml_get_op_params_i32(tensor, 5);
5460
+ const int32_t p1 = ggml_get_op_params_i32(tensor, 6);
5461
+
5462
+ ggml_add_or_set(ctx, cgraph, isrc0, ggml_pool_2d_back(ctx, grad, src0, op, k0, k1, s0, s1, p0, p1));
5463
+ }
5464
+ } break;
5465
  case GGML_OP_WIN_PART:
5466
  case GGML_OP_WIN_UNPART:
5467
+ case GGML_OP_UNARY: {
5468
+ switch (ggml_get_unary_op(tensor)) {
5469
+ case GGML_UNARY_OP_ABS: {
5470
+ if (src0_needs_grads) {
5471
+ ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, ggml_sgn(ctx, src0), grad));
5472
+ }
5473
+ } break;
5474
+ case GGML_UNARY_OP_SGN: {
5475
+ // noop
5476
+ } break;
5477
+ case GGML_UNARY_OP_NEG: {
5478
+ if (src0_needs_grads) {
5479
+ ggml_sub_or_set(ctx, cgraph, isrc0, grad);
5480
+ }
5481
+ } break;
5482
+ case GGML_UNARY_OP_STEP: {
5483
+ // noop
5484
+ } break;
5485
+ case GGML_UNARY_OP_RELU: {
5486
+ if (src0_needs_grads) {
5487
+ ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, ggml_step(ctx, src0), grad));
5488
+ }
5489
+ } break;
5490
+ case GGML_UNARY_OP_SILU: {
5491
+ if (src0_needs_grads) {
5492
+ ggml_add_or_set(ctx, cgraph, isrc0, ggml_silu_back(ctx, src0, grad));
5493
+ }
5494
+ } break;
5495
+ case GGML_UNARY_OP_EXP: {
5496
+ if (src0_needs_grads) {
5497
+ ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, tensor, grad));
5498
+ }
5499
+ } break;
5500
+ default: {
5501
+ fprintf(stderr, "%s: unsupported unary op for backward pass: %s\n",
5502
+ __func__, ggml_unary_op_name(ggml_get_unary_op(tensor)));
5503
+ GGML_ABORT("fatal error");
5504
+ } break;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5505
  }
5506
+ } break;
5507
+ case GGML_OP_CROSS_ENTROPY_LOSS: {
5508
+ if (src0_needs_grads) {
5509
+ ggml_add_or_set(ctx, cgraph, isrc0, ggml_cross_entropy_loss_back(ctx, src0, src1, grad));
5510
  }
5511
+ GGML_ASSERT(!src1_needs_grads && "backward pass for labels not implemented");
5512
+ } break;
5513
+ case GGML_OP_NONE: {
5514
+ // noop
5515
+ } break;
5516
  case GGML_OP_COUNT:
5517
+ default: {
5518
+ fprintf(stderr, "%s: unsupported ggml op for backward pass: %s\n", __func__, ggml_op_name(tensor->op));
5519
+ GGML_ABORT("fatal error");
5520
+ } break;
5521
  }
5522
 
5523
+ GGML_ASSERT(!src0_needs_grads || ggml_are_same_shape(src0, cgraph->grads[isrc0]));
5524
+ GGML_ASSERT(!src1_needs_grads || ggml_are_same_shape(src1, cgraph->grads[isrc1]));
5525
+ GGML_ASSERT(!src2_needs_grads || ggml_are_same_shape(src2, cgraph->grads[isrc2]));
 
 
5526
  }
5527
 
5528
  static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) {
 
 
 
 
 
 
 
 
5529
  // check if already visited
5530
  if (ggml_hash_insert(&cgraph->visited_hash_set, node) == GGML_HASHSET_ALREADY_EXISTS) {
5531
  return;
 
5586
  ggml_build_forward_impl(cgraph, tensor, true);
5587
  }
5588
 
5589
+ void ggml_build_backward_expand(
5590
+ struct ggml_context * ctx_static,
5591
+ struct ggml_context * ctx_compute,
5592
+ struct ggml_cgraph * cgraph,
5593
+ bool accumulate) {
5594
+ GGML_ASSERT(cgraph->n_nodes > 0);
5595
+ GGML_ASSERT(cgraph->grads);
5596
+ GGML_ASSERT(cgraph->grad_accs);
5597
+
5598
+ const int n_nodes_f = cgraph->n_nodes;
5599
 
5600
+ const size_t hash_size = ggml_hash_size(2*cgraph->size);
5601
+ memset(cgraph->grads, 0, hash_size*sizeof(struct ggml_tensor *));
5602
+ memset(cgraph->grad_accs, 0, hash_size*sizeof(struct ggml_tensor *));
5603
+ bool * grads_needed = calloc(hash_size, sizeof(bool));
5604
+
5605
+ {
5606
+ bool any_params = false;
5607
+ bool any_loss = false;
5608
+ for (int i = 0; i < n_nodes_f; ++i) {
5609
+ struct ggml_tensor * node = cgraph->nodes[i];
5610
+ any_params = any_params || (node->flags & GGML_TENSOR_FLAG_PARAM);
5611
+ any_loss = any_loss || (node->flags & GGML_TENSOR_FLAG_LOSS);
5612
+ }
5613
+ GGML_ASSERT(any_params && "no trainable parameters found, did you forget to call ggml_set_param?");
5614
+ GGML_ASSERT(any_loss && "no training loss found, did you forget to call ggml_set_loss?");
5615
+ }
5616
+
5617
+ for (int i = 0; i < n_nodes_f; ++i) {
5618
+ struct ggml_tensor * node = cgraph->nodes[i];
5619
 
5620
  if (node->type == GGML_TYPE_I32) {
5621
  continue;
5622
  }
5623
 
5624
+ bool node_needs_grad = node->flags & GGML_TENSOR_FLAG_PARAM;
5625
  bool ignore_src[GGML_MAX_SRC] = {false};
5626
  switch (node->op) {
5627
  // gradients in node->src[0] for one reason or another have no effect on output gradients
 
5649
  break;
5650
  }
5651
  for (int j = 0; j < GGML_MAX_SRC; ++j) {
5652
+ if (!node->src[j] || ignore_src[j] || !grads_needed[ggml_hash_find(&cgraph->visited_hash_set, node->src[j])]) {
5653
  continue;
5654
  }
5655
  GGML_ASSERT(node->src[j]->type == GGML_TYPE_F32 || node->src[j]->type == GGML_TYPE_F16);
5656
+ node_needs_grad = true;
5657
  break;
5658
  }
5659
+ if (!node_needs_grad) {
5660
  continue;
5661
  }
5662
 
 
5664
  GGML_ASSERT(!node->view_src || node->op == GGML_OP_CPY || node->op == GGML_OP_VIEW ||
5665
  node->op == GGML_OP_RESHAPE || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_TRANSPOSE);
5666
 
5667
+ const size_t igrad = ggml_hash_find(&cgraph->visited_hash_set, node);
5668
+ if ((accumulate && (node->flags & GGML_TENSOR_FLAG_PARAM)) || (node->flags & GGML_TENSOR_FLAG_LOSS)) {
5669
+ cgraph->grads[igrad] = ggml_dup_tensor(ctx_static, node);
5670
+ cgraph->grad_accs[igrad] = cgraph->grads[igrad];
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5671
  }
5672
+ grads_needed[igrad] = true;
5673
  }
5674
 
5675
+ for (int i = n_nodes_f - 1; i >= 0; --i) {
 
 
5676
  // inplace operations to add gradients are not created by ggml_compute_backward except for gradient accumulation
5677
  // use allocator to automatically make inplace operations
5678
+ ggml_compute_backward(ctx_compute, cgraph, i, grads_needed);
 
 
5679
  }
5680
 
5681
+ free(grads_needed);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5682
  }
5683
 
5684
  static void * incr_ptr_aligned(void ** p, size_t size, size_t align) {
 
5696
  incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // leafs
5697
  incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // hash keys
5698
  if (grads) {
5699
+ incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // grads
5700
+ incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // grad_accs
5701
  }
5702
  incr_ptr_aligned(&p, ggml_bitset_size(hash_size) * sizeof(ggml_bitset_t), sizeof(ggml_bitset_t));
5703
 
 
5723
 
5724
  void * p = cgraph + 1;
5725
 
5726
+ struct ggml_tensor ** nodes_ptr = incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
5727
+ struct ggml_tensor ** leafs_ptr = incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
5728
+ struct ggml_tensor ** hash_keys_ptr = incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
5729
+ struct ggml_tensor ** grads_ptr = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL;
5730
+ struct ggml_tensor ** grad_accs_ptr = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL;
5731
+
5732
  ggml_bitset_t * hash_used = incr_ptr_aligned(&p, ggml_bitset_size(hash_size) * sizeof(ggml_bitset_t), sizeof(ggml_bitset_t));
5733
 
5734
  // check that we allocated the correct amount of memory
 
5740
  /*.n_leafs =*/ 0,
5741
  /*.nodes =*/ nodes_ptr,
5742
  /*.grads =*/ grads_ptr,
5743
+ /*.grad_accs =*/ grad_accs_ptr,
5744
  /*.leafs =*/ leafs_ptr,
5745
  /*.hash_table =*/ { hash_size, hash_used, hash_keys_ptr },
5746
  /*.order =*/ GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT,
5747
  };
5748
 
5749
  ggml_hash_set_reset(&cgraph->visited_hash_set);
5750
+ if (grads) {
5751
+ memset(cgraph->grads, 0, hash_size*sizeof(struct ggml_tensor *));
5752
+ memset(cgraph->grad_accs, 0, hash_size*sizeof(struct ggml_tensor *));
5753
+ }
5754
 
5755
  return cgraph;
5756
  }
 
5766
  /*.n_leafs =*/ 0,
5767
  /*.nodes =*/ cgraph0->nodes + i0,
5768
  /*.grads =*/ cgraph0->grads ? cgraph0->grads + i0 : NULL,
5769
+ /*.grad_accs =*/ cgraph0->grad_accs ? cgraph0->grad_accs + i0 : NULL,
5770
  /*.leafs =*/ NULL,
5771
  /*.hash_table =*/ { 0, NULL, NULL },
5772
  /*.order =*/ cgraph0->order,
 
5792
  dst->nodes[i] = src->nodes[i];
5793
  }
5794
 
 
 
 
 
 
 
 
5795
  for (size_t i = 0; i < src->visited_hash_set.size; ++i) {
5796
  // copy all hashset keys (tensors) that are in use
5797
  if (ggml_bitset_get(src->visited_hash_set.used, i)) {
5798
  ggml_hash_insert(&dst->visited_hash_set, src->visited_hash_set.keys[i]);
5799
  }
5800
  }
5801
+
5802
+ if (src->grads) {
5803
+ GGML_ASSERT(dst->grads != NULL);
5804
+ GGML_ASSERT(dst->grad_accs != NULL);
5805
+ for (int i = 0; i < src->n_nodes; ++i) {
5806
+ const size_t igrad_src = ggml_hash_find(&src->visited_hash_set, src->nodes[i]);
5807
+ const size_t igrad_dst = ggml_hash_find(&dst->visited_hash_set, dst->nodes[i]);
5808
+ dst->grads[igrad_dst] = src->grads[igrad_src];
5809
+ dst->grad_accs[igrad_dst] = src->grad_accs[igrad_src];
5810
+ }
5811
+ }
5812
  }
5813
 
5814
  struct ggml_cgraph * ggml_graph_dup(struct ggml_context * ctx, struct ggml_cgraph * cgraph) {
 
5834
  GGML_ASSERT(cgraph->grads != NULL);
5835
 
5836
  for (int i = 0; i < cgraph->n_nodes; i++) {
5837
+ struct ggml_tensor * node = cgraph->nodes[i];
5838
+ struct ggml_tensor * grad_acc = ggml_graph_get_grad_acc(cgraph, node);
5839
+
5840
+ if (node->op == GGML_OP_OPT_STEP_ADAMW) {
5841
+ // clear momenta
5842
+ if (node->src[2]->data) {
5843
+ ggml_set_zero(node->src[2]);
5844
+ }
5845
+ if (node->src[3]->data) {
5846
+ ggml_set_zero(node->src[3]);
5847
+ }
5848
+ }
5849
 
5850
  // initial gradients of loss should be 1, 0 otherwise
5851
+ if (grad_acc) {
5852
  if (node->flags & GGML_TENSOR_FLAG_LOSS) {
5853
+ GGML_ASSERT(grad_acc->type == GGML_TYPE_F32);
5854
+ GGML_ASSERT(ggml_is_scalar(grad_acc));
 
5855
 
5856
  const float onef = 1.0f;
5857
+ if (grad_acc->buffer) {
5858
+ ggml_backend_tensor_set(grad_acc, &onef, 0, sizeof(float));
5859
+ } else {
5860
+ GGML_ASSERT(grad_acc->data);
5861
+ *((float *) grad_acc->data) = onef;
5862
+ }
5863
  } else {
5864
+ ggml_set_zero(grad_acc);
5865
  }
5866
  }
 
 
 
 
 
 
 
 
5867
  }
5868
  }
5869
 
 
5901
  cgraph->n_nodes++;
5902
  }
5903
 
5904
+ struct ggml_tensor * ggml_graph_get_tensor(const struct ggml_cgraph * cgraph, const char * name) {
5905
  for (int i = 0; i < cgraph->n_leafs; i++) {
5906
  struct ggml_tensor * leaf = cgraph->leafs[i];
5907
 
 
5921
  return NULL;
5922
  }
5923
 
5924
+ struct ggml_tensor * ggml_graph_get_grad(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) {
5925
+ const size_t igrad = ggml_hash_find(&cgraph->visited_hash_set, node);
5926
+ return igrad != GGML_HASHSET_FULL && ggml_bitset_get(cgraph->visited_hash_set.used, igrad) ? cgraph->grads[igrad] : NULL;
5927
+ }
5928
+
5929
+ struct ggml_tensor * ggml_graph_get_grad_acc(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) {
5930
+ const size_t igrad = ggml_hash_find(&cgraph->visited_hash_set, node);
5931
+ return igrad != GGML_HASHSET_FULL && ggml_bitset_get(cgraph->visited_hash_set.used, igrad) ? cgraph->grad_accs[igrad] : NULL;
5932
+ }
5933
+
5934
  void ggml_graph_print(const struct ggml_cgraph * cgraph) {
5935
  GGML_LOG_INFO("=== GRAPH ===\n");
5936
 
 
5941
  GGML_LOG_INFO(" - %3d: [ %5" PRId64 ", %5" PRId64 ", %5" PRId64 "] %16s %s\n",
5942
  i,
5943
  node->ne[0], node->ne[1], node->ne[2],
5944
+ ggml_op_name(node->op), (node->flags & GGML_TENSOR_FLAG_PARAM) ? "x" :
5945
+ ggml_graph_get_grad(cgraph, node) ? "g" : " ");
5946
  }
5947
 
5948
  GGML_LOG_INFO("n_leafs = %d\n", cgraph->n_leafs);
 
5977
  static struct ggml_tensor * ggml_graph_get_parent(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) {
5978
  for (int i = 0; i < cgraph->n_nodes; i++) {
5979
  struct ggml_tensor * parent = cgraph->nodes[i];
5980
+ struct ggml_tensor * grad = ggml_graph_get_grad(cgraph, parent);
5981
 
5982
+ if (grad == node) {
5983
  return parent;
5984
  }
5985
  }
 
6019
 
6020
  for (int i = 0; i < gb->n_nodes; i++) {
6021
  struct ggml_tensor * node = gb->nodes[i];
6022
+ struct ggml_tensor * grad = ggml_graph_get_grad(gb, node);
6023
 
6024
  if (ggml_graph_get_parent(gb, node) != NULL) {
6025
  continue;
 
6027
 
6028
  if (node->flags & GGML_TENSOR_FLAG_PARAM) {
6029
  snprintf(color, sizeof(color), "yellow");
6030
+ } else if (grad) {
6031
  if (ggml_graph_find(gf, node)) {
6032
  snprintf(color, sizeof(color), "green");
6033
  } else {
 
6054
  fprintf(fp, "%d [%" PRId64 ", %" PRId64 ", %" PRId64 "] | <x>%s", i, node->ne[0], node->ne[1], node->ne[2], ggml_op_symbol(node->op));
6055
  }
6056
 
6057
+ if (grad) {
6058
+ fprintf(fp, " | <g>%s\"; ]\n", ggml_op_symbol(grad->op));
6059
  } else {
6060
  fprintf(fp, "\"; ]\n");
6061
  }