ahuggingface01 commited on
Commit
624fd4e
·
verified ·
1 Parent(s): e13728c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -25
app.py CHANGED
@@ -10,21 +10,20 @@ from diffusers import FluxPipeline
10
  from insightface.app import FaceAnalysis
11
  from insightface.model_zoo import get_model
12
 
13
- # --- GLOBAL CONFIG (CPU ONLY) ---
14
  MODEL_ID = "black-forest-labs/FLUX.1-dev"
15
  HF_TOKEN = os.getenv("HF_TOKEN")
16
 
17
- # We define the models as None globally and load them inside the GPU function
18
  face_app = None
19
  swapper = None
20
  pipe = None
21
 
22
- def load_models():
23
- """Initializes models inside the GPU-allocated context."""
24
  global face_app, swapper, pipe
25
 
26
  if face_app is None:
27
- # Use CPU provider initially to avoid startup crashes
28
  face_app = FaceAnalysis(name="buffalo_l", providers=['CPUExecutionProvider'])
29
  face_app.prepare(ctx_id=0, det_size=(640, 640))
30
 
@@ -39,7 +38,6 @@ def load_models():
39
  torch_dtype=torch.bfloat16,
40
  token=HF_TOKEN
41
  )
42
- # Offloading helps manage ZeroGPU's 70GB VRAM efficiently
43
  pipe.enable_model_cpu_offload()
44
 
45
  def upscale_image(image):
@@ -52,28 +50,30 @@ def upscale_image(image):
52
 
53
  @spaces.GPU(duration=150)
54
  def generate_vton_final(face_image, body_type, height_ft):
55
- if face_image is None: return None, "Upload face image."
 
56
 
57
- # Trigger model loading within the GPU context
58
- load_models()
59
 
60
  # 1. Face Analysis
61
  img_np = np.array(face_image)
62
  cv_img = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
63
  faces = face_app.get(cv_img)
64
- if not faces: return None, "No face detected."
 
65
 
66
  source_face = faces[0]
67
  gender = "man" if source_face.gender == 1 else "woman"
68
 
69
- # 2. Simplified Prompt (Normal Pose)
70
  profile_seed = int(hashlib.md5(f"{gender}-{body_type}".encode()).hexdigest(), 16) % (10**9)
71
  generator = torch.Generator("cuda").manual_seed(profile_seed)
72
 
73
  prompt = (
74
- f"Full body photo of a {gender}, {body_type} build, {height_ft}ft tall. "
75
- f"Standing in a relaxed normal pose, facing camera, neutral expression. "
76
- f"Casual high-quality clothing, simple studio background, sharp focus, 8k ultra-hd."
77
  )
78
 
79
  # 3. Generation
@@ -91,25 +91,26 @@ def generate_vton_final(face_image, body_type, height_ft):
91
  res_cv = cv2.cvtColor(res_np, cv2.COLOR_RGB2BGR)
92
  target_faces = face_app.get(res_cv)
93
  if target_faces:
94
- # Swap onto the largest detected face
95
  target_faces = sorted(target_faces, key=lambda x: (x.bbox[2]-x.bbox[0])*(x.bbox[3]-x.bbox[1]), reverse=True)
96
  res_cv = swapper.get(res_cv, target_faces[0], source_face, paste_back=True)
97
  gen_img = Image.fromarray(cv2.cvtColor(res_cv, cv2.COLOR_BGR2RGB))
98
 
99
- return upscale_image(gen_img), f"Status: Complete | Seed: {profile_seed}"
 
100
 
101
- # --- GRADIO ---
102
- with gr.Blocks(css=".gradio-container {background-color: #f9f9f9}") as demo:
103
- gr.Markdown("## 💎 HD Virtual Model Generator")
104
  with gr.Row():
105
  with gr.Column():
106
- face_in = gr.Image(type="pil", label="Face Photo")
107
- body_in = gr.Radio(["slim", "muscular", "average"], value="average", label="Body Shape")
108
- h_in = gr.Slider(4.5, 6.5, value=5.7, step=0.1, label="Height (ft)")
109
- btn = gr.Button("Generate Model", variant="primary")
110
  with gr.Column():
111
- img_out = gr.Image(label="Result")
112
- status = gr.Textbox(label="System Logs")
113
 
114
  btn.click(generate_vton_final, [face_in, body_in, h_in], [img_out, status])
115
 
 
10
  from insightface.app import FaceAnalysis
11
  from insightface.model_zoo import get_model
12
 
13
+ # --- GLOBAL CONFIG ---
14
  MODEL_ID = "black-forest-labs/FLUX.1-dev"
15
  HF_TOKEN = os.getenv("HF_TOKEN")
16
 
17
+ # Initialize models as None for ZeroGPU lazy loading
18
  face_app = None
19
  swapper = None
20
  pipe = None
21
 
22
+ def load_models_on_gpu():
23
+ """Initializes models only when GPU is allocated."""
24
  global face_app, swapper, pipe
25
 
26
  if face_app is None:
 
27
  face_app = FaceAnalysis(name="buffalo_l", providers=['CPUExecutionProvider'])
28
  face_app.prepare(ctx_id=0, det_size=(640, 640))
29
 
 
38
  torch_dtype=torch.bfloat16,
39
  token=HF_TOKEN
40
  )
 
41
  pipe.enable_model_cpu_offload()
42
 
43
  def upscale_image(image):
 
50
 
51
  @spaces.GPU(duration=150)
52
  def generate_vton_final(face_image, body_type, height_ft):
53
+ if face_image is None:
54
+ return None, "Please upload a face image."
55
 
56
+ # Ensure models are loaded in the GPU context
57
+ load_models_on_gpu()
58
 
59
  # 1. Face Analysis
60
  img_np = np.array(face_image)
61
  cv_img = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
62
  faces = face_app.get(cv_img)
63
+ if not faces:
64
+ return None, "No face detected in the upload."
65
 
66
  source_face = faces[0]
67
  gender = "man" if source_face.gender == 1 else "woman"
68
 
69
+ # 2. Simplified Prompt (Normal Pose & Casual Clothes)
70
  profile_seed = int(hashlib.md5(f"{gender}-{body_type}".encode()).hexdigest(), 16) % (10**9)
71
  generator = torch.Generator("cuda").manual_seed(profile_seed)
72
 
73
  prompt = (
74
+ f"A full body 8k professional photo of a {gender}, {body_type} build, {height_ft}ft tall. "
75
+ f"Standing in a relaxed, natural pose, facing the camera. "
76
+ f"Wearing stylish casual clothing, clean studio background, sharp focus, cinematic lighting."
77
  )
78
 
79
  # 3. Generation
 
91
  res_cv = cv2.cvtColor(res_np, cv2.COLOR_RGB2BGR)
92
  target_faces = face_app.get(res_cv)
93
  if target_faces:
94
+ # Sort to find the main person in the photo
95
  target_faces = sorted(target_faces, key=lambda x: (x.bbox[2]-x.bbox[0])*(x.bbox[3]-x.bbox[1]), reverse=True)
96
  res_cv = swapper.get(res_cv, target_faces[0], source_face, paste_back=True)
97
  gen_img = Image.fromarray(cv2.cvtColor(res_cv, cv2.COLOR_BGR2RGB))
98
 
99
+ # 5. HD Upscale
100
+ return upscale_image(gen_img), f"Success | Seed: {profile_seed}"
101
 
102
+ # --- GRADIO INTERFACE ---
103
+ with gr.Blocks() as demo:
104
+ gr.Markdown("# 💎 AI Virtual Model Engine")
105
  with gr.Row():
106
  with gr.Column():
107
+ face_in = gr.Image(type="pil", label="Step 1: Upload Face")
108
+ body_in = gr.Radio(["slim", "muscular", "average"], value="average", label="Step 2: Body Build")
109
+ h_in = gr.Slider(4.5, 7.0, value=5.8, step=0.1, label="Step 3: Height (ft)")
110
+ btn = gr.Button("Generate High-Res Model", variant="primary")
111
  with gr.Column():
112
+ img_out = gr.Image(label="Final Result")
113
+ status = gr.Textbox(label="Logs")
114
 
115
  btn.click(generate_vton_final, [face_in, body_in, h_in], [img_out, status])
116