Patryk Studzinski commited on
Commit
e0c72ee
·
1 Parent(s): 04a726c

fix: Add bitsandbytes to requirements and graceful fallback for 8-bit quantization

Browse files
Files changed (2) hide show
  1. app/models/huggingface_local.py +32 -11
  2. requirements.txt +8 -5
app/models/huggingface_local.py CHANGED
@@ -8,13 +8,21 @@ Optimizations:
8
  """
9
 
10
  from typing import List, Dict, Any, Optional
11
- from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
12
  import torch
13
  import asyncio
14
  import os
15
 
16
  from app.models.base_llm import BaseLLM
17
 
 
 
 
 
 
 
 
 
18
 
19
  class HuggingFaceLocal(BaseLLM):
20
  """
@@ -35,7 +43,14 @@ class HuggingFaceLocal(BaseLLM):
35
  self.tokenizer = None
36
  self.model = None
37
  self.use_cache = use_cache
38
- self.use_8bit = use_8bit or (device == "cpu" and os.getenv("USE_8BIT_QUANTIZATION", "true").lower() == "true")
 
 
 
 
 
 
 
39
  self.use_flash_attention = os.getenv("USE_FLASH_ATTENTION", "true").lower() == "true"
40
 
41
  # Determine device index and dtype
@@ -68,15 +83,21 @@ class HuggingFaceLocal(BaseLLM):
68
  }
69
 
70
  # Add 8-bit quantization for CPU (4-6x faster, 50% less memory)
71
- if self.use_8bit:
72
- print(f"[{self.name}] Using 8-bit quantization for CPU optimization")
73
- bnb_config = BitsAndBytesConfig(
74
- load_in_8bit=True,
75
- bnb_8bit_compute_dtype=torch.float16,
76
- bnb_8bit_use_double_quant=True,
77
- )
78
- model_kwargs["quantization_config"] = bnb_config
79
- model_kwargs["device_map"] = "cpu"
 
 
 
 
 
 
80
  else:
81
  model_kwargs["torch_dtype"] = self.torch_dtype
82
  model_kwargs["device_map"] = self.device if self.device == "cuda" else "cpu"
 
8
  """
9
 
10
  from typing import List, Dict, Any, Optional
11
+ from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
12
  import torch
13
  import asyncio
14
  import os
15
 
16
  from app.models.base_llm import BaseLLM
17
 
18
+ # Try to import bitsandbytes, but don't fail if not available
19
+ try:
20
+ from transformers import BitsAndBytesConfig
21
+ HAS_BITSANDBYTES = True
22
+ except ImportError:
23
+ HAS_BITSANDBYTES = False
24
+ print("[WARNING] bitsandbytes not available - 8-bit quantization disabled")
25
+
26
 
27
  class HuggingFaceLocal(BaseLLM):
28
  """
 
43
  self.tokenizer = None
44
  self.model = None
45
  self.use_cache = use_cache
46
+
47
+ # Only enable 8-bit if bitsandbytes is available
48
+ requested_8bit = use_8bit or (device == "cpu" and os.getenv("USE_8BIT_QUANTIZATION", "true").lower() == "true")
49
+ self.use_8bit = requested_8bit and HAS_BITSANDBYTES
50
+
51
+ if requested_8bit and not HAS_BITSANDBYTES:
52
+ print(f"[{name}] 8-bit quantization requested but bitsandbytes not installed - falling back to full precision")
53
+
54
  self.use_flash_attention = os.getenv("USE_FLASH_ATTENTION", "true").lower() == "true"
55
 
56
  # Determine device index and dtype
 
83
  }
84
 
85
  # Add 8-bit quantization for CPU (4-6x faster, 50% less memory)
86
+ if self.use_8bit and HAS_BITSANDBYTES:
87
+ try:
88
+ print(f"[{self.name}] Using 8-bit quantization for CPU optimization")
89
+ bnb_config = BitsAndBytesConfig(
90
+ load_in_8bit=True,
91
+ bnb_8bit_compute_dtype=torch.float16,
92
+ bnb_8bit_use_double_quant=True,
93
+ )
94
+ model_kwargs["quantization_config"] = bnb_config
95
+ model_kwargs["device_map"] = "cpu"
96
+ except Exception as e:
97
+ print(f"[{self.name}] Failed to setup 8-bit quantization: {e}")
98
+ print(f"[{self.name}] Falling back to full precision")
99
+ model_kwargs["torch_dtype"] = self.torch_dtype
100
+ model_kwargs["device_map"] = self.device if self.device == "cuda" else "cpu"
101
  else:
102
  model_kwargs["torch_dtype"] = self.torch_dtype
103
  model_kwargs["device_map"] = self.device if self.device == "cuda" else "cpu"
requirements.txt CHANGED
@@ -1,5 +1,8 @@
1
- fastapi
2
- uvicorn[standard]
3
- transformers[torch]
4
- accelerate
5
- huggingface_hub
 
 
 
 
1
+ fastapi==0.104.1
2
+ uvicorn[standard]==0.24.0
3
+ transformers==4.36.2
4
+ accelerate==0.25.0
5
+ huggingface_hub==0.19.4
6
+ bitsandbytes==0.49.0
7
+ torch>=2.1.0
8
+ pydantic==2.5.0