File size: 2,960 Bytes
6dbfa40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
#include <torch/all.h>
#include <torch/library.h>

#include "registration.h"
#include "torch_binding.h"

at::Tensor lsh_cumulation_wrapper(
    at::Tensor query_mask,         // [batch_size, num_query]
    at::Tensor query_hash_code,    // [batch_size, num_query, num_hash_f]
    at::Tensor key_mask,           // [batch_size, num_key]
    at::Tensor key_hash_code,      // [batch_size, num_key, num_hash_f]
    at::Tensor value,              // [batch_size, num_key, value_dim]
    int64_t hashtable_capacity,
    bool use_cuda,
    int64_t version
) {
  return lsh_cumulation(
    query_mask,
    query_hash_code,
    key_mask,
    key_hash_code,
    value,
    static_cast<int>(hashtable_capacity),
    use_cuda,
    static_cast<int>(version)
  );
}

std::vector<at::Tensor> fast_hash_wrapper(
    at::Tensor query_mask,
    at::Tensor query_vector,
    at::Tensor key_mask,
    at::Tensor key_vector,
    int64_t num_hash_f,
    int64_t hash_code_len,
    bool use_cuda,
    int64_t version
) {
  return fast_hash(
    query_mask,
    query_vector,
    key_mask,
    key_vector,
    static_cast<int>(num_hash_f),
    static_cast<int>(hash_code_len),
    use_cuda,
    static_cast<int>(version)
  );
}

at::Tensor lsh_weighted_cumulation_wrapper(
    at::Tensor query_mask,         // [batch_size, num_query]
    at::Tensor query_hash_code,    // [batch_size, num_query, num_hash_f]
    at::Tensor query_weight,       // [batch_size, num_query, weight_dim]
    at::Tensor key_mask,           // [batch_size, num_key]
    at::Tensor key_hash_code,      // [batch_size, num_key, num_hash_f]
    at::Tensor key_weight,         // [batch_size, num_key, weight_dim]
    at::Tensor value,              // [batch_size, num_key, value_dim]
    int64_t hashtable_capacity,
    bool use_cuda,
    int64_t version
) {
  return lsh_weighted_cumulation(
    query_mask,
    query_hash_code,
    query_weight,
    key_mask,
    key_hash_code,
    key_weight,
    value,
    static_cast<int>(hashtable_capacity),
    use_cuda,
    static_cast<int>(version)
  );
}
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
  ops.def("lsh_cumulation(Tensor query_mask, Tensor query_hash_code, Tensor key_mask, Tensor key_hash_code, Tensor value, int hashtable_capacity, bool use_cuda, int version) -> Tensor");
  ops.impl("lsh_cumulation", torch::kCUDA, &lsh_cumulation_wrapper);

  ops.def("fast_hash(Tensor query_mask, Tensor query_vector, Tensor key_mask, Tensor key_vector, int num_hash_f, int hash_code_len, bool use_cuda, int version) -> Tensor[]");
  ops.impl("fast_hash", torch::kCUDA, &fast_hash_wrapper);

  ops.def("lsh_weighted_cumulation(Tensor query_mask, Tensor query_hash_code, Tensor query_weight, Tensor key_mask, Tensor key_hash_code, Tensor key_weight, Tensor value, int hashtable_capacity, bool use_cuda, int version) -> Tensor");
  ops.impl("lsh_weighted_cumulation", torch::kCUDA, &lsh_weighted_cumulation_wrapper);
}

REGISTER_EXTENSION(TORCH_EXTENSION_NAME)