Evo-App / evo /generation.py
sochasticbackup's picture
resolved generation task errors
0b4c4af
import numpy as np
import sys
import torch
from typing import List, Tuple, Union
from stripedhyena.model import StripedHyena
from stripedhyena.sample import sample
from stripedhyena.tokenizer import CharLevelTokenizer
from .scoring import logits_to_logprobs, prepare_batch
class Generator:
'''
Adapted from https://github.com/togethercomputer/stripedhyena.
Modifications include:
- `generate()` accepts and returns the recurrent cache state, letting the user
keep track of it across sampling runs.
- Able to sample with long token prompts in which the cache is initialized with
recurrent teacher forcing.
'''
def __init__(
self,
model: StripedHyena,
tokenizer: CharLevelTokenizer,
top_k: int = 50,
top_p: float = 0.7,
temperature: float = 1.,
):
self.model = model
self.tokenizer = tokenizer
self.top_k = top_k
self.top_p = top_p
self.temperature = temperature
self.untils = ['\n\n']
def generate(
self,
device: str,
input_string: str = None,
input_ids: torch.tensor = None,
num_tokens: int = 32,
cached_generation: bool = True,
force_prompt_threshold: int = 128,
print_generation: bool = True,
verbose: bool = False,
skip_special_tokens: bool = False,
stop_at_eos: bool = True,
max_seqlen: int = None,
inference_params_dict: dict = None,
) -> Tuple[torch.tensor, torch.tensor, dict]:
"""
A version of the generate() method that enables passing in and that returns the
`inference_params_dict` for replaying cached sampling from a given state.
"""
if isinstance(self.tokenizer.eos, int):
eos_token_ids = torch.LongTensor([self.tokenizer.eos]).to(device)
else:
# is a tensor
eos_token_ids = self.tokenizer.tokenize(self.tokenizer.eos).to(device)
if input_ids is None:
input = self.tokenizer.tokenize(input_string)
if isinstance(input, list):
input = torch.LongTensor(input).unsqueeze(0).to(device)
# is a tensor
else:
input = input.unsqueeze(0).to(device)
else:
input = input_ids
x = input
if max_seqlen is not None:
x = x[:, -max_seqlen :]
num_tokens = int(num_tokens)
batch_size = x.shape[0]
prompt_length = x.shape[1]
prompt_forcing = prompt_length > force_prompt_threshold
if prompt_forcing:
forced_prompt_length = prompt_length - force_prompt_threshold
x_force = x[:, force_prompt_threshold:]
x = x[:, :force_prompt_threshold]
else:
forced_prompt_length = 0
generation = torch.empty(
x.shape[0],
num_tokens,
dtype=torch.long,
device=x.device,
)
scores = torch.empty(
x.shape[0],
num_tokens,
self.tokenizer.vocab_size,
dtype=torch.float,
device=x.device,
)
# Initialize prefilled to False by default
prefilled = False
if inference_params_dict is not None:
cached_generation = True
prefilled = True
# Ensure that the cached data is loaded on the correct device.
for key, data in inference_params_dict['mha'].key_value_memory_dict.items():
inference_params_dict['mha'].key_value_memory_dict[key] = data.to(x.device)
for key, data in inference_params_dict['hyena'].fir_state_dict.items():
inference_params_dict['hyena'].fir_state_dict[key] = data.to(x.device)
for key, data in inference_params_dict['hyena'].state_dict.items():
inference_params_dict['hyena'].state_dict[key] = data.to(x.device)
elif cached_generation:
inference_params_dict = self.model.initialize_inference_params()
inference_params_dict['mha'].max_batch_size = batch_size
inference_params_dict['hyena'].max_batch_size = batch_size
prefilled = False
if verbose:
mem_after_tok = torch.cuda.memory_allocated(device=x.device) / 1e9
print(f'Memory after tokenization: {mem_after_tok} GB')
print('Starting generation...')
if input_string is not None:
print('Prompt: ' + input_string)
else:
print(f'Prompt ids: {input_ids} {input_ids.shape}')
for i in range(forced_prompt_length + num_tokens):
if prefilled:
post_prefill = True
else:
post_prefill = cached_generation and i > 0
# prefill then process only the last token
if post_prefill:
x = x[:, -1:]
seqlen_offset = inference_params_dict['mha'].seqlen_offset
if seqlen_offset == 0:
seqlen_offset = input.shape[-1]
inference_params_dict['hyena'].seqlen_offset = seqlen_offset
inference_params_dict['mha'].seqlen_offset = seqlen_offset
else:
inference_params_dict['mha'].seqlen_offset += 1
inference_params_dict['hyena'].seqlen_offset += 1
# do forward pass with no gradient
with torch.inference_mode():
logits, inference_params_dict = self.model(
x,
inference_params_dict=inference_params_dict,
)
last_logits = logits[:, -1]
if prompt_forcing and i < forced_prompt_length:
new_idx = x_force[:, i]
else:
new_idx = sample(
last_logits,
top_k=self.top_k,
top_p=self.top_p,
temperature=self.temperature,
)
if stop_at_eos and (generation[0, -2:] == eos_token_ids).all():
print('Stopping generation at EOS')
if print_generation and verbose and batch_size == 1:
print(
f'{self.tokenizer.detokenize([new_idx.item()])}',
end=' ',
)
if prompt_forcing:
if i >= forced_prompt_length:
scores[:, i - forced_prompt_length] = last_logits
generation[:, i - forced_prompt_length] = new_idx
else:
scores[:, i] = last_logits
generation[:, i] = new_idx
if post_prefill:
x = new_idx[:, None]
else:
x = torch.cat([x, new_idx[:, None]], dim=-1)
if verbose:
y = self.tokenizer.detokenize_batch(generation[:, : i + 1])
for until in self.untils:
if until in y:
y = y.split(until)[0]
break
print(f'\nInput: {input_string}, Output: {y}')
mem_end = torch.cuda.memory_allocated(device=x.device) / 1e9
print(f'Memory after generation: {mem_end} GB')
return generation[:, : i + 1], scores[:, : i + 1], inference_params_dict
def generate(
prompt_seqs: List[str],
model: StripedHyena,
tokenizer: CharLevelTokenizer,
n_tokens: int = 100,
temperature: float = 0.,
top_k: int = 1,
top_p: float = 1.,
batched: bool = True,
prepend_bos: bool = False,
cached_generation: bool = False,
force_prompt_threshold: int = 128,
verbose: int = 1,
device: str = 'cuda:0',
**kwargs,
) -> Tuple[List[str], List[float]]:
"""
Performs generation from a list of prompts.
If all prompts are the same length, this can do batched generation.
Also supports cached generation for efficient sampling.
"""
model.eval()
g = Generator(
model,
tokenizer,
top_k=top_k,
top_p=top_p,
temperature=temperature,
)
uniform_lengths = all(len(s) == len(prompt_seqs[0]) for s in prompt_seqs)
if batched and uniform_lengths:
input_ids_list = [
prepare_batch(
prompt_seqs,
tokenizer,
prepend_bos=prepend_bos,
device=device,
)[0]
]
else:
if verbose:
if not uniform_lengths:
sys.stderr.write('Note: Prompts are of different lengths.\n')
sys.stderr.write('Note: Will not do batched generation.\n')
input_ids_list = [
prepare_batch(
[ prompt_seq ],
tokenizer,
prepend_bos=prepend_bos,
device=device,
)[0]
for prompt_seq in prompt_seqs
]
generated_seqs, generated_scores = [], []
for input_ids in input_ids_list:
batch_size = input_ids.shape[0]
output_ids, logits, _ = g.generate(
input_ids=input_ids,
num_tokens=n_tokens,
cached_generation=cached_generation,
force_prompt_threshold=force_prompt_threshold,
device=device,
print_generation=(verbose > 1),
verbose=(verbose > 1),
stop_at_eos=False,
)
if verbose > 1:
print('input_ids.shape', input_ids.shape)
print('output_ids.shape', output_ids.shape)
print('logits.shape', logits.shape)
generated_seqs_batch = list(tokenizer.detokenize_batch(output_ids))
assert len(generated_seqs_batch) == batch_size
generated_seqs += generated_seqs_batch
logprobs = logits_to_logprobs(logits, output_ids)
logprobs = logprobs.float().cpu().numpy()
generated_scores += [ np.mean(logprobs[idx]) for idx in range(batch_size) ]
assert len(generated_seqs) == len(generated_scores) == len(prompt_seqs)
if verbose:
for seq, score, prompt in zip(generated_seqs, generated_scores, prompt_seqs):
print(f'Prompt: "{prompt}",\tOutput: "{seq}",\tScore: {score}')
return generated_seqs, generated_scores