GPL: Generative Pseudo Labeling for Unsupervised Domain Adaptation of Dense Retrieval
Paper
•
2112.07577
•
Published
This is a doc2query model based on mT5 (also known as docT5query).
It can be used for:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
model_name = 'NghiemAbe/Law-Doc2Query'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
text = "1. Bố trí lực lượng tham gia tuần tra, kiểm soát trật tự, an toàn giao thông theo kế hoạch. 2. Thống kê, báo cáo các vụ, việc vi phạm pháp luật, tai nạn giao thông đường bộ; kết quả tuần tra, kiểm soát và xử lý vi phạm hành chính về trật tự, an toàn giao thông đường bộ theo sự phân công trong kế hoạch. 3. Trường hợp không có lực lượng Cảnh sát giao thông đi cùng thì lực lượng Cảnh sát khác và Công an xã thực hiện việc tuần tra, kiểm soát theo kế hoạch đã được cấp có thẩm quyền phê duyệt. 4. Lực lượng Công an xã chỉ được tuần tra, kiểm soát trên các tuyến đường liên xã, liên thôn thuộc địa bàn quản lý và xử lý các hành vi vi phạm trật tự, an toàn giao thông sau: điều khiển xe mô tô, xe gắn máy không đội mũ bảo hiểm, chở quá số người quy định, chở hàng hóa cồng kềnh; đỗ xe ở lòng đường trái quy định; điều khiển phương tiện phóng nhanh, lạng lách, đánh võng, tháo ống xả, không có gương chiếu hậu hoặc chưa đủ tuổi điều khiển phương tiện theo quy định của pháp luật và các hành vi vi phạm hành lang an toàn giao thông đường bộ như họp chợ dưới lòng đường, lấn chiếm hành lang an toàn giao thông. Nghiêm cấm việc Công an xã dừng xe, kiểm soát trên các tuyến quốc lộ, tỉnh lộ."
def create_queries(para):
input_ids = tokenizer.encode(para, return_tensors='pt')
with torch.no_grad():
# Here we use top_k / top_k random sampling. It generates more diverse queries, but of lower quality
sampling_outputs = model.generate(
input_ids=input_ids,
max_length=64,
do_sample=True,
top_p=0.95,
top_k=10,
num_return_sequences=5
)
# Here we use Beam-search. It generates better quality queries, but with less diversity
beam_outputs = model.generate(
input_ids=input_ids,
max_length=64,
num_beams=5,
no_repeat_ngram_size=2,
num_return_sequences=5,
early_stopping=True
)
print("Paragraph:")
print(para)
print("\nBeam Outputs:")
for i in range(len(beam_outputs)):
query = tokenizer.decode(beam_outputs[i], skip_special_tokens=True)
print(f'{i + 1}: {query}')
print("\nSampling Outputs:")
for i in range(len(sampling_outputs)):
query = tokenizer.decode(sampling_outputs[i], skip_special_tokens=True)
print(f'{i + 1}: {query}')
create_queries(text)
Beam Outputs:
1: Trách nhiệm của Công an xã trong việc tuần tra, kiểm soát giao thông đường bộ được quy định như thế nào?
2: Trách nhiệm của Công an xã trong việc tuần tra, kiểm soát trật tự, an toàn giao thông là gì?
3: Công an xã có được tuần tra, kiểm soát hành lang an toàn giao thông không?
4: Công an xã có được tuần tra, kiểm soát trên các tuyến đường liên thôn không?
5: Lực lượng Công an xã có được tuần tra, kiểm soát trên các tuyến đường liên thôn không?
Sampling Outputs:
1: Tiêu chuẩn về hành vi vi phạm hành lang an toàn giao thbuffer được quy định như thế nào?
2: Trách nhiệm của Công an xã trong việc xử lý các hành vi vi phạm hành chính về đường bộ là gì?
3: Trách nhiệm của lực lượng Cảnh sát giao thông đối với tình trạng tai nạn giao thông (07/2016) được quy định như thế nào?
4: Lực lượng Công an xã có được tuần tra trong các tuyến đường lớn, liên thôn không?
5: Cảnh sát giao thông có Nordland dừng xe không?
Note: model.generate() is non-deterministic for top_k/top_n sampling. It produces different queries each time you run it.
This model fine-tuned doc2query/msmarco-vietnamese-mt5-base-v1 for 4k training steps (4 epochs on the 2k5 training pairs from Legal).