|
|
#include <torch/all.h> |
|
|
#include <torch/library.h> |
|
|
|
|
|
#include "registration.h" |
|
|
#include "torch_binding.h" |
|
|
|
|
|
at::Tensor lsh_cumulation_wrapper( |
|
|
at::Tensor query_mask, |
|
|
at::Tensor query_hash_code, |
|
|
at::Tensor key_mask, |
|
|
at::Tensor key_hash_code, |
|
|
at::Tensor value, |
|
|
int64_t hashtable_capacity, |
|
|
bool use_cuda, |
|
|
int64_t version |
|
|
) { |
|
|
return lsh_cumulation( |
|
|
query_mask, |
|
|
query_hash_code, |
|
|
key_mask, |
|
|
key_hash_code, |
|
|
value, |
|
|
static_cast<int>(hashtable_capacity), |
|
|
use_cuda, |
|
|
static_cast<int>(version) |
|
|
); |
|
|
} |
|
|
|
|
|
std::vector<at::Tensor> fast_hash_wrapper( |
|
|
at::Tensor query_mask, |
|
|
at::Tensor query_vector, |
|
|
at::Tensor key_mask, |
|
|
at::Tensor key_vector, |
|
|
int64_t num_hash_f, |
|
|
int64_t hash_code_len, |
|
|
bool use_cuda, |
|
|
int64_t version |
|
|
) { |
|
|
return fast_hash( |
|
|
query_mask, |
|
|
query_vector, |
|
|
key_mask, |
|
|
key_vector, |
|
|
static_cast<int>(num_hash_f), |
|
|
static_cast<int>(hash_code_len), |
|
|
use_cuda, |
|
|
static_cast<int>(version) |
|
|
); |
|
|
} |
|
|
|
|
|
at::Tensor lsh_weighted_cumulation_wrapper( |
|
|
at::Tensor query_mask, |
|
|
at::Tensor query_hash_code, |
|
|
at::Tensor query_weight, |
|
|
at::Tensor key_mask, |
|
|
at::Tensor key_hash_code, |
|
|
at::Tensor key_weight, |
|
|
at::Tensor value, |
|
|
int64_t hashtable_capacity, |
|
|
bool use_cuda, |
|
|
int64_t version |
|
|
) { |
|
|
return lsh_weighted_cumulation( |
|
|
query_mask, |
|
|
query_hash_code, |
|
|
query_weight, |
|
|
key_mask, |
|
|
key_hash_code, |
|
|
key_weight, |
|
|
value, |
|
|
static_cast<int>(hashtable_capacity), |
|
|
use_cuda, |
|
|
static_cast<int>(version) |
|
|
); |
|
|
} |
|
|
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { |
|
|
ops.def("lsh_cumulation(Tensor query_mask, Tensor query_hash_code, Tensor key_mask, Tensor key_hash_code, Tensor value, int hashtable_capacity, bool use_cuda, int version) -> Tensor"); |
|
|
ops.impl("lsh_cumulation", torch::kCUDA, &lsh_cumulation_wrapper); |
|
|
|
|
|
ops.def("fast_hash(Tensor query_mask, Tensor query_vector, Tensor key_mask, Tensor key_vector, int num_hash_f, int hash_code_len, bool use_cuda, int version) -> Tensor[]"); |
|
|
ops.impl("fast_hash", torch::kCUDA, &fast_hash_wrapper); |
|
|
|
|
|
ops.def("lsh_weighted_cumulation(Tensor query_mask, Tensor query_hash_code, Tensor query_weight, Tensor key_mask, Tensor key_hash_code, Tensor key_weight, Tensor value, int hashtable_capacity, bool use_cuda, int version) -> Tensor"); |
|
|
ops.impl("lsh_weighted_cumulation", torch::kCUDA, &lsh_weighted_cumulation_wrapper); |
|
|
} |
|
|
|
|
|
REGISTER_EXTENSION(TORCH_EXTENSION_NAME) |