slaren commited on
Commit
050174c
·
1 Parent(s): 64976cd

ggml-backend : fix async copy from CPU (llama/8897)

Browse files

* ggml-backend : fix async copy from CPU

* cuda : more reliable async copy, fix stream used when the devices are the same

Files changed (2) hide show
  1. ggml/src/ggml-backend.c +15 -10
  2. ggml/src/ggml-cuda.cu +15 -13
ggml/src/ggml-backend.c CHANGED
@@ -351,15 +351,10 @@ void ggml_backend_tensor_copy_async(ggml_backend_t backend_src, ggml_backend_t b
351
  }
352
 
353
  // an async copy would normally happen after all the queued operations on both backends are completed
354
- // sync src, set_async dst
355
- if (ggml_backend_buffer_is_host(src->buffer)) {
356
- ggml_backend_synchronize(backend_src);
357
- ggml_backend_tensor_set_async(backend_dst, dst, src->data, 0, ggml_nbytes(src));
358
- } else {
359
- ggml_backend_synchronize(backend_src);
360
- ggml_backend_tensor_copy(src, dst);
361
- ggml_backend_synchronize(backend_dst);
362
- }
363
  }
364
 
365
  // events
@@ -1782,7 +1777,17 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
1782
  } else {
1783
  ggml_backend_synchronize(split_backend);
1784
  }
1785
- ggml_backend_tensor_copy_async(input_backend, split_backend, input, input_cpy);
 
 
 
 
 
 
 
 
 
 
1786
  }
1787
  }
1788
 
 
351
  }
352
 
353
  // an async copy would normally happen after all the queued operations on both backends are completed
354
+ // to simulate the same behavior, we need to synchronize both backends first, and do a blocking copy
355
+ ggml_backend_synchronize(backend_src);
356
+ ggml_backend_synchronize(backend_dst);
357
+ ggml_backend_tensor_copy(src, dst);
 
 
 
 
 
358
  }
359
 
360
  // events
 
1777
  } else {
1778
  ggml_backend_synchronize(split_backend);
1779
  }
1780
+ // try async copy, but if not possible, we can still use a sync copy without synchronizing the dst backend, since we handle the synchronization here with multiple copies and events
1781
+ // TODO: add public function to facilitate this, since applications do not have direct access to the backend interface
1782
+ if (!split_backend->iface.cpy_tensor_async || !split_backend->iface.cpy_tensor_async(input_backend, split_backend, input, input_cpy)) {
1783
+ ggml_backend_synchronize(input_backend);
1784
+ if (sched->events[split_backend_id][sched->cur_copy] != NULL) {
1785
+ ggml_backend_event_synchronize(sched->events[split_backend_id][sched->cur_copy]);
1786
+ } else {
1787
+ ggml_backend_synchronize(split_backend);
1788
+ }
1789
+ ggml_backend_tensor_copy(input, input_cpy);
1790
+ }
1791
  }
1792
  }
1793
 
ggml/src/ggml-cuda.cu CHANGED
@@ -2358,33 +2358,35 @@ GGML_CALL static void ggml_backend_cuda_get_tensor_async(ggml_backend_t backend,
2358
  }
2359
 
2360
  GGML_CALL static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, const ggml_tensor * src, ggml_tensor * dst) {
2361
- GGML_ASSERT(ggml_backend_is_cuda(backend_src) || ggml_backend_is_cuda(backend_dst));
2362
-
2363
  ggml_backend_buffer_t buf_src = src->view_src ? src->view_src->buffer : src->buffer;
2364
  ggml_backend_buffer_t buf_dst = dst->view_src ? dst->view_src->buffer : dst->buffer;
2365
 
2366
- if (!ggml_backend_buffer_is_cuda(src->buffer)) {
2367
  return false;
2368
  }
2369
 
2370
- if (!ggml_backend_buffer_is_cuda(dst->buffer)) {
2371
  return false;
2372
  }
2373
 
2374
- // device -> device
2375
  ggml_backend_cuda_context * cuda_ctx_src = (ggml_backend_cuda_context *)backend_src->context;
2376
  ggml_backend_cuda_context * cuda_ctx_dst = (ggml_backend_cuda_context *)backend_dst->context;
2377
 
2378
- if (backend_src != backend_dst) {
2379
- ggml_backend_cuda_buffer_context * buf_ctx_src = (ggml_backend_cuda_buffer_context *)buf_src->context;
2380
- ggml_backend_cuda_buffer_context * buf_ctx_dst = (ggml_backend_cuda_buffer_context *)buf_dst->context;
2381
 
2382
- GGML_ASSERT(cuda_ctx_src->device == buf_ctx_src->device);
2383
- GGML_ASSERT(cuda_ctx_dst->device == buf_ctx_dst->device);
 
 
 
 
2384
 
 
2385
  // copy on src stream
2386
  if (cuda_ctx_src->device == cuda_ctx_dst->device) {
2387
- CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(dst), cudaMemcpyDeviceToDevice, cuda_ctx_dst->stream()));
2388
  } else {
2389
  #ifdef GGML_CUDA_NO_PEER_COPY
2390
  return false;
@@ -2393,7 +2395,7 @@ GGML_CALL static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_
2393
  #endif
2394
  }
2395
 
2396
- // record event on src stream
2397
  if (!cuda_ctx_src->copy_event) {
2398
  ggml_cuda_set_device(cuda_ctx_src->device);
2399
  CUDA_CHECK(cudaEventCreateWithFlags(&cuda_ctx_src->copy_event, cudaEventDisableTiming));
@@ -2405,7 +2407,7 @@ GGML_CALL static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_
2405
  CUDA_CHECK(cudaStreamWaitEvent(cuda_ctx_dst->stream(), cuda_ctx_src->copy_event, 0));
2406
  } else {
2407
  // src and dst are on the same backend
2408
- CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(dst), cudaMemcpyDeviceToDevice, cuda_ctx_dst->stream()));
2409
  }
2410
  return true;
2411
  }
 
2358
  }
2359
 
2360
  GGML_CALL static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, const ggml_tensor * src, ggml_tensor * dst) {
 
 
2361
  ggml_backend_buffer_t buf_src = src->view_src ? src->view_src->buffer : src->buffer;
2362
  ggml_backend_buffer_t buf_dst = dst->view_src ? dst->view_src->buffer : dst->buffer;
2363
 
2364
+ if (!ggml_backend_is_cuda(backend_src) || !ggml_backend_is_cuda(backend_dst)) {
2365
  return false;
2366
  }
2367
 
2368
+ if (!ggml_backend_buffer_is_cuda(src->buffer) || !ggml_backend_buffer_is_cuda(dst->buffer)) {
2369
  return false;
2370
  }
2371
 
2372
+ // device -> device copy
2373
  ggml_backend_cuda_context * cuda_ctx_src = (ggml_backend_cuda_context *)backend_src->context;
2374
  ggml_backend_cuda_context * cuda_ctx_dst = (ggml_backend_cuda_context *)backend_dst->context;
2375
 
2376
+ ggml_backend_cuda_buffer_context * buf_ctx_src = (ggml_backend_cuda_buffer_context *)buf_src->context;
2377
+ ggml_backend_cuda_buffer_context * buf_ctx_dst = (ggml_backend_cuda_buffer_context *)buf_dst->context;
 
2378
 
2379
+ if (cuda_ctx_src->device != buf_ctx_src->device || cuda_ctx_dst->device != buf_ctx_dst->device) {
2380
+ #ifndef NDEBUG
2381
+ GGML_CUDA_LOG_WARN("%s: backend and buffer devices do not match\n", __func__);
2382
+ #endif
2383
+ return false;
2384
+ }
2385
 
2386
+ if (backend_src != backend_dst) {
2387
  // copy on src stream
2388
  if (cuda_ctx_src->device == cuda_ctx_dst->device) {
2389
+ CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(dst), cudaMemcpyDeviceToDevice, cuda_ctx_src->stream()));
2390
  } else {
2391
  #ifdef GGML_CUDA_NO_PEER_COPY
2392
  return false;
 
2395
  #endif
2396
  }
2397
 
2398
+ // record event on src stream after the copy
2399
  if (!cuda_ctx_src->copy_event) {
2400
  ggml_cuda_set_device(cuda_ctx_src->device);
2401
  CUDA_CHECK(cudaEventCreateWithFlags(&cuda_ctx_src->copy_event, cudaEventDisableTiming));
 
2407
  CUDA_CHECK(cudaStreamWaitEvent(cuda_ctx_dst->stream(), cuda_ctx_src->copy_event, 0));
2408
  } else {
2409
  // src and dst are on the same backend
2410
+ CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(dst), cudaMemcpyDeviceToDevice, cuda_ctx_src->stream()));
2411
  }
2412
  return true;
2413
  }