import gradio as gr from transformers import RobertaTokenizer, RobertaForMaskedLM import torch # Load CodeBERT model and tokenizer model_name = "microsoft/codebert-base-mlm" tokenizer = RobertaTokenizer.from_pretrained(model_name) model = RobertaForMaskedLM.from_pretrained(model_name) def predict(code, num_predictions=5): """ Predict the masked token in code. Use to indicate where to predict. Args: code: Code snippet with token num_predictions: Number of top predictions to return Returns: JSON object with predictions """ try: # Replace with the tokenizer's mask token code_input = code.replace("", tokenizer.mask_token) # Tokenize input inputs = tokenizer(code_input, return_tensors="pt") # Find the position of the mask token mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1] if len(mask_token_index) == 0: return { "error": "No token found in the input. Please include where you want predictions.", "predictions": [] } # Get predictions with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits # Get top-k predictions for the mask token mask_token_logits = logits[0, mask_token_index, :] top_tokens = torch.topk(mask_token_logits, num_predictions, dim=1) predictions = [] for rank, (token_id, score) in enumerate(zip(top_tokens.indices[0].tolist(), top_tokens.values[0].tolist()), 1): predicted_token = tokenizer.decode([token_id]) completed_code = code_input.replace(tokenizer.mask_token, predicted_token) predictions.append({ "rank": rank, "token": predicted_token, "score": round(float(score), 4), "completed_code": completed_code }) return { "original_code": code, "predictions": predictions } except Exception as e: return { "error": str(e), "predictions": [] } # Create Gradio interface with gr.Blocks(title="CodeBERT Masked Language Model") as demo: gr.Markdown( """ # CodeBERT Masked Language Model This model predicts masked tokens in code. Use `` to indicate where you want predictions. ### Examples: - `def (x, y): return x + y` - `import ` - `for i in (10):` - `x = [1, 2, 3]; y = x.()` """ ) with gr.Row(): with gr.Column(): code_input = gr.Textbox( label="Code with ", placeholder="Enter code with token...", lines=5, value="def (x, y):\n return x + y" ) num_predictions_slider = gr.Slider( minimum=1, maximum=10, value=5, step=1, label="Number of predictions" ) predict_btn = gr.Button("Predict", variant="primary") with gr.Column(): output = gr.JSON( label="Predictions" ) # Examples gr.Examples( examples=[ ["def (x, y):\n return x + y", 5], ["import ", 5], ["for i in (10):", 5], ["x = [1, 2, 3]\ny = x.()", 5], ["if x 0:", 5], ["class :", 5], ], inputs=[code_input, num_predictions_slider], ) predict_btn.click( fn=predict, inputs=[code_input, num_predictions_slider], outputs=output ) if __name__ == "__main__": demo.launch()