Spaces:
Running
on
Zero
Running
on
Zero
Removing nnsight imports
Browse files- steering.py +0 -83
steering.py
CHANGED
|
@@ -1,47 +1,6 @@
|
|
| 1 |
import torch
|
| 2 |
-
from nnsight import LanguageModel
|
| 3 |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
| 4 |
from threading import Thread
|
| 5 |
-
from huggingface_hub import hf_hub_download
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
def load_saes(cfg, device):
|
| 9 |
-
"""Load steering vectors from SAEs and prepare steering components."""
|
| 10 |
-
if not cfg['features'] or len(cfg['features']) == 0:
|
| 11 |
-
print("No features specified, returning empty steering components.")
|
| 12 |
-
return []
|
| 13 |
-
|
| 14 |
-
steering_components = []
|
| 15 |
-
cache_dir = "./downloads"
|
| 16 |
-
features = cfg['features']
|
| 17 |
-
reduced_strengths = cfg['reduced_strengths']
|
| 18 |
-
|
| 19 |
-
for i, feature in enumerate(features):
|
| 20 |
-
layer_idx, feature_idx = feature[0], feature[1]
|
| 21 |
-
strength = feature[2] if len(feature) > 2 else 0.0
|
| 22 |
-
|
| 23 |
-
# If the strengths in the config file were given in reduced form, scale them by layer index
|
| 24 |
-
if reduced_strengths:
|
| 25 |
-
strength *= layer_idx
|
| 26 |
-
|
| 27 |
-
# Display strength (avoid division by zero)
|
| 28 |
-
reduced_str = f"[{strength/layer_idx:.2f}]" if layer_idx > 0 else "[N/A]"
|
| 29 |
-
print(f"Loading feature {layer_idx} {feature_idx} {strength:.2f} {reduced_str}")
|
| 30 |
-
|
| 31 |
-
sae_filename = cfg['sae_filename_prefix'] + f"{layer_idx}" + cfg['sae_filename_suffix']
|
| 32 |
-
file_path = hf_hub_download(repo_id=cfg['sae_path'], filename=sae_filename, cache_dir=cache_dir)
|
| 33 |
-
sae = torch.load(file_path, map_location="cpu")
|
| 34 |
-
vec = sae["decoder.weight"][:, feature_idx].to(device, non_blocking=True)
|
| 35 |
-
|
| 36 |
-
steering_components.append({
|
| 37 |
-
'layer': layer_idx,
|
| 38 |
-
'feature': feature_idx,
|
| 39 |
-
'strength': strength,
|
| 40 |
-
'vector': vec
|
| 41 |
-
})
|
| 42 |
-
del sae
|
| 43 |
-
|
| 44 |
-
return steering_components
|
| 45 |
|
| 46 |
|
| 47 |
def load_saes_from_file(file_path, cfg, device):
|
|
@@ -112,48 +71,6 @@ def load_saes_from_file(file_path, cfg, device):
|
|
| 112 |
return steering_components
|
| 113 |
|
| 114 |
|
| 115 |
-
def generate_steered_answer(model: LanguageModel,
|
| 116 |
-
chat,
|
| 117 |
-
steering_components,
|
| 118 |
-
max_new_tokens=128,
|
| 119 |
-
temperature=0.0,
|
| 120 |
-
repetition_penalty=1.0,
|
| 121 |
-
clamp_intensity=False):
|
| 122 |
-
"""
|
| 123 |
-
Generates an answer from the model given a chat history, applying steering components.
|
| 124 |
-
Expects steering_components to be a list of dicts with keys:
|
| 125 |
-
'layer': int, layer index to apply steering
|
| 126 |
-
'strength': float, steering intensity
|
| 127 |
-
'vector': torch.Tensor, steering vector
|
| 128 |
-
"""
|
| 129 |
-
input_ids = model.tokenizer.apply_chat_template(chat, tokenize=True, add_generation_prompt=True)
|
| 130 |
-
with model.generate(max_new_tokens=max_new_tokens, repetition_penalty=repetition_penalty,
|
| 131 |
-
do_sample=temperature > 0.0, temperature=temperature,
|
| 132 |
-
pad_token_id=model.tokenizer.eos_token_id) as tracer:
|
| 133 |
-
with tracer.invoke(input_ids):
|
| 134 |
-
with tracer.all():
|
| 135 |
-
for sc in steering_components:
|
| 136 |
-
layer, strength, vector = sc["layer"], sc["strength"], sc["vector"]
|
| 137 |
-
|
| 138 |
-
# Ensure vector matches model dtype and device
|
| 139 |
-
layer_output = model.model.layers[layer].output
|
| 140 |
-
vector = vector.to(dtype=layer_output.dtype, device=layer_output.device)
|
| 141 |
-
|
| 142 |
-
length = layer_output.shape[1]
|
| 143 |
-
amount = (strength * vector).unsqueeze(0).expand(length, -1).unsqueeze(0).clone()
|
| 144 |
-
if clamp_intensity:
|
| 145 |
-
projection = (layer_output @ vector).unsqueeze(-1)@(vector.unsqueeze(0))
|
| 146 |
-
amount -= projection
|
| 147 |
-
|
| 148 |
-
layer_output += amount
|
| 149 |
-
with tracer.invoke():
|
| 150 |
-
trace = model.generator.output.save()
|
| 151 |
-
|
| 152 |
-
answer = model.tokenizer.decode(trace[0][len(input_ids):], skip_special_tokens=True)
|
| 153 |
-
output = {'input_ids': input_ids, 'trace': trace, 'answer': answer}
|
| 154 |
-
return output
|
| 155 |
-
|
| 156 |
-
|
| 157 |
|
| 158 |
def create_steering_hook(layer_idx, steering_components, clamp_intensity=False):
|
| 159 |
"""
|
|
|
|
| 1 |
import torch
|
|
|
|
| 2 |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
| 3 |
from threading import Thread
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
|
| 6 |
def load_saes_from_file(file_path, cfg, device):
|
|
|
|
| 71 |
return steering_components
|
| 72 |
|
| 73 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
def create_steering_hook(layer_idx, steering_components, clamp_intensity=False):
|
| 76 |
"""
|