Juno360219 commited on
Commit
0bcd1db
·
verified ·
1 Parent(s): 6021a8e

Delete nfsw_filter.py

Browse files
Files changed (1) hide show
  1. nfsw_filter.py +0 -272
nfsw_filter.py DELETED
@@ -1,272 +0,0 @@
1
- # Strict mode: When predicted as NSFW, apply image blurring.
2
-
3
- import os
4
- from PIL import Image
5
- import numpy as np
6
- import onnxruntime as ort
7
- import json
8
- from huggingface_hub import hf_hub_download
9
-
10
-
11
- class NSFWDetector:
12
- """
13
- NSFW detector using YOLOv9 for image classification.
14
- """
15
-
16
- def __init__(self, repo_id="Falconsai/nsfw_image_detection",
17
- model_filename="falconsai_yolov9_nsfw_model_quantized.pt",
18
- labels_filename="labels.json",
19
- input_size=(224, 224)):
20
- """
21
- Initialize the NSFW detector.
22
-
23
- Args:
24
- repo_id (str): Hugging Face repository ID.
25
- model_filename (str): Model filename.
26
- labels_filename (str): Labels filename.
27
- input_size (tuple): Model input size (height, width).
28
- """
29
- self.repo_id = repo_id
30
- self.model_filename = model_filename
31
- self.labels_filename = labels_filename
32
- self.input_size = input_size
33
-
34
- # Download files from Hugging Face
35
- self.model_path = self._download_model()
36
- self.labels_path = self._download_labels()
37
-
38
- # Load labels
39
- self.labels = self._load_labels()
40
-
41
- # Load model
42
- self.session = self._load_model()
43
- self.input_name = self.session.get_inputs()[0].name
44
- self.output_name = self.session.get_outputs()[0].name
45
-
46
- def _download_model(self):
47
- """
48
- Download the model file from Hugging Face.
49
-
50
- Returns:
51
- str: Path to the downloaded model file.
52
- """
53
- try:
54
- print(f"Downloading model from {self.repo_id}: {self.model_filename}")
55
- model_path = hf_hub_download(
56
- repo_id=self.repo_id,
57
- filename=self.model_filename,
58
- cache_dir="./hf_cache"
59
- )
60
- print(f"✅ Model downloaded: {model_path}")
61
- return model_path
62
- except Exception as e:
63
- raise RuntimeError(f"Model download failed: {e}")
64
-
65
- def _download_labels(self):
66
- """
67
- Download the labels file from Hugging Face.
68
-
69
- Returns:
70
- str: Path to the downloaded labels file.
71
- """
72
- try:
73
- print(f"Downloading labels from {self.repo_id}: {self.labels_filename}")
74
- labels_path = hf_hub_download(
75
- repo_id=self.repo_id,
76
- filename=self.labels_filename,
77
- cache_dir="./hf_cache"
78
- )
79
- print(f"✅ Labels downloaded: {labels_path}")
80
- return labels_path
81
- except Exception as e:
82
- raise RuntimeError(f"Labels download failed: {e}")
83
-
84
- def _load_labels(self):
85
- """
86
- Load class labels.
87
-
88
- Returns:
89
- dict: Labels dictionary.
90
- """
91
- try:
92
- with open(self.labels_path, "r") as f:
93
- return json.load(f)
94
- except FileNotFoundError:
95
- raise FileNotFoundError(f"Labels file not found: {self.labels_path}")
96
- except json.JSONDecodeError:
97
- raise ValueError(f"Labels file is malformed: {self.labels_path}")
98
-
99
- def _load_model(self):
100
- """
101
- Load ONNX model.
102
-
103
- Returns:
104
- onnxruntime.InferenceSession: Model session.
105
- """
106
- try:
107
- return ort.InferenceSession(self.model_path)
108
- except Exception as e:
109
- raise RuntimeError(f"Model load failed: {self.model_path}, error: {e}")
110
-
111
- def _preprocess_image(self, image_path):
112
- """
113
- Preprocess image.
114
-
115
- Args:
116
- image_path (str): Image file path.
117
-
118
- Returns:
119
- tuple: (preprocessed tensor, original image)
120
- """
121
- try:
122
- # Load and convert image
123
- original_image = Image.open(image_path).convert("RGB")
124
-
125
- # Resize
126
- image_resized = original_image.resize(self.input_size, Image.Resampling.BILINEAR)
127
-
128
- # To numpy and normalize
129
- image_np = np.array(image_resized, dtype=np.float32) / 255.0
130
-
131
- # Reorder dims [H, W, C] -> [C, H, W]
132
- image_np = np.transpose(image_np, (2, 0, 1))
133
-
134
- # Add batch dim [C, H, W] -> [1, C, H, W]
135
- input_tensor = np.expand_dims(image_np, axis=0).astype(np.float32)
136
-
137
- return input_tensor, original_image
138
-
139
- except FileNotFoundError:
140
- raise FileNotFoundError(f"Image file not found: {image_path}")
141
- except Exception as e:
142
- raise RuntimeError(f"Image preprocessing failed: {e}")
143
-
144
- def _postprocess_predictions(self, predictions):
145
- """
146
- Postprocess model predictions.
147
-
148
- Args:
149
- predictions: Model output.
150
-
151
- Returns:
152
- str: Predicted class label.
153
- """
154
- predicted_index = np.argmax(predictions)
155
- predicted_label = self.labels[str(predicted_index)]
156
- return predicted_label
157
-
158
- def predict(self, image_path):
159
- """
160
- Run NSFW detection on a single image.
161
-
162
- Args:
163
- image_path (str): Image file path.
164
-
165
- Returns:
166
- tuple: (predicted label, original image)
167
- """
168
- # Preprocess image
169
- input_tensor, original_image = self._preprocess_image(image_path)
170
-
171
- # Run inference
172
- outputs = self.session.run([self.output_name], {self.input_name: input_tensor})
173
- predictions = outputs[0]
174
-
175
- # Postprocess
176
- predicted_label = self._postprocess_predictions(predictions)
177
-
178
- return predicted_label, original_image
179
-
180
- def predict_label_only(self, image_path):
181
- """
182
- Return only the predicted label (no image).
183
-
184
- Args:
185
- image_path (str): Image file path.
186
-
187
- Returns:
188
- str: Predicted class label.
189
- """
190
- predicted_label, _ = self.predict(image_path)
191
- return predicted_label
192
-
193
- def predict_from_pil(self, pil_image):
194
- """
195
- Run NSFW detection from a PIL Image object.
196
-
197
- Args:
198
- pil_image (PIL.Image): PIL image object.
199
-
200
- Returns:
201
- tuple: (predicted label, original image)
202
- """
203
- try:
204
- # Ensure RGB
205
- if pil_image.mode != "RGB":
206
- pil_image = pil_image.convert("RGB")
207
-
208
- # Resize
209
- image_resized = pil_image.resize(self.input_size, Image.Resampling.BILINEAR)
210
-
211
- # To numpy and normalize
212
- image_np = np.array(image_resized, dtype=np.float32) / 255.0
213
-
214
- # Reorder dims [H, W, C] -> [C, H, W]
215
- image_np = np.transpose(image_np, (2, 0, 1))
216
-
217
- # Add batch dim [C, H, W] -> [1, C, H, W]
218
- input_tensor = np.expand_dims(image_np, axis=0).astype(np.float32)
219
-
220
- # Run inference
221
- outputs = self.session.run([self.output_name], {self.input_name: input_tensor})
222
- predictions = outputs[0]
223
-
224
- # Postprocess
225
- predicted_label = self._postprocess_predictions(predictions)
226
-
227
- return predicted_label, pil_image
228
-
229
- except Exception as e:
230
- raise RuntimeError(f"PIL image prediction failed: {e}")
231
-
232
- def predict_pil_label_only(self, pil_image):
233
- """
234
- Return only the predicted label from a PIL Image.
235
-
236
- Args:
237
- pil_image (PIL.Image): PIL image object.
238
-
239
- Returns:
240
- str: Predicted class label.
241
- """
242
- predicted_label, _ = self.predict_from_pil(pil_image)
243
- return predicted_label
244
-
245
- # --- Usage example ---
246
- if __name__ == "__main__":
247
- # Config
248
- single_image_path = "datas/bad01.jpg"
249
-
250
- try:
251
- # Create detector (auto-download from Hugging Face)
252
- detector = NSFWDetector()
253
-
254
- # Check image file exists
255
- if os.path.exists(single_image_path):
256
- # Run prediction
257
- predicted_label = detector.predict_label_only(single_image_path)
258
- print(f"Image file: {single_image_path}")
259
- print(f"Prediction: {predicted_label}")
260
-
261
- # Strict NSFW check
262
- if predicted_label.lower() == "nsfw":
263
- print("❌ NSFW content detected, image will be blurred!")
264
- # TODO: Add image blur logic here
265
- else:
266
- print("✅ Image content is normal, allowed to use")
267
-
268
- else:
269
- print(f"Error: Image file does not exist: {single_image_path}")
270
-
271
- except Exception as e:
272
- print(f"Error initializing detector: {e}")