dlouapre HF Staff commited on
Commit
c5681ae
·
1 Parent(s): 4dbfbc3

Creating the steering demo

Browse files
Files changed (5) hide show
  1. app.py +174 -53
  2. demo.yaml +35 -0
  3. requirements.txt +485 -0
  4. steering.py +286 -0
  5. steering_vectors.pt +3 -0
app.py CHANGED
@@ -1,70 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
-
4
-
5
- def respond(
6
- message,
7
- history: list[dict[str, str]],
8
- system_message,
9
- max_tokens,
10
- temperature,
11
- top_p,
12
- hf_token: gr.OAuthToken,
13
- ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  """
15
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  """
17
- client = InferenceClient(token=hf_token.token, model="openai/gpt-oss-20b")
18
 
19
- messages = [{"role": "system", "content": system_message}]
20
 
21
- messages.extend(history)
 
 
22
 
23
- messages.append({"role": "user", "content": message})
 
 
 
 
 
 
 
 
 
 
24
 
25
- response = ""
 
26
 
27
- for message in client.chat_completion(
28
- messages,
29
- max_tokens=max_tokens,
30
- stream=True,
31
- temperature=temperature,
32
- top_p=top_p,
 
 
 
 
33
  ):
34
- choices = message.choices
35
- token = ""
36
- if len(choices) and choices[0].delta.content:
37
- token = choices[0].delta.content
38
 
39
- response += token
40
- yield response
41
 
 
 
42
 
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- chatbot = gr.ChatInterface(
47
- respond,
48
- type="messages",
49
- additional_inputs=[
50
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
51
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
52
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
53
- gr.Slider(
54
- minimum=0.1,
55
- maximum=1.0,
56
- value=0.95,
57
- step=0.05,
58
- label="Top-p (nucleus sampling)",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  ),
60
- ],
61
- )
62
 
63
- with gr.Blocks() as demo:
64
- with gr.Sidebar():
65
- gr.LoginButton()
66
- chatbot.render()
67
 
68
 
69
  if __name__ == "__main__":
70
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio demo for steered LLM generation using SAE features.
3
+ Supports real-time streaming generation with HuggingFace Transformers.
4
+
5
+ IMPORTANT: Before running this app, you must extract steering vectors:
6
+ python extract_steering_vectors.py
7
+
8
+ This creates steering_vectors.pt which is much faster to load than
9
+ downloading full SAE files from HuggingFace Hub.
10
+
11
+ For HuggingFace Spaces ZeroGPU deployment, the @spaces.GPU decorator
12
+ ensures efficient GPU allocation only during inference.
13
+ """
14
  import gradio as gr
15
+ import torch
16
+ import yaml
17
+ import os
18
+
19
+ # ZeroGPU support for HuggingFace Spaces
20
+ try:
21
+ import spaces
22
+ SPACES_AVAILABLE = True
23
+ except ImportError:
24
+ SPACES_AVAILABLE = False
25
+ # Create a dummy decorator for local development
26
+ def spaces_gpu_decorator(func):
27
+ return func
28
+ spaces = type('spaces', (), {'GPU': spaces_gpu_decorator})()
29
+
30
+ from transformers import AutoModelForCausalLM, AutoTokenizer
31
+ from steering import load_saes_from_file, stream_steered_answer_hf
32
+
33
+ # Global variables
34
+ model = None
35
+ tokenizer = None
36
+ steering_components = None
37
+ cfg = None
38
+
39
+
40
+ def initialize_model():
41
+ """
42
+ Load model, SAEs, and configuration on startup.
43
+
44
+ For ZeroGPU: Model is loaded with device_map="auto" and will be automatically
45
+ moved to GPU when @spaces.GPU decorated functions are called. Steering vectors
46
+ are loaded on CPU initially and moved to GPU during inference.
47
  """
48
+ global model, tokenizer, steering_components, cfg
49
+
50
+ # Get HuggingFace token for gated models (if needed)
51
+ hf_token = os.getenv("HF_TOKEN", None)
52
+ if hf_token:
53
+ print("Using HF_TOKEN from environment")
54
+
55
+ print("Loading configuration...")
56
+ with open("demo.yaml", "r") as f:
57
+ cfg = yaml.safe_load(f)
58
+
59
+ # For ZeroGPU, we prefer CUDA but the actual allocation happens in @spaces.GPU functions
60
+ device = "cuda" if torch.cuda.is_available() else "cpu"
61
+
62
+ print(f"Loading model: {cfg['llm_name']}...")
63
+ print(f"Target device: {device} (ZeroGPU will manage allocation)" if SPACES_AVAILABLE else f"Target device: {device}")
64
+
65
+ model = AutoModelForCausalLM.from_pretrained(
66
+ cfg['llm_name'],
67
+ device_map="auto",
68
+ dtype=torch.float16 if device == "cuda" else torch.float32,
69
+ token=hf_token
70
+ )
71
+
72
+ tokenizer = AutoTokenizer.from_pretrained(cfg['llm_name'], token=hf_token)
73
+
74
+ print("Loading SAE steering components...")
75
+ # Use pre-extracted steering vectors for faster loading
76
+ # For ZeroGPU: vectors loaded on CPU, will be moved to GPU during inference
77
+ steering_vectors_file = "steering_vectors.pt"
78
+ load_device = "cpu" if SPACES_AVAILABLE else device
79
+ steering_components = load_saes_from_file(steering_vectors_file, cfg, load_device)
80
+ for i in range(len(steering_components)):
81
+ steering_components[i]['vector'] /= steering_components[i]['vector'].norm()
82
+
83
+ print("Model initialized successfully!")
84
+ return model, tokenizer, steering_components, cfg
85
+
86
+
87
+ @spaces.GPU
88
+ def chat_function(message, history):
89
  """
90
+ Handle chat interactions with steered generation and real-time streaming.
91
 
92
+ Decorated with @spaces.GPU to allocate GPU only during inference on HuggingFace Spaces.
93
 
94
+ Args:
95
+ message: User's input message
96
+ history: List of previous [user_msg, bot_msg] pairs from Gradio
97
 
98
+ Yields:
99
+ Partial text updates as tokens are generated
100
+ """
101
+ global model, tokenizer, steering_components, cfg
102
+
103
+ # Convert Gradio history format to chat format
104
+ chat = []
105
+ for user_msg, bot_msg in history:
106
+ chat.append({"role": "user", "content": user_msg})
107
+ if bot_msg is not None:
108
+ chat.append({"role": "assistant", "content": bot_msg})
109
 
110
+ # Add current message
111
+ chat.append({"role": "user", "content": message})
112
 
113
+ # Stream tokens as they are generated
114
+ for partial_text in stream_steered_answer_hf(
115
+ model=model,
116
+ tokenizer=tokenizer,
117
+ chat=chat,
118
+ steering_components=steering_components,
119
+ max_new_tokens=cfg['max_new_tokens'],
120
+ temperature=cfg['temperature'],
121
+ repetition_penalty=cfg['repetition_penalty'],
122
+ clamp_intensity=cfg['clamp_intensity']
123
  ):
124
+ yield partial_text
 
 
 
125
 
 
 
126
 
127
+ def create_demo():
128
+ """Create and configure the Gradio interface."""
129
 
130
+ # Custom CSS for better appearance
131
+ custom_css = """
132
+ .gradio-container {
133
+ font-family: 'Arial', sans-serif;
134
+ }
135
+ #chatbot {
136
+ height: 600px;
137
+ }
138
+ """
139
+
140
+ # Create the interface
141
+ demo = gr.ChatInterface(
142
+ fn=chat_function,
143
+ title="🎯 Steered LLM Demo with SAE Features",
144
+ description="""
145
+ This demo showcases **steered text generation** using Sparse Autoencoder (SAE) features.
146
+
147
+ The model (Llama 3.1 8B Instruct) has its activations modified using vectors extracted from SAEs,
148
+ resulting in controlled behavior changes during generation.
149
+
150
+ **Features:**
151
+ - Real-time streaming: tokens appear as they're generated ⚡
152
+ - Multi-turn conversations with full history
153
+ - SAE-based activation steering across multiple layers
154
+
155
+ Start chatting below!
156
+ """,
157
+ examples=[
158
+ "Explain how neural networks work.",
159
+ "Tell me a creative story about a robot.",
160
+ "What are the applications of AI in healthcare?"
161
+ ],
162
+ cache_examples=False,
163
+ theme=gr.themes.Soft(),
164
+ css=custom_css,
165
+ chatbot=gr.Chatbot(
166
+ elem_id="chatbot",
167
+ bubble_full_width=False,
168
+ show_copy_button=True
169
  ),
170
+ )
 
171
 
172
+ return demo
 
 
 
173
 
174
 
175
  if __name__ == "__main__":
176
+ print("=" * 60)
177
+ print("Steered LLM Demo - Initializing")
178
+ print("=" * 60)
179
+
180
+ initialize_model()
181
+
182
+ print("\n" + "=" * 60)
183
+ print("Launching Gradio interface...")
184
+ print("=" * 60 + "\n")
185
+
186
+ demo = create_demo()
187
+ demo.launch(
188
+ share=False, # Set to True for public link
189
+ server_name="0.0.0.0", # Allow external access
190
+ server_port=7860 # Default HF Spaces port
191
+ )
demo.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Model configuration
2
+ llm_name: "meta-llama/Llama-3.1-8B-Instruct"
3
+ sae_path: "andyrdt/saes-llama-3.1-8b-instruct"
4
+ sae_filename_prefix: "resid_post_layer_"
5
+ sae_filename_suffix: "/trainer_1/ae.pt"
6
+
7
+ reduced_strengths: false
8
+ features:
9
+ # - [3, 4774]
10
+ # - [3, 13935]
11
+ # - [3, 94572]
12
+ # - [3, 88169]
13
+ # - [3, 60537]
14
+ # - [3, 121375]
15
+ # - [7, 56243]
16
+ # - [7, 65190]
17
+ # - [7, 70732]
18
+ - [11, 74457, 1.03]
19
+ - [11, 18894, 1.42]
20
+ - [11, 61463, 1.77]
21
+ - [15, 21576, 4.85]
22
+ - [19, 93, 6.69]
23
+ - [23, 111898, 10.3]
24
+ - [23, 40788, 3.24]
25
+ - [23, 21334, 1.38]
26
+ # - [27, 52459]
27
+ # - [27, 86068]
28
+
29
+ # Generation parameters
30
+ temperature: 0.5
31
+ seed: 16
32
+ max_new_tokens: 256
33
+ repetition_penalty: 1.1
34
+ steer_prompt: true
35
+ clamp_intensity: true
requirements.txt ADDED
@@ -0,0 +1,485 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file was autogenerated by uv via the following command:
2
+ # uv pip compile pyproject.toml -o requirements.txt
3
+ accelerate==1.11.0
4
+ # via
5
+ # eiffel-demo (pyproject.toml)
6
+ # nnsight
7
+ # transformer-lens
8
+ aiofiles==24.1.0
9
+ # via gradio
10
+ aiohappyeyeballs==2.6.1
11
+ # via aiohttp
12
+ aiohttp==3.13.2
13
+ # via fsspec
14
+ aiosignal==1.4.0
15
+ # via aiohttp
16
+ annotated-doc==0.0.3
17
+ # via fastapi
18
+ annotated-types==0.7.0
19
+ # via pydantic
20
+ anyio==4.11.0
21
+ # via
22
+ # gradio
23
+ # httpx
24
+ # starlette
25
+ astor==0.8.1
26
+ # via nnsight
27
+ asttokens==3.0.0
28
+ # via stack-data
29
+ attrs==25.4.0
30
+ # via aiohttp
31
+ babe==0.0.7
32
+ # via sae-lens
33
+ beartype==0.14.1
34
+ # via transformer-lens
35
+ better-abc==0.0.3
36
+ # via transformer-lens
37
+ bidict==0.23.1
38
+ # via python-socketio
39
+ brotli==1.1.0
40
+ # via gradio
41
+ certifi==2025.10.5
42
+ # via
43
+ # httpcore
44
+ # httpx
45
+ # requests
46
+ # sentry-sdk
47
+ charset-normalizer==3.4.4
48
+ # via requests
49
+ click==8.3.0
50
+ # via
51
+ # nltk
52
+ # typer
53
+ # uvicorn
54
+ # wandb
55
+ cloudpickle==3.1.2
56
+ # via nnsight
57
+ config2py==0.1.42
58
+ # via py2store
59
+ datasets==4.4.0
60
+ # via
61
+ # sae-lens
62
+ # transformer-lens
63
+ decorator==5.2.1
64
+ # via ipython
65
+ dill==0.4.0
66
+ # via
67
+ # datasets
68
+ # multiprocess
69
+ docstring-parser==0.17.0
70
+ # via simple-parsing
71
+ dol==0.3.31
72
+ # via
73
+ # config2py
74
+ # graze
75
+ # py2store
76
+ einops==0.8.1
77
+ # via transformer-lens
78
+ executing==2.2.1
79
+ # via stack-data
80
+ fancy-einsum==0.0.3
81
+ # via transformer-lens
82
+ fastapi==0.121.0
83
+ # via gradio
84
+ ffmpy==0.6.4
85
+ # via gradio
86
+ filelock==3.20.0
87
+ # via
88
+ # datasets
89
+ # huggingface-hub
90
+ # torch
91
+ # transformers
92
+ frozenlist==1.8.0
93
+ # via
94
+ # aiohttp
95
+ # aiosignal
96
+ fsspec==2025.10.0
97
+ # via
98
+ # datasets
99
+ # gradio-client
100
+ # huggingface-hub
101
+ # torch
102
+ gitdb==4.0.12
103
+ # via gitpython
104
+ gitpython==3.1.45
105
+ # via wandb
106
+ gradio==5.49.1
107
+ # via eiffel-demo (pyproject.toml)
108
+ gradio-client==1.13.3
109
+ # via gradio
110
+ graze==0.1.39
111
+ # via babe
112
+ groovy==0.1.2
113
+ # via gradio
114
+ h11==0.16.0
115
+ # via
116
+ # httpcore
117
+ # uvicorn
118
+ # wsproto
119
+ hf-transfer==0.1.9
120
+ # via eiffel-demo (pyproject.toml)
121
+ hf-xet==1.2.0
122
+ # via huggingface-hub
123
+ httpcore==1.0.9
124
+ # via httpx
125
+ httpx==0.28.1
126
+ # via
127
+ # datasets
128
+ # gradio
129
+ # gradio-client
130
+ # safehttpx
131
+ huggingface-hub==0.36.0
132
+ # via
133
+ # accelerate
134
+ # datasets
135
+ # gradio
136
+ # gradio-client
137
+ # tokenizers
138
+ # transformers
139
+ i2==0.1.58
140
+ # via config2py
141
+ idna==3.11
142
+ # via
143
+ # anyio
144
+ # httpx
145
+ # requests
146
+ # yarl
147
+ importlib-resources==6.5.2
148
+ # via py2store
149
+ ipython==9.6.0
150
+ # via nnsight
151
+ ipython-pygments-lexers==1.1.1
152
+ # via ipython
153
+ jaxtyping==0.3.3
154
+ # via transformer-lens
155
+ jedi==0.19.2
156
+ # via ipython
157
+ jinja2==3.1.6
158
+ # via
159
+ # gradio
160
+ # torch
161
+ joblib==1.5.2
162
+ # via nltk
163
+ markdown-it-py==4.0.0
164
+ # via rich
165
+ markupsafe==3.0.3
166
+ # via
167
+ # gradio
168
+ # jinja2
169
+ matplotlib-inline==0.2.1
170
+ # via ipython
171
+ mdurl==0.1.2
172
+ # via markdown-it-py
173
+ mpmath==1.3.0
174
+ # via sympy
175
+ multidict==6.7.0
176
+ # via
177
+ # aiohttp
178
+ # yarl
179
+ multiprocess==0.70.18
180
+ # via datasets
181
+ narwhals==2.10.1
182
+ # via plotly
183
+ networkx==3.5
184
+ # via torch
185
+ nltk==3.9.2
186
+ # via sae-lens
187
+ nnsight==0.5.10
188
+ # via eiffel-demo (pyproject.toml)
189
+ numpy==1.26.4
190
+ # via
191
+ # accelerate
192
+ # datasets
193
+ # gradio
194
+ # pandas
195
+ # patsy
196
+ # plotly-express
197
+ # scipy
198
+ # statsmodels
199
+ # transformer-lens
200
+ # transformers
201
+ nvidia-cublas-cu12==12.8.4.1
202
+ # via
203
+ # nvidia-cudnn-cu12
204
+ # nvidia-cusolver-cu12
205
+ # torch
206
+ nvidia-cuda-cupti-cu12==12.8.90
207
+ # via torch
208
+ nvidia-cuda-nvrtc-cu12==12.8.93
209
+ # via torch
210
+ nvidia-cuda-runtime-cu12==12.8.90
211
+ # via torch
212
+ nvidia-cudnn-cu12==9.10.2.21
213
+ # via torch
214
+ nvidia-cufft-cu12==11.3.3.83
215
+ # via torch
216
+ nvidia-cufile-cu12==1.13.1.3
217
+ # via torch
218
+ nvidia-curand-cu12==10.3.9.90
219
+ # via torch
220
+ nvidia-cusolver-cu12==11.7.3.90
221
+ # via torch
222
+ nvidia-cusparse-cu12==12.5.8.93
223
+ # via
224
+ # nvidia-cusolver-cu12
225
+ # torch
226
+ nvidia-cusparselt-cu12==0.7.1
227
+ # via torch
228
+ nvidia-nccl-cu12==2.27.5
229
+ # via torch
230
+ nvidia-nvjitlink-cu12==12.8.93
231
+ # via
232
+ # nvidia-cufft-cu12
233
+ # nvidia-cusolver-cu12
234
+ # nvidia-cusparse-cu12
235
+ # torch
236
+ nvidia-nvshmem-cu12==3.3.20
237
+ # via torch
238
+ nvidia-nvtx-cu12==12.8.90
239
+ # via torch
240
+ orjson==3.11.4
241
+ # via gradio
242
+ packaging==25.0
243
+ # via
244
+ # accelerate
245
+ # datasets
246
+ # gradio
247
+ # gradio-client
248
+ # huggingface-hub
249
+ # plotly
250
+ # statsmodels
251
+ # transformers
252
+ # wandb
253
+ pandas==2.3.3
254
+ # via
255
+ # babe
256
+ # datasets
257
+ # gradio
258
+ # plotly-express
259
+ # statsmodels
260
+ # transformer-lens
261
+ parso==0.8.5
262
+ # via jedi
263
+ patsy==1.0.2
264
+ # via
265
+ # plotly-express
266
+ # statsmodels
267
+ pexpect==4.9.0
268
+ # via ipython
269
+ pillow==11.3.0
270
+ # via gradio
271
+ platformdirs==4.5.0
272
+ # via wandb
273
+ plotly==6.3.1
274
+ # via
275
+ # plotly-express
276
+ # sae-lens
277
+ plotly-express==0.4.1
278
+ # via sae-lens
279
+ prompt-toolkit==3.0.52
280
+ # via ipython
281
+ propcache==0.4.1
282
+ # via
283
+ # aiohttp
284
+ # yarl
285
+ protobuf==6.33.0
286
+ # via wandb
287
+ psutil==7.1.3
288
+ # via accelerate
289
+ ptyprocess==0.7.0
290
+ # via pexpect
291
+ pure-eval==0.2.3
292
+ # via stack-data
293
+ py2store==0.1.22
294
+ # via babe
295
+ pyarrow==22.0.0
296
+ # via datasets
297
+ pydantic==2.11.10
298
+ # via
299
+ # fastapi
300
+ # gradio
301
+ # nnsight
302
+ # wandb
303
+ pydantic-core==2.33.2
304
+ # via pydantic
305
+ pydub==0.25.1
306
+ # via gradio
307
+ pygments==2.19.2
308
+ # via
309
+ # ipython
310
+ # ipython-pygments-lexers
311
+ # rich
312
+ python-dateutil==2.9.0.post0
313
+ # via pandas
314
+ python-dotenv==1.2.1
315
+ # via sae-lens
316
+ python-engineio==4.12.3
317
+ # via python-socketio
318
+ python-multipart==0.0.20
319
+ # via gradio
320
+ python-socketio==5.14.3
321
+ # via nnsight
322
+ pytz==2025.2
323
+ # via pandas
324
+ pyyaml==6.0.3
325
+ # via
326
+ # eiffel-demo (pyproject.toml)
327
+ # accelerate
328
+ # datasets
329
+ # gradio
330
+ # huggingface-hub
331
+ # sae-lens
332
+ # transformers
333
+ # wandb
334
+ regex==2025.11.3
335
+ # via
336
+ # nltk
337
+ # transformers
338
+ requests==2.32.5
339
+ # via
340
+ # datasets
341
+ # graze
342
+ # huggingface-hub
343
+ # python-socketio
344
+ # transformers
345
+ # wandb
346
+ rich==14.2.0
347
+ # via
348
+ # nnsight
349
+ # transformer-lens
350
+ # typer
351
+ ruff==0.14.3
352
+ # via gradio
353
+ sae-lens==6.21.0
354
+ # via eiffel-demo (pyproject.toml)
355
+ safehttpx==0.1.7
356
+ # via gradio
357
+ safetensors==0.6.2
358
+ # via
359
+ # accelerate
360
+ # sae-lens
361
+ # transformers
362
+ scipy==1.16.3
363
+ # via
364
+ # plotly-express
365
+ # statsmodels
366
+ semantic-version==2.10.0
367
+ # via gradio
368
+ sentencepiece==0.2.1
369
+ # via transformer-lens
370
+ sentry-sdk==2.43.0
371
+ # via wandb
372
+ shellingham==1.5.4
373
+ # via typer
374
+ simple-parsing==0.1.7
375
+ # via sae-lens
376
+ simple-websocket==1.1.0
377
+ # via python-engineio
378
+ six==1.17.0
379
+ # via python-dateutil
380
+ smmap==5.0.2
381
+ # via gitdb
382
+ sniffio==1.3.1
383
+ # via anyio
384
+ stack-data==0.6.3
385
+ # via ipython
386
+ starlette==0.49.3
387
+ # via
388
+ # fastapi
389
+ # gradio
390
+ statsmodels==0.14.5
391
+ # via plotly-express
392
+ sympy==1.14.0
393
+ # via torch
394
+ tenacity==9.1.2
395
+ # via sae-lens
396
+ tokenizers==0.22.1
397
+ # via transformers
398
+ toml==0.10.2
399
+ # via nnsight
400
+ tomlkit==0.13.3
401
+ # via gradio
402
+ torch==2.9.0
403
+ # via
404
+ # eiffel-demo (pyproject.toml)
405
+ # accelerate
406
+ # nnsight
407
+ # transformer-lens
408
+ tqdm==4.67.1
409
+ # via
410
+ # datasets
411
+ # huggingface-hub
412
+ # nltk
413
+ # transformer-lens
414
+ # transformers
415
+ traitlets==5.14.3
416
+ # via
417
+ # ipython
418
+ # matplotlib-inline
419
+ transformer-lens==2.16.1
420
+ # via sae-lens
421
+ transformers==4.57.1
422
+ # via
423
+ # eiffel-demo (pyproject.toml)
424
+ # nnsight
425
+ # sae-lens
426
+ # transformer-lens
427
+ # transformers-stream-generator
428
+ transformers-stream-generator==0.0.5
429
+ # via transformer-lens
430
+ triton==3.5.0
431
+ # via torch
432
+ typeguard==4.4.4
433
+ # via transformer-lens
434
+ typer==0.20.0
435
+ # via gradio
436
+ typing-extensions==4.15.0
437
+ # via
438
+ # aiosignal
439
+ # anyio
440
+ # fastapi
441
+ # gradio
442
+ # gradio-client
443
+ # huggingface-hub
444
+ # ipython
445
+ # pydantic
446
+ # pydantic-core
447
+ # sae-lens
448
+ # simple-parsing
449
+ # starlette
450
+ # torch
451
+ # transformer-lens
452
+ # typeguard
453
+ # typer
454
+ # typing-inspection
455
+ # wandb
456
+ typing-inspection==0.4.2
457
+ # via pydantic
458
+ tzdata==2025.2
459
+ # via pandas
460
+ urllib3==2.5.0
461
+ # via
462
+ # requests
463
+ # sentry-sdk
464
+ uvicorn==0.38.0
465
+ # via gradio
466
+ wadler-lindig==0.1.7
467
+ # via jaxtyping
468
+ wandb==0.22.3
469
+ # via transformer-lens
470
+ wcwidth==0.2.14
471
+ # via prompt-toolkit
472
+ websocket-client==1.9.0
473
+ # via python-socketio
474
+ websockets==15.0.1
475
+ # via gradio-client
476
+ wsproto==1.2.0
477
+ # via simple-websocket
478
+ xxhash==3.6.0
479
+ # via datasets
480
+ yarl==1.22.0
481
+ # via aiohttp
482
+
483
+ # HuggingFace Spaces ZeroGPU support
484
+ spaces==0.28.3
485
+ # via eiffel-demo (for ZeroGPU deployment)
steering.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from nnsight import LanguageModel
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
4
+ from threading import Thread
5
+ from huggingface_hub import hf_hub_download
6
+
7
+
8
+ def load_saes(cfg, device):
9
+ """Load steering vectors from SAEs and prepare steering components."""
10
+ if not cfg['features'] or len(cfg['features']) == 0:
11
+ print("No features specified, returning empty steering components.")
12
+ return []
13
+
14
+ steering_components = []
15
+ cache_dir = "./downloads"
16
+ features = cfg['features']
17
+ reduced_strengths = cfg['reduced_strengths']
18
+
19
+ for i, feature in enumerate(features):
20
+ layer_idx, feature_idx = feature[0], feature[1]
21
+ strength = feature[2] if len(feature) > 2 else 0.0
22
+
23
+ # If the strengths in the config file were given in reduced form, scale them by layer index
24
+ if reduced_strengths:
25
+ strength *= layer_idx
26
+
27
+ # Display strength (avoid division by zero)
28
+ reduced_str = f"[{strength/layer_idx:.2f}]" if layer_idx > 0 else "[N/A]"
29
+ print(f"Loading feature {layer_idx} {feature_idx} {strength:.2f} {reduced_str}")
30
+
31
+ sae_filename = cfg['sae_filename_prefix'] + f"{layer_idx}" + cfg['sae_filename_suffix']
32
+ file_path = hf_hub_download(repo_id=cfg['sae_path'], filename=sae_filename, cache_dir=cache_dir)
33
+ sae = torch.load(file_path, map_location="cpu")
34
+ vec = sae["decoder.weight"][:, feature_idx].to(device, non_blocking=True)
35
+
36
+ steering_components.append({
37
+ 'layer': layer_idx,
38
+ 'feature': feature_idx,
39
+ 'strength': strength,
40
+ 'vector': vec
41
+ })
42
+ del sae
43
+
44
+ return steering_components
45
+
46
+
47
+ def load_saes_from_file(file_path, cfg, device):
48
+ """
49
+ Load pre-extracted steering vectors from a local file.
50
+
51
+ This is much faster than load_saes() since it doesn't download large SAE files.
52
+ The file should be created using extract_steering_vectors.py script.
53
+
54
+ Args:
55
+ file_path: Path to the .pt file containing steering vectors
56
+ cfg: Configuration dict with 'features' list
57
+ device: Device to load tensors on ('cuda' or 'cpu')
58
+
59
+ Returns:
60
+ List of steering component dicts with keys: 'layer', 'feature', 'strength', 'vector'
61
+ """
62
+ import os
63
+
64
+ if not os.path.exists(file_path):
65
+ raise FileNotFoundError(
66
+ f"Steering vectors file not found: {file_path}\n"
67
+ f"Please run: python extract_steering_vectors.py"
68
+ )
69
+
70
+ print(f"Loading pre-extracted steering vectors from {file_path}...")
71
+
72
+ # Load the dictionary of vectors
73
+ steering_vectors_dict = torch.load(file_path, map_location="cpu")
74
+
75
+ if not cfg['features'] or len(cfg['features']) == 0:
76
+ print("No features specified in config.")
77
+ return []
78
+
79
+ steering_components = []
80
+ features = cfg['features']
81
+ reduced_strengths = cfg.get('reduced_strengths', False)
82
+
83
+ for i, feature in enumerate(features):
84
+ layer_idx, feature_idx = feature[0], feature[1]
85
+ strength = feature[2] if len(feature) > 2 else 0.0
86
+
87
+ if reduced_strengths:
88
+ strength *= layer_idx
89
+
90
+ # Look up the pre-extracted vector
91
+ key = (layer_idx, feature_idx)
92
+ if key not in steering_vectors_dict:
93
+ raise KeyError(
94
+ f"Vector for layer {layer_idx}, feature {feature_idx} not found in {file_path}.\n"
95
+ f"Please re-run: python extract_steering_vectors.py"
96
+ )
97
+
98
+ vec = steering_vectors_dict[key].to(device, non_blocking=True)
99
+
100
+ # Display
101
+ reduced_str = f"[{strength/layer_idx:.2f}]" if layer_idx > 0 else "[N/A]"
102
+ print(f"Loaded feature {layer_idx} {feature_idx} {strength:.2f} {reduced_str}")
103
+
104
+ steering_components.append({
105
+ 'layer': layer_idx,
106
+ 'feature': feature_idx,
107
+ 'strength': strength,
108
+ 'vector': vec # Already normalized in the file
109
+ })
110
+
111
+ print(f"Loaded {len(steering_components)} steering vector(s) from local file")
112
+ return steering_components
113
+
114
+
115
+ def generate_steered_answer(model: LanguageModel,
116
+ chat,
117
+ steering_components,
118
+ max_new_tokens=128,
119
+ temperature=0.0,
120
+ repetition_penalty=1.0,
121
+ clamp_intensity=False):
122
+ """
123
+ Generates an answer from the model given a chat history, applying steering components.
124
+ Expects steering_components to be a list of dicts with keys:
125
+ 'layer': int, layer index to apply steering
126
+ 'strength': float, steering intensity
127
+ 'vector': torch.Tensor, steering vector
128
+ """
129
+ input_ids = model.tokenizer.apply_chat_template(chat, tokenize=True, add_generation_prompt=True)
130
+ with model.generate(max_new_tokens=max_new_tokens, repetition_penalty=repetition_penalty,
131
+ do_sample=temperature > 0.0, temperature=temperature,
132
+ pad_token_id=model.tokenizer.eos_token_id) as tracer:
133
+ with tracer.invoke(input_ids):
134
+ with tracer.all():
135
+ for sc in steering_components:
136
+ layer, strength, vector = sc["layer"], sc["strength"], sc["vector"]
137
+
138
+ # Ensure vector matches model dtype and device
139
+ layer_output = model.model.layers[layer].output
140
+ vector = vector.to(dtype=layer_output.dtype, device=layer_output.device)
141
+
142
+ length = layer_output.shape[1]
143
+ amount = (strength * vector).unsqueeze(0).expand(length, -1).unsqueeze(0).clone()
144
+ if clamp_intensity:
145
+ projection = (layer_output @ vector).unsqueeze(-1)@(vector.unsqueeze(0))
146
+ amount -= projection
147
+
148
+ layer_output += amount
149
+ with tracer.invoke():
150
+ trace = model.generator.output.save()
151
+
152
+ answer = model.tokenizer.decode(trace[0][len(input_ids):], skip_special_tokens=True)
153
+ output = {'input_ids': input_ids, 'trace': trace, 'answer': answer}
154
+ return output
155
+
156
+
157
+
158
+ def create_steering_hook(layer_idx, steering_components, clamp_intensity=False):
159
+ """
160
+ Create a forward hook for a specific layer that applies steering.
161
+
162
+ Args:
163
+ layer_idx: Which layer this hook is for
164
+ steering_components: List of steering components (all layers)
165
+ clamp_intensity: Whether to clamp steering intensity
166
+
167
+ Returns:
168
+ Forward hook function
169
+ """
170
+ layer_components = [sc for sc in steering_components if sc['layer'] == layer_idx]
171
+
172
+ if not layer_components:
173
+ return None
174
+
175
+ def hook(module, input, output):
176
+ """Forward hook that modifies the output hidden states."""
177
+ # Handle different output formats (tuple vs tensor)
178
+ if isinstance(output, tuple):
179
+ hidden_states = output[0]
180
+ rest_of_output = output[1:]
181
+ else:
182
+ hidden_states = output
183
+ rest_of_output = None
184
+
185
+ # Handle different shapes during generation
186
+ original_shape = hidden_states.shape
187
+ if len(original_shape) == 2:
188
+ # During generation: [batch, hidden_dim] -> add seq_len dimension
189
+ hidden_states = hidden_states.unsqueeze(1) # [batch, 1, hidden_dim]
190
+
191
+ for sc in layer_components:
192
+ strength = sc['strength']
193
+ vector = sc['vector'] # Already normalized
194
+
195
+ # Ensure vector matches hidden_states dtype and device
196
+ vector = vector.to(dtype=hidden_states.dtype, device=hidden_states.device)
197
+
198
+ # Match nnsight's expansion pattern exactly
199
+ seq_len = hidden_states.shape[1]
200
+ amount = (strength * vector).unsqueeze(0).expand(seq_len, -1).unsqueeze(0) # [1, seq_len, hidden_dim]
201
+
202
+ if clamp_intensity:
203
+ # Remove existing projection (prevents over-steering)
204
+ projection_scalars = torch.einsum('bsh,h->bs', hidden_states, vector).unsqueeze(-1)
205
+ projection_vectors = projection_scalars * vector.view(1, 1, -1)
206
+ amount = amount - projection_vectors
207
+
208
+ hidden_states = hidden_states + amount
209
+
210
+ # Restore original shape if we added a dimension
211
+ if len(original_shape) == 2:
212
+ hidden_states = hidden_states.squeeze(1) # [batch, hidden_dim]
213
+
214
+ # Return in the same format as input
215
+ if rest_of_output is not None:
216
+ return (hidden_states,) + rest_of_output
217
+ else:
218
+ return hidden_states
219
+
220
+ return hook
221
+
222
+
223
+ def stream_steered_answer_hf(model: AutoModelForCausalLM,
224
+ tokenizer: AutoTokenizer,
225
+ chat,
226
+ steering_components,
227
+ max_new_tokens=128,
228
+ temperature=0.0,
229
+ repetition_penalty=1.0,
230
+ clamp_intensity=False,
231
+ stream=True):
232
+ """
233
+ Generate steered answer using pure HuggingFace Transformers with streaming.
234
+
235
+ Args:
236
+ model: HuggingFace transformers model
237
+ tokenizer: Tokenizer instance
238
+ chat: Chat history in OpenAI format
239
+ steering_components: List of dicts with 'layer', 'strength', 'vector'
240
+ max_new_tokens: Maximum tokens to generate
241
+ temperature: Sampling temperature (0 = greedy)
242
+ repetition_penalty: Repetition penalty
243
+ clamp_intensity: Whether to clamp steering intensity
244
+
245
+ Yields:
246
+ Partial text as tokens are generated
247
+
248
+ """
249
+
250
+ input_ids_list = tokenizer.apply_chat_template(chat, tokenize=True, add_generation_prompt=True)
251
+ input_ids = torch.tensor([input_ids_list]).to(model.device)
252
+
253
+ # Register steering hooks
254
+ hook_handles = []
255
+ layers_to_steer = set(sc['layer'] for sc in steering_components)
256
+
257
+ for layer_idx in layers_to_steer:
258
+ hook_fn = create_steering_hook(layer_idx, steering_components, clamp_intensity)
259
+ if hook_fn:
260
+ layer_module = model.model.layers[layer_idx]
261
+ handle = layer_module.register_forward_hook(hook_fn)
262
+ hook_handles.append(handle)
263
+
264
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
265
+ generation_kwargs = {
266
+ "input_ids": input_ids,
267
+ "max_new_tokens": max_new_tokens,
268
+ "temperature": temperature if temperature > 0 else 1.0,
269
+ "do_sample": temperature > 0,
270
+ "repetition_penalty": repetition_penalty,
271
+ "streamer": streamer,
272
+ "pad_token_id": tokenizer.eos_token_id,
273
+ }
274
+
275
+ thread = Thread(target=lambda: model.generate(**generation_kwargs))
276
+ thread.start()
277
+
278
+ generated_text = ""
279
+ for token_text in streamer:
280
+ generated_text += token_text
281
+ yield generated_text
282
+
283
+ thread.join()
284
+
285
+ for handle in hook_handles:
286
+ handle.remove()
steering_vectors.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ba54a67bef9880b37df42668de7b5561e886bb3be591535409740d56f445f287
3
+ size 134539