JohannesGaessler commited on
Commit
341f451
·
1 Parent(s): 5c0b540

mnist: fix segmentation fault (ggml/1227)

Browse files
Files changed (2) hide show
  1. ggml/include/ggml-opt.h +2 -0
  2. ggml/src/ggml-opt.cpp +5 -0
ggml/include/ggml-opt.h CHANGED
@@ -128,6 +128,8 @@ extern "C" {
128
  // set gradients to zero, initilize loss, and optionally reset the optimizer
129
  GGML_API void ggml_opt_reset(ggml_opt_context_t opt_ctx, bool optimizer);
130
 
 
 
131
  // get underlying tensors that store data
132
  // if not using static graphs these pointers become invalid with the next call to ggml_opt_alloc
133
  GGML_API struct ggml_tensor * ggml_opt_inputs( ggml_opt_context_t opt_ctx); // forward graph input tensor
 
128
  // set gradients to zero, initilize loss, and optionally reset the optimizer
129
  GGML_API void ggml_opt_reset(ggml_opt_context_t opt_ctx, bool optimizer);
130
 
131
+ GGML_API bool ggml_opt_static_graphs(ggml_opt_context_t opt_ctx); // whether the graphs are allocated_statically
132
+
133
  // get underlying tensors that store data
134
  // if not using static graphs these pointers become invalid with the next call to ggml_opt_alloc
135
  GGML_API struct ggml_tensor * ggml_opt_inputs( ggml_opt_context_t opt_ctx); // forward graph input tensor
ggml/src/ggml-opt.cpp CHANGED
@@ -576,6 +576,10 @@ void ggml_opt_reset(ggml_opt_context_t opt_ctx, bool optimizer) {
576
  }
577
  }
578
 
 
 
 
 
579
  struct ggml_tensor * ggml_opt_inputs(ggml_opt_context_t opt_ctx) {
580
  return opt_ctx->inputs;
581
  }
@@ -842,6 +846,7 @@ void ggml_opt_epoch(
842
  int64_t idata_split,
843
  ggml_opt_epoch_callback callback_train,
844
  ggml_opt_epoch_callback callback_eval) {
 
845
  struct ggml_tensor * inputs = ggml_opt_inputs(opt_ctx);
846
  struct ggml_tensor * labels = ggml_opt_labels(opt_ctx);
847
  struct ggml_tensor * data = ggml_opt_dataset_data(dataset);
 
576
  }
577
  }
578
 
579
+ bool ggml_opt_static_graphs(ggml_opt_context_t opt_ctx) {
580
+ return opt_ctx->static_graphs;
581
+ }
582
+
583
  struct ggml_tensor * ggml_opt_inputs(ggml_opt_context_t opt_ctx) {
584
  return opt_ctx->inputs;
585
  }
 
846
  int64_t idata_split,
847
  ggml_opt_epoch_callback callback_train,
848
  ggml_opt_epoch_callback callback_eval) {
849
+ GGML_ASSERT(ggml_opt_static_graphs(opt_ctx) && "ggml_opt_epoch requires static graphs");
850
  struct ggml_tensor * inputs = ggml_opt_inputs(opt_ctx);
851
  struct ggml_tensor * labels = ggml_opt_labels(opt_ctx);
852
  struct ggml_tensor * data = ggml_opt_dataset_data(dataset);