AB498's picture
.
8a48ded
import gradio as gr
from transformers import RobertaTokenizer, RobertaForMaskedLM
import torch
# Load CodeBERT model and tokenizer
model_name = "microsoft/codebert-base-mlm"
tokenizer = RobertaTokenizer.from_pretrained(model_name)
model = RobertaForMaskedLM.from_pretrained(model_name)
def predict(code, num_predictions=5):
"""
Predict the masked token in code.
Use <mask> to indicate where to predict.
Args:
code: Code snippet with <mask> token
num_predictions: Number of top predictions to return
Returns:
JSON object with predictions
"""
try:
# Replace <mask> with the tokenizer's mask token
code_input = code.replace("<mask>", tokenizer.mask_token)
# Tokenize input
inputs = tokenizer(code_input, return_tensors="pt")
# Find the position of the mask token
mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
if len(mask_token_index) == 0:
return {
"error": "No <mask> token found in the input. Please include <mask> where you want predictions.",
"predictions": []
}
# Get predictions
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
# Get top-k predictions for the mask token
mask_token_logits = logits[0, mask_token_index, :]
top_tokens = torch.topk(mask_token_logits, num_predictions, dim=1)
predictions = []
for rank, (token_id, score) in enumerate(zip(top_tokens.indices[0].tolist(), top_tokens.values[0].tolist()), 1):
predicted_token = tokenizer.decode([token_id])
completed_code = code_input.replace(tokenizer.mask_token, predicted_token)
predictions.append({
"rank": rank,
"token": predicted_token,
"score": round(float(score), 4),
"completed_code": completed_code
})
return {
"original_code": code,
"predictions": predictions
}
except Exception as e:
return {
"error": str(e),
"predictions": []
}
# Create Gradio interface
with gr.Blocks(title="CodeBERT Masked Language Model") as demo:
gr.Markdown(
"""
# CodeBERT Masked Language Model
This model predicts masked tokens in code. Use `<mask>` to indicate where you want predictions.
### Examples:
- `def <mask>(x, y): return x + y`
- `import <mask>`
- `for i in <mask>(10):`
- `x = [1, 2, 3]; y = x.<mask>()`
"""
)
with gr.Row():
with gr.Column():
code_input = gr.Textbox(
label="Code with <mask>",
placeholder="Enter code with <mask> token...",
lines=5,
value="def <mask>(x, y):\n return x + y"
)
num_predictions_slider = gr.Slider(
minimum=1,
maximum=10,
value=5,
step=1,
label="Number of predictions"
)
predict_btn = gr.Button("Predict", variant="primary")
with gr.Column():
output = gr.JSON(
label="Predictions"
)
# Examples
gr.Examples(
examples=[
["def <mask>(x, y):\n return x + y", 5],
["import <mask>", 5],
["for i in <mask>(10):", 5],
["x = [1, 2, 3]\ny = x.<mask>()", 5],
["if x <mask> 0:", 5],
["class <mask>:", 5],
],
inputs=[code_input, num_predictions_slider],
)
predict_btn.click(
fn=predict,
inputs=[code_input, num_predictions_slider],
outputs=output
)
if __name__ == "__main__":
demo.launch()