dlouapre HF Staff commited on
Commit
3a0c265
·
1 Parent(s): 2a3cabe

Removing nnsight imports

Browse files
Files changed (1) hide show
  1. steering.py +0 -83
steering.py CHANGED
@@ -1,47 +1,6 @@
1
  import torch
2
- from nnsight import LanguageModel
3
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
4
  from threading import Thread
5
- from huggingface_hub import hf_hub_download
6
-
7
-
8
- def load_saes(cfg, device):
9
- """Load steering vectors from SAEs and prepare steering components."""
10
- if not cfg['features'] or len(cfg['features']) == 0:
11
- print("No features specified, returning empty steering components.")
12
- return []
13
-
14
- steering_components = []
15
- cache_dir = "./downloads"
16
- features = cfg['features']
17
- reduced_strengths = cfg['reduced_strengths']
18
-
19
- for i, feature in enumerate(features):
20
- layer_idx, feature_idx = feature[0], feature[1]
21
- strength = feature[2] if len(feature) > 2 else 0.0
22
-
23
- # If the strengths in the config file were given in reduced form, scale them by layer index
24
- if reduced_strengths:
25
- strength *= layer_idx
26
-
27
- # Display strength (avoid division by zero)
28
- reduced_str = f"[{strength/layer_idx:.2f}]" if layer_idx > 0 else "[N/A]"
29
- print(f"Loading feature {layer_idx} {feature_idx} {strength:.2f} {reduced_str}")
30
-
31
- sae_filename = cfg['sae_filename_prefix'] + f"{layer_idx}" + cfg['sae_filename_suffix']
32
- file_path = hf_hub_download(repo_id=cfg['sae_path'], filename=sae_filename, cache_dir=cache_dir)
33
- sae = torch.load(file_path, map_location="cpu")
34
- vec = sae["decoder.weight"][:, feature_idx].to(device, non_blocking=True)
35
-
36
- steering_components.append({
37
- 'layer': layer_idx,
38
- 'feature': feature_idx,
39
- 'strength': strength,
40
- 'vector': vec
41
- })
42
- del sae
43
-
44
- return steering_components
45
 
46
 
47
  def load_saes_from_file(file_path, cfg, device):
@@ -112,48 +71,6 @@ def load_saes_from_file(file_path, cfg, device):
112
  return steering_components
113
 
114
 
115
- def generate_steered_answer(model: LanguageModel,
116
- chat,
117
- steering_components,
118
- max_new_tokens=128,
119
- temperature=0.0,
120
- repetition_penalty=1.0,
121
- clamp_intensity=False):
122
- """
123
- Generates an answer from the model given a chat history, applying steering components.
124
- Expects steering_components to be a list of dicts with keys:
125
- 'layer': int, layer index to apply steering
126
- 'strength': float, steering intensity
127
- 'vector': torch.Tensor, steering vector
128
- """
129
- input_ids = model.tokenizer.apply_chat_template(chat, tokenize=True, add_generation_prompt=True)
130
- with model.generate(max_new_tokens=max_new_tokens, repetition_penalty=repetition_penalty,
131
- do_sample=temperature > 0.0, temperature=temperature,
132
- pad_token_id=model.tokenizer.eos_token_id) as tracer:
133
- with tracer.invoke(input_ids):
134
- with tracer.all():
135
- for sc in steering_components:
136
- layer, strength, vector = sc["layer"], sc["strength"], sc["vector"]
137
-
138
- # Ensure vector matches model dtype and device
139
- layer_output = model.model.layers[layer].output
140
- vector = vector.to(dtype=layer_output.dtype, device=layer_output.device)
141
-
142
- length = layer_output.shape[1]
143
- amount = (strength * vector).unsqueeze(0).expand(length, -1).unsqueeze(0).clone()
144
- if clamp_intensity:
145
- projection = (layer_output @ vector).unsqueeze(-1)@(vector.unsqueeze(0))
146
- amount -= projection
147
-
148
- layer_output += amount
149
- with tracer.invoke():
150
- trace = model.generator.output.save()
151
-
152
- answer = model.tokenizer.decode(trace[0][len(input_ids):], skip_special_tokens=True)
153
- output = {'input_ids': input_ids, 'trace': trace, 'answer': answer}
154
- return output
155
-
156
-
157
 
158
  def create_steering_hook(layer_idx, steering_components, clamp_intensity=False):
159
  """
 
1
  import torch
 
2
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
3
  from threading import Thread
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
 
6
  def load_saes_from_file(file_path, cfg, device):
 
71
  return steering_components
72
 
73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
  def create_steering_hook(layer_idx, steering_components, clamp_intensity=False):
76
  """