Spaces:
Runtime error
Runtime error
Update compute_loss.py
Browse filesenabling new syntax structures (e.g., 'an apple is blue')
- compute_loss.py +71 -3
compute_loss.py
CHANGED
|
@@ -142,8 +142,8 @@ def align_wordpieces_indices(
|
|
| 142 |
return wp_indices
|
| 143 |
|
| 144 |
|
| 145 |
-
def extract_attribution_indices(
|
| 146 |
-
doc = parser(prompt)
|
| 147 |
subtrees = []
|
| 148 |
modifiers = ["amod", "nmod", "compound", "npadvmod", "advmod", "acomp"]
|
| 149 |
|
|
@@ -167,6 +167,74 @@ def extract_attribution_indices(prompt, parser):
|
|
| 167 |
subtrees.append(subtree)
|
| 168 |
return subtrees
|
| 169 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
|
| 171 |
def calculate_negative_loss(
|
| 172 |
attention_maps, modifier, noun, subtree_indices, attn_map_idx_to_wp
|
|
@@ -187,7 +255,7 @@ def calculate_negative_loss(
|
|
| 187 |
return negative_loss
|
| 188 |
|
| 189 |
def get_indices(tokenizer, prompt: str) -> Dict[str, int]:
|
| 190 |
-
"""Utility function to list the indices of the tokens you wish to
|
| 191 |
ids = tokenizer(prompt).input_ids
|
| 192 |
indices = {
|
| 193 |
i: tok
|
|
|
|
| 142 |
return wp_indices
|
| 143 |
|
| 144 |
|
| 145 |
+
def extract_attribution_indices(doc):
|
| 146 |
+
# doc = parser(prompt)
|
| 147 |
subtrees = []
|
| 148 |
modifiers = ["amod", "nmod", "compound", "npadvmod", "advmod", "acomp"]
|
| 149 |
|
|
|
|
| 167 |
subtrees.append(subtree)
|
| 168 |
return subtrees
|
| 169 |
|
| 170 |
+
def extract_attribution_indices_with_verbs(doc):
|
| 171 |
+
'''This function specifically addresses cases where a verb is between
|
| 172 |
+
a noun and its modifier. For instance: "a dog that is red"
|
| 173 |
+
here, the aux is between 'dog' and 'red'. '''
|
| 174 |
+
|
| 175 |
+
subtrees = []
|
| 176 |
+
modifiers = ["amod", "nmod", "compound", "npadvmod", "advmod", "acomp",
|
| 177 |
+
'relcl']
|
| 178 |
+
for w in doc:
|
| 179 |
+
if w.pos_ not in ["NOUN", "PROPN"] or w.dep_ in modifiers:
|
| 180 |
+
continue
|
| 181 |
+
subtree = []
|
| 182 |
+
stack = []
|
| 183 |
+
for child in w.children:
|
| 184 |
+
if child.dep_ in modifiers:
|
| 185 |
+
if child.pos_ not in ['AUX', 'VERB']:
|
| 186 |
+
subtree.append(child)
|
| 187 |
+
stack.extend(child.children)
|
| 188 |
+
|
| 189 |
+
while stack:
|
| 190 |
+
node = stack.pop()
|
| 191 |
+
if node.dep_ in modifiers or node.dep_ == "conj":
|
| 192 |
+
# we don't want to add 'is' or other verbs to the loss, we want their children
|
| 193 |
+
if node.pos_ not in ['AUX', 'VERB']:
|
| 194 |
+
subtree.append(node)
|
| 195 |
+
stack.extend(node.children)
|
| 196 |
+
if subtree:
|
| 197 |
+
subtree.append(w)
|
| 198 |
+
subtrees.append(subtree)
|
| 199 |
+
return subtrees
|
| 200 |
+
|
| 201 |
+
def extract_attribution_indices_with_verb_root(doc):
|
| 202 |
+
'''This function specifically addresses cases where a verb is between
|
| 203 |
+
a noun and its modifier. For instance: "a dog that is red"
|
| 204 |
+
here, the aux is between 'dog' and 'red'. '''
|
| 205 |
+
|
| 206 |
+
subtrees = []
|
| 207 |
+
modifiers = ["amod", "nmod", "compound", "npadvmod", "advmod", "acomp"]
|
| 208 |
+
for w in doc:
|
| 209 |
+
subtree = []
|
| 210 |
+
stack = []
|
| 211 |
+
|
| 212 |
+
# if w is a verb/aux and has a noun child and a modifier child, add them to the stack
|
| 213 |
+
if w.pos_ != 'AUX' or w.dep_ in modifiers:
|
| 214 |
+
continue
|
| 215 |
+
|
| 216 |
+
for child in w.children:
|
| 217 |
+
if child.dep_ in modifiers or child.pos_ in ['NOUN', 'PROPN']:
|
| 218 |
+
if child.pos_ not in ['AUX', 'VERB']:
|
| 219 |
+
subtree.append(child)
|
| 220 |
+
stack.extend(child.children)
|
| 221 |
+
# did not find a pair of noun and modifier
|
| 222 |
+
if len(subtree) < 2:
|
| 223 |
+
continue
|
| 224 |
+
|
| 225 |
+
while stack:
|
| 226 |
+
node = stack.pop()
|
| 227 |
+
if node.dep_ in modifiers or node.dep_ == "conj":
|
| 228 |
+
# we don't want to add 'is' or other verbs to the loss, we want their children
|
| 229 |
+
if node.pos_ not in ['AUX']:
|
| 230 |
+
subtree.append(node)
|
| 231 |
+
stack.extend(node.children)
|
| 232 |
+
|
| 233 |
+
if subtree:
|
| 234 |
+
if w.pos_ not in ['AUX']:
|
| 235 |
+
subtree.append(w)
|
| 236 |
+
subtrees.append(subtree)
|
| 237 |
+
return subtrees
|
| 238 |
|
| 239 |
def calculate_negative_loss(
|
| 240 |
attention_maps, modifier, noun, subtree_indices, attn_map_idx_to_wp
|
|
|
|
| 255 |
return negative_loss
|
| 256 |
|
| 257 |
def get_indices(tokenizer, prompt: str) -> Dict[str, int]:
|
| 258 |
+
"""Utility function to list the indices of the tokens you wish to alter"""
|
| 259 |
ids = tokenizer(prompt).input_ids
|
| 260 |
indices = {
|
| 261 |
i: tok
|