Jishnuuuu commited on
Commit
9e2f571
·
verified ·
1 Parent(s): 1dde640

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -236
app.py CHANGED
@@ -1,243 +1,30 @@
1
- from fastapi import FastAPI, HTTPException, Depends, Header
2
- from fastapi.middleware.cors import CORSMiddleware
3
- from pydantic import BaseModel
4
- from transformers import DistilBertForSequenceClassification, DistilBertTokenizer
5
- import torch
6
- from typing import Dict, List, Optional
7
- import os
8
- from firebase_service import firebase_service
9
 
10
- app = FastAPI(title="CyberGuard AI API", version="1.0.0")
 
 
11
 
12
- # Enable CORS for React Native app
13
- app.add_middleware(
14
- CORSMiddleware,
15
- allow_origins=["*"], # In production, specify your app's origin
16
- allow_credentials=True,
17
- allow_methods=["*"],
18
- allow_headers=["*"],
19
- )
20
-
21
- # Global variables for model and tokenizer
22
- model = None
23
- tokenizer = None
24
- device = None
25
-
26
- # Labels for classification (inverted based on model training)
27
- LABELS = {
28
- 0: "safe",
29
- 1: "threat"
30
- }
31
-
32
- SEVERITY_MAPPING = {
33
- "safe": "none",
34
- "threat": "high"
35
- }
36
-
37
- class MessageRequest(BaseModel):
38
- text: str
39
- app_name: Optional[str] = None
40
-
41
- class MessageResponse(BaseModel):
42
- text: str
43
- is_harmful: bool
44
- label: str
45
- confidence: float
46
- severity: str
47
- details: Dict
48
-
49
- class AuthenticatedUser(BaseModel):
50
- uid: str
51
- email: str
52
- email_verified: bool
53
-
54
- class UserStatsResponse(BaseModel):
55
- total_alerts: int
56
- high_severity: int
57
- medium_severity: int
58
- low_severity: int
59
-
60
- # Authentication dependency
61
- async def get_current_user(authorization: Optional[str] = Header(None)) -> Optional[AuthenticatedUser]:
62
- """Get current authenticated user from Firebase ID token"""
63
- if not authorization or not authorization.startswith('Bearer '):
64
- return None
65
-
66
- token = authorization.split('Bearer ')[1]
67
- user_data = firebase_service.verify_user_token(token)
68
-
69
- if user_data:
70
- return AuthenticatedUser(
71
- uid=user_data['uid'],
72
- email=user_data['email'],
73
- email_verified=user_data['email_verified']
74
- )
75
- return None
76
-
77
- @app.on_event("startup")
78
- async def load_model():
79
- """Load the model and tokenizer on startup"""
80
- global model, tokenizer, device
81
 
82
  try:
83
- # Get the model path
84
- model_path = os.path.join(
85
- os.path.dirname(os.path.dirname(__file__)),
86
- "CyberGuard Model V1",
87
- "final_cyberbullying_model"
88
- )
89
-
90
- print(f"Loading model from: {model_path}")
91
-
92
- # Check if model exists
93
- if not os.path.exists(model_path):
94
- print(f"⚠️ Model path not found: {model_path}")
95
- print("⚠️ Backend will run in demo mode without AI model")
96
- return
97
-
98
- # Set device
99
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
100
- print(f"Using device: {device}")
101
-
102
- # Load tokenizer and model
103
- tokenizer = DistilBertTokenizer.from_pretrained(model_path)
104
- model = DistilBertForSequenceClassification.from_pretrained(model_path)
105
- model.to(device)
106
- model.eval()
107
-
108
- print("✅ Model loaded successfully!")
109
-
110
- except Exception as e:
111
- print(f"⚠️ Model loading failed: {str(e)}")
112
- print("⚠️ Backend will run in demo mode without AI model")
113
- # Don't raise - allow backend to start without model
114
- model = None
115
- tokenizer = None
116
-
117
- @app.get("/")
118
- async def root():
119
- """Health check endpoint"""
120
- return {
121
- "status": "online",
122
- "message": "CyberGuard AI API is running",
123
- "model_loaded": model is not None
124
- }
125
-
126
- @app.post("/analyze", response_model=MessageResponse)
127
- async def analyze_message(request: MessageRequest, current_user: Optional[AuthenticatedUser] = Depends(get_current_user)):
128
- """Analyze a message for cyberbullying/threats"""
129
-
130
- if model is None or tokenizer is None:
131
- raise HTTPException(status_code=503, detail="Model not loaded")
132
-
133
- if not request.text or len(request.text.strip()) == 0:
134
- raise HTTPException(status_code=400, detail="Text cannot be empty")
135
-
136
- try:
137
- # Tokenize input
138
- inputs = tokenizer(
139
- request.text,
140
- return_tensors="pt",
141
- truncation=True,
142
- max_length=512,
143
- padding=True
144
- )
145
-
146
- # Move to device
147
- inputs = {k: v.to(device) for k, v in inputs.items()}
148
-
149
- # Get prediction
150
- with torch.no_grad():
151
- outputs = model(**inputs)
152
- logits = outputs.logits
153
- probabilities = torch.nn.functional.softmax(logits, dim=-1)
154
- predicted_class = torch.argmax(probabilities, dim=-1).item()
155
- confidence = probabilities[0][predicted_class].item()
156
-
157
- # Get label
158
- label = LABELS.get(predicted_class, "unknown")
159
- is_harmful = label == "threat"
160
- severity = SEVERITY_MAPPING.get(label, "none")
161
-
162
- # Additional details
163
- details = {
164
- "safe_probability": round(probabilities[0][0].item(), 4),
165
- "threat_probability": round(probabilities[0][1].item(), 4),
166
- "model": "DistilBERT",
167
- "version": "1.0"
168
  }
169
-
170
- # Store alert if user is authenticated and threat is detected
171
- if current_user and is_harmful:
172
- alert_data = {
173
- 'timestamp': None, # Will be set by Firebase
174
- 'severity': severity,
175
- 'threat_type': 'cyberbullying',
176
- 'app_name': request.app_name,
177
- 'confidence': round(confidence, 4),
178
- 'model_version': '1.0'
179
- }
180
- firebase_service.store_threat_alert(current_user.uid, alert_data)
181
-
182
- return MessageResponse(
183
- text=request.text,
184
- is_harmful=is_harmful,
185
- label=label,
186
- confidence=round(confidence, 4),
187
- severity=severity,
188
- details=details
189
- )
190
-
191
  except Exception as e:
192
- raise HTTPException(status_code=500, detail=f"Analysis failed: {str(e)}")
193
-
194
- @app.post("/analyze-batch")
195
- async def analyze_batch(messages: List[str]):
196
- """Analyze multiple messages at once"""
197
-
198
- if model is None or tokenizer is None:
199
- raise HTTPException(status_code=503, detail="Model not loaded")
200
-
201
- results = []
202
- for text in messages:
203
- if text and len(text.strip()) > 0:
204
- response = await analyze_message(MessageRequest(text=text))
205
- results.append(response.dict())
206
-
207
- return {"results": results, "count": len(results)}
208
-
209
- @app.get("/user/stats", response_model=UserStatsResponse)
210
- async def get_user_stats(current_user: AuthenticatedUser = Depends(get_current_user)):
211
- """Get user's threat alert statistics"""
212
- if not current_user:
213
- raise HTTPException(status_code=401, detail="Authentication required")
214
-
215
- stats = firebase_service.get_user_alert_stats(current_user.uid)
216
- return UserStatsResponse(**stats)
217
-
218
- @app.delete("/user/data")
219
- async def delete_user_data(current_user: AuthenticatedUser = Depends(get_current_user)):
220
- """Delete all user data (GDPR compliance)"""
221
- if not current_user:
222
- raise HTTPException(status_code=401, detail="Authentication required")
223
-
224
- success = firebase_service.delete_user_data(current_user.uid)
225
- if success:
226
- return {"message": "User data deleted successfully"}
227
- else:
228
- raise HTTPException(status_code=500, detail="Failed to delete user data")
229
-
230
- @app.get("/health")
231
- async def health_check():
232
- """Detailed health check"""
233
- return {
234
- "status": "healthy",
235
- "model_loaded": model is not None,
236
- "tokenizer_loaded": tokenizer is not None,
237
- "device": str(device) if device else "unknown",
238
- "firebase_initialized": firebase_service.is_initialized()
239
- }
240
 
241
- if __name__ == "__main__":
242
- import uvicorn
243
- uvicorn.run(app, host="0.0.0.0", port=8001)
 
1
+ import gradio as gr
2
+ from transformers import pipeline
 
 
 
 
 
 
3
 
4
+ print("🔄 Loading CyberGuard Model...")
5
+ classifier = pipeline("text-classification", model="Jishnuuuu/cyberguard-v1")
6
+ print("✅ Model loaded successfully!")
7
 
8
+ def analyze_message(text):
9
+ """Analyze message for cyberbullying threats"""
10
+ if not text or len(text.strip()) == 0:
11
+ return {"Label": "ERROR", "Confidence": 0.0}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  try:
14
+ result = classifier(text)[0]
15
+ return {
16
+ "Label": result['label'],
17
+ "Confidence": round(result['score'], 3)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  except Exception as e:
20
+ return {"Label": "ERROR", "Confidence": 0.0, "error": str(e)}
21
+
22
+ demo = gr.Interface(
23
+ fn=analyze_message,
24
+ inputs=gr.Textbox(label="Message", placeholder="Type a message..."),
25
+ outputs=gr.JSON(label="Result"),
26
+ title="🛡️ CyberGuard AI",
27
+ description="Cyberbullying Detection - Model: Jishnuuuu/cyberguard-v1"
28
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
+ demo.launch()