|
|
import numpy as np |
|
|
import torch |
|
|
from typing import List, Tuple |
|
|
|
|
|
from stripedhyena.model import StripedHyena |
|
|
from stripedhyena.tokenizer import CharLevelTokenizer |
|
|
|
|
|
|
|
|
def prepare_batch( |
|
|
seqs: List[str], |
|
|
tokenizer: CharLevelTokenizer, |
|
|
prepend_bos: bool = True, |
|
|
device: str = 'cuda:0' |
|
|
) -> Tuple[torch.Tensor, List[int]]: |
|
|
""" |
|
|
Takes in a list of sequences, tokenizes them, and puts them in a tensor batch. |
|
|
If the sequences have differing lengths, then pad up to the maximum sequence length. |
|
|
""" |
|
|
seq_lengths = [ len(seq) for seq in seqs ] |
|
|
max_seq_length = max(seq_lengths) |
|
|
|
|
|
input_ids = [] |
|
|
for seq in seqs: |
|
|
padding = [tokenizer.pad_id] * (max_seq_length - len(seq)) |
|
|
input_ids.append( |
|
|
torch.tensor( |
|
|
([tokenizer.eod_id] * int(prepend_bos)) + tokenizer.tokenize(seq) + padding, |
|
|
dtype=torch.long, |
|
|
).to(device).unsqueeze(0) |
|
|
) |
|
|
input_ids = torch.cat(input_ids, dim=0) |
|
|
|
|
|
return input_ids, seq_lengths |
|
|
|
|
|
|
|
|
def logits_to_logprobs( |
|
|
logits: torch.Tensor, |
|
|
input_ids: torch.Tensor, |
|
|
trim_bos: bool = True, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Takes in a tensor of logits of dimension (batch, length, vocab). |
|
|
Computes the log-likelihoods using a softmax along the vocab dimension. |
|
|
Uses the `input_ids` to index into the log-likelihoods and returns the likelihood |
|
|
of the provided sequence at each position with dimension (batch, length). |
|
|
""" |
|
|
softmax_logprobs = torch.log_softmax(logits, dim=-1) |
|
|
if trim_bos: |
|
|
softmax_logprobs = softmax_logprobs[:, :-1] |
|
|
input_ids = input_ids[:, 1:] |
|
|
assert(softmax_logprobs.shape[1] == input_ids.shape[1]) |
|
|
|
|
|
logprobs = torch.gather( |
|
|
softmax_logprobs, |
|
|
2, |
|
|
input_ids.unsqueeze(-1) |
|
|
).squeeze(-1) |
|
|
|
|
|
return logprobs |
|
|
|
|
|
|
|
|
def score_sequences( |
|
|
seqs: List[str], |
|
|
model: StripedHyena, |
|
|
tokenizer: CharLevelTokenizer, |
|
|
reduce_method: str = 'mean', |
|
|
device: str = 'cuda:0', |
|
|
) -> List[float]: |
|
|
""" |
|
|
Computes the model log-likelihood scores for sequences in `seqs`. |
|
|
Uses `reduce_method` to take the mean or sum across the likelihoods at each |
|
|
position (default: `'mean'`). |
|
|
|
|
|
Returns a list of scalar scores corresponding to the reduced log-likelihoods for |
|
|
each sequence. |
|
|
""" |
|
|
input_ids, seq_lengths = prepare_batch(seqs, tokenizer, device=device, prepend_bos=True) |
|
|
assert(len(seq_lengths) == input_ids.shape[0]) |
|
|
|
|
|
with torch.inference_mode(): |
|
|
logits, _ = model(input_ids) |
|
|
|
|
|
logprobs = logits_to_logprobs(logits, input_ids, trim_bos=True) |
|
|
logprobs = logprobs.float().cpu().numpy() |
|
|
|
|
|
if reduce_method == 'mean': |
|
|
reduce_func = np.mean |
|
|
elif reduce_method == 'sum': |
|
|
reduce_func = np.sum |
|
|
else: |
|
|
raise ValueError(f'Invalid reduce_method {reduce_method}') |
|
|
|
|
|
return [ |
|
|
reduce_func(logprobs[idx][:seq_lengths[idx]]) |
|
|
for idx in range(len(seq_lengths)) |
|
|
] |
|
|
|
|
|
|
|
|
def positional_entropies( |
|
|
seqs: List[str], |
|
|
model: StripedHyena, |
|
|
tokenizer: CharLevelTokenizer, |
|
|
device: str = 'cuda:0', |
|
|
) -> List[np.array]: |
|
|
""" |
|
|
Computes the positional entropies for sequences in `seqs`. |
|
|
|
|
|
Returns a list of arrays, where each array is the same length as the |
|
|
corresponding sequence length. Each array contains the per-position entropy |
|
|
across the vocab dimension. |
|
|
""" |
|
|
input_ids, seq_lengths = prepare_batch(seqs, tokenizer, device=device, prepend_bos=True) |
|
|
assert(len(seq_lengths) == input_ids.shape[0]) |
|
|
|
|
|
with torch.inference_mode(): |
|
|
logits, _ = model(input_ids) |
|
|
|
|
|
|
|
|
softmax_logprobs = torch.log_softmax(logits, dim=-1)[:, :-1] |
|
|
|
|
|
entropies = -torch.sum(torch.exp(softmax_logprobs) * softmax_logprobs, dim=-1) |
|
|
entropies = entropies.float().cpu().numpy() |
|
|
|
|
|
sequence_entropies = [ |
|
|
entropies[idx][:seq_lengths[idx]] for idx in range(len(seq_lengths)) |
|
|
] |
|
|
assert all( |
|
|
len(seq) == len(entropy) for seq, entropy in zip(seqs, sequence_entropies) |
|
|
) |
|
|
|
|
|
return sequence_entropies |
|
|
|