Spaces:
Sleeping
Sleeping
Commit
·
341f451
1
Parent(s):
5c0b540
mnist: fix segmentation fault (ggml/1227)
Browse files- ggml/include/ggml-opt.h +2 -0
- ggml/src/ggml-opt.cpp +5 -0
ggml/include/ggml-opt.h
CHANGED
|
@@ -128,6 +128,8 @@ extern "C" {
|
|
| 128 |
// set gradients to zero, initilize loss, and optionally reset the optimizer
|
| 129 |
GGML_API void ggml_opt_reset(ggml_opt_context_t opt_ctx, bool optimizer);
|
| 130 |
|
|
|
|
|
|
|
| 131 |
// get underlying tensors that store data
|
| 132 |
// if not using static graphs these pointers become invalid with the next call to ggml_opt_alloc
|
| 133 |
GGML_API struct ggml_tensor * ggml_opt_inputs( ggml_opt_context_t opt_ctx); // forward graph input tensor
|
|
|
|
| 128 |
// set gradients to zero, initilize loss, and optionally reset the optimizer
|
| 129 |
GGML_API void ggml_opt_reset(ggml_opt_context_t opt_ctx, bool optimizer);
|
| 130 |
|
| 131 |
+
GGML_API bool ggml_opt_static_graphs(ggml_opt_context_t opt_ctx); // whether the graphs are allocated_statically
|
| 132 |
+
|
| 133 |
// get underlying tensors that store data
|
| 134 |
// if not using static graphs these pointers become invalid with the next call to ggml_opt_alloc
|
| 135 |
GGML_API struct ggml_tensor * ggml_opt_inputs( ggml_opt_context_t opt_ctx); // forward graph input tensor
|
ggml/src/ggml-opt.cpp
CHANGED
|
@@ -576,6 +576,10 @@ void ggml_opt_reset(ggml_opt_context_t opt_ctx, bool optimizer) {
|
|
| 576 |
}
|
| 577 |
}
|
| 578 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 579 |
struct ggml_tensor * ggml_opt_inputs(ggml_opt_context_t opt_ctx) {
|
| 580 |
return opt_ctx->inputs;
|
| 581 |
}
|
|
@@ -842,6 +846,7 @@ void ggml_opt_epoch(
|
|
| 842 |
int64_t idata_split,
|
| 843 |
ggml_opt_epoch_callback callback_train,
|
| 844 |
ggml_opt_epoch_callback callback_eval) {
|
|
|
|
| 845 |
struct ggml_tensor * inputs = ggml_opt_inputs(opt_ctx);
|
| 846 |
struct ggml_tensor * labels = ggml_opt_labels(opt_ctx);
|
| 847 |
struct ggml_tensor * data = ggml_opt_dataset_data(dataset);
|
|
|
|
| 576 |
}
|
| 577 |
}
|
| 578 |
|
| 579 |
+
bool ggml_opt_static_graphs(ggml_opt_context_t opt_ctx) {
|
| 580 |
+
return opt_ctx->static_graphs;
|
| 581 |
+
}
|
| 582 |
+
|
| 583 |
struct ggml_tensor * ggml_opt_inputs(ggml_opt_context_t opt_ctx) {
|
| 584 |
return opt_ctx->inputs;
|
| 585 |
}
|
|
|
|
| 846 |
int64_t idata_split,
|
| 847 |
ggml_opt_epoch_callback callback_train,
|
| 848 |
ggml_opt_epoch_callback callback_eval) {
|
| 849 |
+
GGML_ASSERT(ggml_opt_static_graphs(opt_ctx) && "ggml_opt_epoch requires static graphs");
|
| 850 |
struct ggml_tensor * inputs = ggml_opt_inputs(opt_ctx);
|
| 851 |
struct ggml_tensor * labels = ggml_opt_labels(opt_ctx);
|
| 852 |
struct ggml_tensor * data = ggml_opt_dataset_data(dataset);
|