Evo-App / evo /scoring.py
sochasticbackup's picture
initialised app
2997d61
import numpy as np
import torch
from typing import List, Tuple
from stripedhyena.model import StripedHyena
from stripedhyena.tokenizer import CharLevelTokenizer
def prepare_batch(
seqs: List[str],
tokenizer: CharLevelTokenizer,
prepend_bos: bool = True,
device: str = 'cuda:0'
) -> Tuple[torch.Tensor, List[int]]:
"""
Takes in a list of sequences, tokenizes them, and puts them in a tensor batch.
If the sequences have differing lengths, then pad up to the maximum sequence length.
"""
seq_lengths = [ len(seq) for seq in seqs ]
max_seq_length = max(seq_lengths)
input_ids = []
for seq in seqs:
padding = [tokenizer.pad_id] * (max_seq_length - len(seq))
input_ids.append(
torch.tensor(
([tokenizer.eod_id] * int(prepend_bos)) + tokenizer.tokenize(seq) + padding,
dtype=torch.long,
).to(device).unsqueeze(0)
)
input_ids = torch.cat(input_ids, dim=0)
return input_ids, seq_lengths
def logits_to_logprobs(
logits: torch.Tensor,
input_ids: torch.Tensor,
trim_bos: bool = True,
) -> torch.Tensor:
"""
Takes in a tensor of logits of dimension (batch, length, vocab).
Computes the log-likelihoods using a softmax along the vocab dimension.
Uses the `input_ids` to index into the log-likelihoods and returns the likelihood
of the provided sequence at each position with dimension (batch, length).
"""
softmax_logprobs = torch.log_softmax(logits, dim=-1)
if trim_bos:
softmax_logprobs = softmax_logprobs[:, :-1] # Remove last prediction.
input_ids = input_ids[:, 1:] # Trim BOS added by tokenizer.
assert(softmax_logprobs.shape[1] == input_ids.shape[1])
logprobs = torch.gather(
softmax_logprobs, # Gather likelihoods...
2, # along the vocab dimension...
input_ids.unsqueeze(-1) # using the token ids to index.
).squeeze(-1)
return logprobs
def score_sequences(
seqs: List[str],
model: StripedHyena,
tokenizer: CharLevelTokenizer,
reduce_method: str = 'mean',
device: str = 'cuda:0',
) -> List[float]:
"""
Computes the model log-likelihood scores for sequences in `seqs`.
Uses `reduce_method` to take the mean or sum across the likelihoods at each
position (default: `'mean'`).
Returns a list of scalar scores corresponding to the reduced log-likelihoods for
each sequence.
"""
input_ids, seq_lengths = prepare_batch(seqs, tokenizer, device=device, prepend_bos=True)
assert(len(seq_lengths) == input_ids.shape[0])
with torch.inference_mode():
logits, _ = model(input_ids) # (batch, length, vocab)
logprobs = logits_to_logprobs(logits, input_ids, trim_bos=True)
logprobs = logprobs.float().cpu().numpy()
if reduce_method == 'mean':
reduce_func = np.mean
elif reduce_method == 'sum':
reduce_func = np.sum
else:
raise ValueError(f'Invalid reduce_method {reduce_method}')
return [
reduce_func(logprobs[idx][:seq_lengths[idx]])
for idx in range(len(seq_lengths))
]
def positional_entropies(
seqs: List[str],
model: StripedHyena,
tokenizer: CharLevelTokenizer,
device: str = 'cuda:0',
) -> List[np.array]:
"""
Computes the positional entropies for sequences in `seqs`.
Returns a list of arrays, where each array is the same length as the
corresponding sequence length. Each array contains the per-position entropy
across the vocab dimension.
"""
input_ids, seq_lengths = prepare_batch(seqs, tokenizer, device=device, prepend_bos=True)
assert(len(seq_lengths) == input_ids.shape[0])
with torch.inference_mode():
logits, _ = model(input_ids) # (batch, length, vocab)
# Tokenizer prepends BOS, remember to remove last prediction.
softmax_logprobs = torch.log_softmax(logits, dim=-1)[:, :-1]
entropies = -torch.sum(torch.exp(softmax_logprobs) * softmax_logprobs, dim=-1)
entropies = entropies.float().cpu().numpy()
sequence_entropies = [
entropies[idx][:seq_lengths[idx]] for idx in range(len(seq_lengths))
]
assert all(
len(seq) == len(entropy) for seq, entropy in zip(seqs, sequence_entropies)
)
return sequence_entropies