ghosthets commited on
Commit
48cf71a
·
verified ·
1 Parent(s): 620d411

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -14
app.py CHANGED
@@ -1,38 +1,47 @@
1
  import flask
2
  from flask import request, jsonify
3
- from transformers import pipeline
4
  import torch
5
- import warnings # warning suppress करने के लिए
6
 
7
- # warnings को suppress करें, वर्ना CPU पर warnings आ सकती हैं
8
  warnings.filterwarnings("ignore")
9
 
10
  app = flask.Flask(__name__)
11
 
12
  # ===========================
13
- # LOAD MODEL (StableLM-3B-Chat)
 
14
  # ===========================
15
  model_id = "HuggingFaceTB/SmolLM-1.7B"
16
  print("🔄 Loading model...")
17
 
18
  # CPU/GPU device set
19
- # हम CPU पर लोड करते समय 'torch.bfloat16' का उपयोग करके मेमोरी को कम करने की कोशिश करेंगे।
20
  device = 0 if torch.cuda.is_available() else -1
21
- dtype = torch.float32 if device == -1 else torch.bfloat16 # CPU के लिए float32
 
22
 
23
  try:
 
 
 
 
 
 
 
24
  ai = pipeline(
25
  "text-generation",
26
  model=model_id,
 
27
  max_new_tokens=200,
28
  device=device,
29
- torch_dtype=dtype, # CPU/Memory optimization
30
- trust_remote_code=True # StableLM के लिए आवश्यक
31
  )
32
  print("✅ Model loaded!")
33
  except Exception as e:
34
  print(f"❌ Error loading model: {e}")
35
- ai = None # If load fails, prevent later API errors
36
 
37
  # ===========================
38
  # CHAT API
@@ -48,14 +57,24 @@ def chat():
48
  if not msg:
49
  return jsonify({"error": "No message sent"}), 400
50
 
51
- # StableLM Instruction Format:
52
- prompt = f"<|user|>\n{msg}<|end|>\n<|assistant|>"
53
 
54
  output = ai(prompt)[0]["generated_text"]
55
 
56
- # Output को clean करें ताकि सिर्फ assistant का जवाब मिले
57
- reply = output.split("<|assistant|>")[-1].strip()
58
-
 
 
 
 
 
 
 
 
 
 
59
  return jsonify({"reply": reply})
60
  except Exception as e:
61
  return jsonify({"error": str(e)}), 500
 
1
  import flask
2
  from flask import request, jsonify
3
+ from transformers import pipeline, AutoTokenizer # Added AutoTokenizer
4
  import torch
5
+ import warnings
6
 
7
+ # Suppress minor warnings that occur on CPU runs
8
  warnings.filterwarnings("ignore")
9
 
10
  app = flask.Flask(__name__)
11
 
12
  # ===========================
13
+ # LOAD MODEL (SmolLM-1.7B-Chat)
14
+ # This model is small (1.7B) and fully open-access.
15
  # ===========================
16
  model_id = "HuggingFaceTB/SmolLM-1.7B"
17
  print("🔄 Loading model...")
18
 
19
  # CPU/GPU device set
 
20
  device = 0 if torch.cuda.is_available() else -1
21
+ # Use float32 for CPU (or bfloat16 for GPU)
22
+ dtype = torch.float32 if device == -1 else torch.bfloat16
23
 
24
  try:
25
+ # 1. Load Tokenizer and set pad_token for stability
26
+ tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
27
+ if tokenizer.pad_token is None:
28
+ # Set pad_token to eos_token to fix generation warning/error
29
+ tokenizer.pad_token = tokenizer.eos_token
30
+
31
+ # 2. Load Pipeline with the fixed tokenizer
32
  ai = pipeline(
33
  "text-generation",
34
  model=model_id,
35
+ tokenizer=tokenizer, # Passing the configured tokenizer here
36
  max_new_tokens=200,
37
  device=device,
38
+ torch_dtype=dtype,
39
+ trust_remote_code=True
40
  )
41
  print("✅ Model loaded!")
42
  except Exception as e:
43
  print(f"❌ Error loading model: {e}")
44
+ ai = None
45
 
46
  # ===========================
47
  # CHAT API
 
57
  if not msg:
58
  return jsonify({"error": "No message sent"}), 400
59
 
60
+ # Instruction Format: Using a simple template for this model
61
+ prompt = f"User: {msg}\nAssistant:"
62
 
63
  output = ai(prompt)[0]["generated_text"]
64
 
65
+ # Clean the output to extract only the model's reply
66
+ # We split based on the 'Assistant:' tag in the prompt template
67
+ if "Assistant:" in output:
68
+ reply = output.split("Assistant:")[-1].strip()
69
+ elif "User:" in output: # Sometimes the model repeats the prompt
70
+ reply = output.split("User:")[0].strip()
71
+ else:
72
+ reply = output.strip()
73
+
74
+ # Remove any remaining instruction markers from the start
75
+ if reply.startswith(msg):
76
+ reply = reply[len(msg):].strip()
77
+
78
  return jsonify({"reply": reply})
79
  except Exception as e:
80
  return jsonify({"error": str(e)}), 500