Spaces:
Runtime error
Runtime error
| import os | |
| import re | |
| import time | |
| import random | |
| import zipfile | |
| import javalang | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch_geometric | |
| from torch_geometric.data import Data, Dataset, DataLoader | |
| from sklearn.model_selection import train_test_split | |
| from sklearn.metrics import precision_recall_fscore_support | |
| from tqdm import tqdm | |
| import networkx as nx | |
| # ---- Utility functions ---- | |
| def unzip_dataset(zip_path, extract_to): | |
| with zipfile.ZipFile(zip_path, 'r') as zip_ref: | |
| zip_ref.extractall(extract_to) | |
| def normalize_java_code(code): | |
| # Remove single-line comments | |
| code = re.sub(r'//.*?\n', '', code) | |
| # Remove multi-line comments | |
| code = re.sub(r'/\*.*?\*/', '', code, flags=re.DOTALL) | |
| # Remove extra spaces and blank lines | |
| code = re.sub(r'\s+', ' ', code) | |
| return code.strip() | |
| def safe_parse_java(code): | |
| try: | |
| tokens = list(javalang.tokenizer.tokenize(code)) | |
| parser = javalang.parser.Parser(tokens) | |
| tree = parser.parse() | |
| return tree | |
| except Exception: | |
| return None | |
| def ast_to_graph(ast): | |
| graph = nx.DiGraph() | |
| def dfs(node, parent_id=None): | |
| node_id = len(graph) | |
| graph.add_node(node_id, label=type(node).__name__) | |
| if parent_id is not None: | |
| graph.add_edge(parent_id, node_id) | |
| for child in getattr(node, 'children', []): | |
| if isinstance(child, (list, tuple)): | |
| for item in child: | |
| if isinstance(item, javalang.ast.Node): | |
| dfs(item, node_id) | |
| elif isinstance(child, javalang.ast.Node): | |
| dfs(child, node_id) | |
| dfs(ast) | |
| return graph | |
| def tokenize_java_code(code): | |
| try: | |
| tokens = list(javalang.tokenizer.tokenize(code)) | |
| token_list = [token.value for token in tokens] | |
| return token_list | |
| except: | |
| return [] | |
| # ---- Data Preprocessing ---- | |
| class CloneDataset(Dataset): | |
| def __init__(self, root_dir, transform=None): | |
| super().__init__() | |
| self.data_list = [] | |
| self.labels = [] | |
| self.skipped_files = 0 | |
| self.max_tokens = 5000 | |
| clone_dirs = { | |
| "Clone_Type1": 1, | |
| "Clone_Type2": 1, | |
| "Clone_Type3 - ST": 1, | |
| "Clone_Type3 - VST": 1, | |
| "Clone_Type3 - MT": 0 # Assuming MT = Not Clone | |
| } | |
| for clone_type, label in clone_dirs.items(): | |
| clone_path = os.path.join(root_dir, 'Subject_CloneTypes_Directories', clone_type) | |
| for root, _, files in os.walk(clone_path): | |
| for file in files: | |
| if file.endswith(".java"): | |
| file_path = os.path.join(root, file) | |
| with open(file_path, 'r', encoding='utf-8', errors='ignore') as f: | |
| code = f.read() | |
| code = normalize_java_code(code) | |
| if len(code.split()) > self.max_tokens: | |
| self.skipped_files += 1 | |
| continue | |
| ast = safe_parse_java(code) | |
| if ast is None: | |
| self.skipped_files += 1 | |
| continue | |
| graph = ast_to_graph(ast) | |
| tokens = tokenize_java_code(code) | |
| if not tokens: | |
| self.skipped_files += 1 | |
| continue | |
| data = { | |
| 'graph': graph, | |
| 'tokens': tokens, | |
| 'label': label | |
| } | |
| self.data_list.append(data) | |
| def len(self): | |
| return len(self.data_list) | |
| def get(self, idx): | |
| data_item = self.data_list[idx] | |
| graph = data_item['graph'] | |
| tokens = data_item['tokens'] | |
| label = data_item['label'] | |
| # Graph processing | |
| edge_index = torch.tensor(list(graph.edges)).t().contiguous() | |
| node_features = torch.arange(graph.number_of_nodes()).unsqueeze(1).float() | |
| # Token processing | |
| token_indices = torch.tensor([hash(t) % 5000 for t in tokens], dtype=torch.long) | |
| return edge_index, node_features, token_indices, torch.tensor(label, dtype=torch.long) | |
| # ---- Models ---- | |
| class GNNEncoder(nn.Module): | |
| def __init__(self, in_channels=1, hidden_dim=64): | |
| super().__init__() | |
| self.conv1 = torch_geometric.nn.GCNConv(in_channels, hidden_dim) | |
| self.conv2 = torch_geometric.nn.GCNConv(hidden_dim, hidden_dim) | |
| def forward(self, x, edge_index): | |
| x = self.conv1(x, edge_index) | |
| x = F.relu(x) | |
| x = self.conv2(x, edge_index) | |
| x = F.relu(x) | |
| return torch.mean(x, dim=0) # Graph-level embedding | |
| class RNNEncoder(nn.Module): | |
| def __init__(self, vocab_size=5000, embedding_dim=64, hidden_dim=64): | |
| super().__init__() | |
| self.embedding = nn.Embedding(vocab_size, embedding_dim) | |
| self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True) | |
| def forward(self, tokens): | |
| embeds = self.embedding(tokens) | |
| _, (hidden, _) = self.lstm(embeds) | |
| return hidden.squeeze(0) | |
| class HybridClassifier(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.gnn = GNNEncoder() | |
| self.rnn = RNNEncoder() | |
| self.fc = nn.Linear(128, 2) | |
| def forward(self, edge_index, node_features, tokens): | |
| gnn_out = self.gnn(node_features, edge_index) | |
| rnn_out = self.rnn(tokens) | |
| combined = torch.cat([gnn_out, rnn_out], dim=-1) | |
| out = self.fc(combined) | |
| return out | |
| # ---- Training and Evaluation ---- | |
| def train(model, optimizer, loader, device): | |
| model.train() | |
| total_loss = 0 | |
| for edge_index, node_features, tokens, labels in loader: | |
| edge_index = edge_index.to(device) | |
| node_features = node_features.to(device) | |
| tokens = tokens.to(device) | |
| labels = labels.to(device) | |
| optimizer.zero_grad() | |
| outputs = model(edge_index, node_features, tokens) | |
| loss = F.cross_entropy(outputs.unsqueeze(0), labels.unsqueeze(0)) | |
| loss.backward() | |
| optimizer.step() | |
| total_loss += loss.item() | |
| return total_loss / len(loader) | |
| def evaluate(model, loader, device): | |
| model.eval() | |
| preds, labels_all = [], [] | |
| with torch.no_grad(): | |
| for edge_index, node_features, tokens, labels in loader: | |
| edge_index = edge_index.to(device) | |
| node_features = node_features.to(device) | |
| tokens = tokens.to(device) | |
| labels = labels.to(device) | |
| outputs = model(edge_index, node_features, tokens) | |
| pred = outputs.argmax(dim=-1) | |
| preds.append(pred.cpu().numpy()) | |
| labels_all.append(labels.cpu().numpy()) | |
| preds = np.concatenate(preds) | |
| labels_all = np.concatenate(labels_all) | |
| precision, recall, f1, _ = precision_recall_fscore_support(labels_all, preds, average='binary') | |
| return precision, recall, f1 | |
| # ---- Main Execution ---- | |
| if __name__ == "__main__": | |
| import numpy as np | |
| dataset_root = 'archive (1)' | |
| unzip_dataset('archive (1).zip', dataset_root) | |
| dataset = CloneDataset(dataset_root) | |
| print(f"Total valid samples: {dataset.len()}") | |
| print(f"Total skipped files: {dataset.skipped_files}") | |
| indices = list(range(dataset.len())) | |
| train_idx, temp_idx = train_test_split(indices, test_size=0.2, random_state=42) | |
| val_idx, test_idx = train_test_split(temp_idx, test_size=0.5, random_state=42) | |
| train_set = torch.utils.data.Subset(dataset, train_idx) | |
| val_set = torch.utils.data.Subset(dataset, val_idx) | |
| test_set = torch.utils.data.Subset(dataset, test_idx) | |
| batch_size = 1 # small because of variable graph sizes | |
| train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True) | |
| val_loader = DataLoader(val_set, batch_size=batch_size) | |
| test_loader = DataLoader(test_set, batch_size=batch_size) | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| model = HybridClassifier().to(device) | |
| optimizer = torch.optim.Adam(model.parameters(), lr=0.001) | |
| epochs = 5 | |
| start_time = time.time() | |
| for epoch in range(epochs): | |
| train_loss = train(model, optimizer, train_loader, device) | |
| precision, recall, f1 = evaluate(model, val_loader, device) | |
| print(f"Epoch {epoch+1}: Loss={train_loss:.4f}, Precision={precision:.4f}, Recall={recall:.4f}, F1={f1:.4f}") | |
| precision, recall, f1 = evaluate(model, test_loader, device) | |
| total_time = time.time() - start_time | |
| print(f"Test Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}") | |
| print(f"Total execution time: {total_time:.2f} seconds") | |