Aranwer's picture
Update app.py
1c3d984 verified
import os
import re
import time
import random
import zipfile
import javalang
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric
from torch_geometric.data import Data, Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_recall_fscore_support
from tqdm import tqdm
import networkx as nx
# ---- Utility functions ----
def unzip_dataset(zip_path, extract_to):
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(extract_to)
def normalize_java_code(code):
# Remove single-line comments
code = re.sub(r'//.*?\n', '', code)
# Remove multi-line comments
code = re.sub(r'/\*.*?\*/', '', code, flags=re.DOTALL)
# Remove extra spaces and blank lines
code = re.sub(r'\s+', ' ', code)
return code.strip()
def safe_parse_java(code):
try:
tokens = list(javalang.tokenizer.tokenize(code))
parser = javalang.parser.Parser(tokens)
tree = parser.parse()
return tree
except Exception:
return None
def ast_to_graph(ast):
graph = nx.DiGraph()
def dfs(node, parent_id=None):
node_id = len(graph)
graph.add_node(node_id, label=type(node).__name__)
if parent_id is not None:
graph.add_edge(parent_id, node_id)
for child in getattr(node, 'children', []):
if isinstance(child, (list, tuple)):
for item in child:
if isinstance(item, javalang.ast.Node):
dfs(item, node_id)
elif isinstance(child, javalang.ast.Node):
dfs(child, node_id)
dfs(ast)
return graph
def tokenize_java_code(code):
try:
tokens = list(javalang.tokenizer.tokenize(code))
token_list = [token.value for token in tokens]
return token_list
except:
return []
# ---- Data Preprocessing ----
class CloneDataset(Dataset):
def __init__(self, root_dir, transform=None):
super().__init__()
self.data_list = []
self.labels = []
self.skipped_files = 0
self.max_tokens = 5000
clone_dirs = {
"Clone_Type1": 1,
"Clone_Type2": 1,
"Clone_Type3 - ST": 1,
"Clone_Type3 - VST": 1,
"Clone_Type3 - MT": 0 # Assuming MT = Not Clone
}
for clone_type, label in clone_dirs.items():
clone_path = os.path.join(root_dir, 'Subject_CloneTypes_Directories', clone_type)
for root, _, files in os.walk(clone_path):
for file in files:
if file.endswith(".java"):
file_path = os.path.join(root, file)
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
code = f.read()
code = normalize_java_code(code)
if len(code.split()) > self.max_tokens:
self.skipped_files += 1
continue
ast = safe_parse_java(code)
if ast is None:
self.skipped_files += 1
continue
graph = ast_to_graph(ast)
tokens = tokenize_java_code(code)
if not tokens:
self.skipped_files += 1
continue
data = {
'graph': graph,
'tokens': tokens,
'label': label
}
self.data_list.append(data)
def len(self):
return len(self.data_list)
def get(self, idx):
data_item = self.data_list[idx]
graph = data_item['graph']
tokens = data_item['tokens']
label = data_item['label']
# Graph processing
edge_index = torch.tensor(list(graph.edges)).t().contiguous()
node_features = torch.arange(graph.number_of_nodes()).unsqueeze(1).float()
# Token processing
token_indices = torch.tensor([hash(t) % 5000 for t in tokens], dtype=torch.long)
return edge_index, node_features, token_indices, torch.tensor(label, dtype=torch.long)
# ---- Models ----
class GNNEncoder(nn.Module):
def __init__(self, in_channels=1, hidden_dim=64):
super().__init__()
self.conv1 = torch_geometric.nn.GCNConv(in_channels, hidden_dim)
self.conv2 = torch_geometric.nn.GCNConv(hidden_dim, hidden_dim)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index)
x = F.relu(x)
x = self.conv2(x, edge_index)
x = F.relu(x)
return torch.mean(x, dim=0) # Graph-level embedding
class RNNEncoder(nn.Module):
def __init__(self, vocab_size=5000, embedding_dim=64, hidden_dim=64):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim)
self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
def forward(self, tokens):
embeds = self.embedding(tokens)
_, (hidden, _) = self.lstm(embeds)
return hidden.squeeze(0)
class HybridClassifier(nn.Module):
def __init__(self):
super().__init__()
self.gnn = GNNEncoder()
self.rnn = RNNEncoder()
self.fc = nn.Linear(128, 2)
def forward(self, edge_index, node_features, tokens):
gnn_out = self.gnn(node_features, edge_index)
rnn_out = self.rnn(tokens)
combined = torch.cat([gnn_out, rnn_out], dim=-1)
out = self.fc(combined)
return out
# ---- Training and Evaluation ----
def train(model, optimizer, loader, device):
model.train()
total_loss = 0
for edge_index, node_features, tokens, labels in loader:
edge_index = edge_index.to(device)
node_features = node_features.to(device)
tokens = tokens.to(device)
labels = labels.to(device)
optimizer.zero_grad()
outputs = model(edge_index, node_features, tokens)
loss = F.cross_entropy(outputs.unsqueeze(0), labels.unsqueeze(0))
loss.backward()
optimizer.step()
total_loss += loss.item()
return total_loss / len(loader)
def evaluate(model, loader, device):
model.eval()
preds, labels_all = [], []
with torch.no_grad():
for edge_index, node_features, tokens, labels in loader:
edge_index = edge_index.to(device)
node_features = node_features.to(device)
tokens = tokens.to(device)
labels = labels.to(device)
outputs = model(edge_index, node_features, tokens)
pred = outputs.argmax(dim=-1)
preds.append(pred.cpu().numpy())
labels_all.append(labels.cpu().numpy())
preds = np.concatenate(preds)
labels_all = np.concatenate(labels_all)
precision, recall, f1, _ = precision_recall_fscore_support(labels_all, preds, average='binary')
return precision, recall, f1
# ---- Main Execution ----
if __name__ == "__main__":
import numpy as np
dataset_root = 'archive (1)'
unzip_dataset('archive (1).zip', dataset_root)
dataset = CloneDataset(dataset_root)
print(f"Total valid samples: {dataset.len()}")
print(f"Total skipped files: {dataset.skipped_files}")
indices = list(range(dataset.len()))
train_idx, temp_idx = train_test_split(indices, test_size=0.2, random_state=42)
val_idx, test_idx = train_test_split(temp_idx, test_size=0.5, random_state=42)
train_set = torch.utils.data.Subset(dataset, train_idx)
val_set = torch.utils.data.Subset(dataset, val_idx)
test_set = torch.utils.data.Subset(dataset, test_idx)
batch_size = 1 # small because of variable graph sizes
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_set, batch_size=batch_size)
test_loader = DataLoader(test_set, batch_size=batch_size)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = HybridClassifier().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
epochs = 5
start_time = time.time()
for epoch in range(epochs):
train_loss = train(model, optimizer, train_loader, device)
precision, recall, f1 = evaluate(model, val_loader, device)
print(f"Epoch {epoch+1}: Loss={train_loss:.4f}, Precision={precision:.4f}, Recall={recall:.4f}, F1={f1:.4f}")
precision, recall, f1 = evaluate(model, test_loader, device)
total_time = time.time() - start_time
print(f"Test Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}")
print(f"Total execution time: {total_time:.2f} seconds")