Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| import numpy as np | |
| from tqdm import tqdm | |
| from config import * | |
| from utils import * | |
| from samplings import * | |
| from accelerate import Accelerator | |
| from transformers import BertConfig, AutoTokenizer | |
| import argparse | |
| # Parse command-line arguments | |
| parser = argparse.ArgumentParser(description="Feature extraction for CLaMP3.") | |
| parser.add_argument("--epoch", type=str, default=None, help="Epoch of the checkpoint to load.") | |
| parser.add_argument("input_dir", type=str, help="Directory containing input data files.") | |
| parser.add_argument("output_dir", type=str, help="Directory to save the output features.") | |
| parser.add_argument("--get_global", action="store_true", help="Get global feature.") | |
| args = parser.parse_args() | |
| # Retrieve arguments | |
| epoch = args.epoch | |
| input_dir = args.input_dir | |
| output_dir = args.output_dir | |
| get_global = args.get_global | |
| files = [] | |
| for root, dirs, fs in os.walk(input_dir): | |
| for f in fs: | |
| if f.endswith(".txt") or f.endswith(".abc") or f.endswith(".mtf") or f.endswith(".npy"): | |
| files.append(os.path.join(root, f)) | |
| print(f"Found {len(files)} files in total") | |
| # Initialize accelerator and device | |
| accelerator = Accelerator() | |
| device = accelerator.device | |
| print("Using device:", device) | |
| # Model and configuration setup | |
| audio_config = BertConfig(vocab_size=1, | |
| hidden_size=AUDIO_HIDDEN_SIZE, | |
| num_hidden_layers=AUDIO_NUM_LAYERS, | |
| num_attention_heads=AUDIO_HIDDEN_SIZE//64, | |
| intermediate_size=AUDIO_HIDDEN_SIZE*4, | |
| max_position_embeddings=MAX_AUDIO_LENGTH) | |
| symbolic_config = BertConfig(vocab_size=1, | |
| hidden_size=M3_HIDDEN_SIZE, | |
| num_hidden_layers=PATCH_NUM_LAYERS, | |
| num_attention_heads=M3_HIDDEN_SIZE//64, | |
| intermediate_size=M3_HIDDEN_SIZE*4, | |
| max_position_embeddings=PATCH_LENGTH) | |
| model = CLaMP3Model(audio_config=audio_config, | |
| symbolic_config=symbolic_config, | |
| text_model_name=TEXT_MODEL_NAME, | |
| hidden_size=CLAMP3_HIDDEN_SIZE, | |
| load_m3=CLAMP3_LOAD_M3) | |
| model = model.to(device) | |
| tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME) | |
| patchilizer = M3Patchilizer() | |
| # print parameter number | |
| print("Total Parameter Number: "+str(sum(p.numel() for p in model.parameters()))) | |
| # Load model weights | |
| model.eval() | |
| checkpoint_path = CLAMP3_WEIGHTS_PATH | |
| if epoch is not None: | |
| checkpoint_path = CLAMP3_WEIGHTS_PATH.replace(".pth", f"_{epoch}.pth") | |
| checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=True) | |
| print(f"Successfully Loaded CLaMP 3 Checkpoint from Epoch {checkpoint['epoch']} with loss {checkpoint['min_eval_loss']}") | |
| model.load_state_dict(checkpoint['model']) | |
| def extract_feature(filename, get_global=get_global): | |
| if not filename.endswith(".npy"): | |
| with open(filename, "r", encoding="utf-8") as f: | |
| item = f.read() | |
| if filename.endswith(".txt"): | |
| item = list(set(item.split("\n"))) | |
| item = "\n".join(item) | |
| item = item.split("\n") | |
| item = [c for c in item if len(c) > 0] | |
| item = tokenizer.sep_token.join(item) | |
| input_data = tokenizer(item, return_tensors="pt") | |
| input_data = input_data['input_ids'].squeeze(0) | |
| max_input_length = MAX_TEXT_LENGTH | |
| elif filename.endswith(".abc") or filename.endswith(".mtf"): | |
| input_data = patchilizer.encode(item, add_special_patches=True) | |
| input_data = torch.tensor(input_data) | |
| max_input_length = PATCH_LENGTH | |
| elif filename.endswith(".npy"): | |
| input_data = np.load(filename) | |
| input_data = torch.tensor(input_data) | |
| input_data = input_data.reshape(-1, input_data.size(-1)) | |
| zero_vec = torch.zeros((1, input_data.size(-1))) | |
| input_data = torch.cat((zero_vec, input_data, zero_vec), 0) | |
| max_input_length = MAX_AUDIO_LENGTH | |
| else: | |
| raise ValueError(f"Unsupported file type: {filename}, only support .txt, .abc, .mtf, .npy files") | |
| segment_list = [] | |
| for i in range(0, len(input_data), max_input_length): | |
| segment_list.append(input_data[i:i+max_input_length]) | |
| segment_list[-1] = input_data[-max_input_length:] | |
| last_hidden_states_list = [] | |
| for input_segment in segment_list: | |
| input_masks = torch.tensor([1]*input_segment.size(0)) | |
| if filename.endswith(".txt"): | |
| pad_indices = torch.ones(MAX_TEXT_LENGTH - input_segment.size(0)).long() * tokenizer.pad_token_id | |
| elif filename.endswith(".abc") or filename.endswith(".mtf"): | |
| pad_indices = torch.ones((PATCH_LENGTH - input_segment.size(0), PATCH_SIZE)).long() * patchilizer.pad_token_id | |
| else: | |
| pad_indices = torch.ones((MAX_AUDIO_LENGTH - input_segment.size(0), AUDIO_HIDDEN_SIZE)).float() * 0. | |
| input_masks = torch.cat((input_masks, torch.zeros(max_input_length - input_segment.size(0))), 0) | |
| input_segment = torch.cat((input_segment, pad_indices), 0) | |
| if filename.endswith(".txt"): | |
| last_hidden_states = model.get_text_features(text_inputs=input_segment.unsqueeze(0).to(device), | |
| text_masks=input_masks.unsqueeze(0).to(device), | |
| get_global=get_global) | |
| elif filename.endswith(".abc") or filename.endswith(".mtf"): | |
| last_hidden_states = model.get_symbolic_features(symbolic_inputs=input_segment.unsqueeze(0).to(device), | |
| symbolic_masks=input_masks.unsqueeze(0).to(device), | |
| get_global=get_global) | |
| else: | |
| last_hidden_states = model.get_audio_features(audio_inputs=input_segment.unsqueeze(0).to(device), | |
| audio_masks=input_masks.unsqueeze(0).to(device), | |
| get_global=get_global) | |
| if not get_global: | |
| last_hidden_states = last_hidden_states[:, :input_masks.sum().long().item(), :] | |
| last_hidden_states_list.append(last_hidden_states) | |
| if not get_global: | |
| last_hidden_states_list = [last_hidden_states[0] for last_hidden_states in last_hidden_states_list] | |
| last_hidden_states_list[-1] = last_hidden_states_list[-1][-(len(input_data)%max_input_length):] | |
| last_hidden_states_list = torch.concat(last_hidden_states_list, 0) | |
| else: | |
| full_chunk_cnt = len(input_data) // max_input_length | |
| remain_chunk_len = len(input_data) % max_input_length | |
| if remain_chunk_len == 0: | |
| feature_weights = torch.tensor([max_input_length] * full_chunk_cnt, device=device).view(-1, 1) | |
| else: | |
| feature_weights = torch.tensor([max_input_length] * full_chunk_cnt + [remain_chunk_len], device=device).view(-1, 1) | |
| last_hidden_states_list = torch.concat(last_hidden_states_list, 0) | |
| last_hidden_states_list = last_hidden_states_list * feature_weights | |
| last_hidden_states_list = last_hidden_states_list.sum(dim=0) / feature_weights.sum() | |
| return last_hidden_states_list | |
| def process_directory(input_dir, output_dir, files): | |
| # calculate the number of files to process per GPU | |
| num_files_per_gpu = len(files) // accelerator.num_processes | |
| # calculate the start and end index for the current GPU | |
| start_idx = accelerator.process_index * num_files_per_gpu | |
| end_idx = start_idx + num_files_per_gpu | |
| if accelerator.process_index == accelerator.num_processes - 1: | |
| end_idx = len(files) | |
| files_to_process = files[start_idx:end_idx] | |
| # process the files | |
| for file in tqdm(files_to_process): | |
| output_subdir = output_dir + os.path.dirname(file)[len(input_dir):] | |
| try: | |
| os.makedirs(output_subdir, exist_ok=True) | |
| except Exception as e: | |
| print(output_subdir + " can not be created\n" + str(e)) | |
| output_file = os.path.join(output_subdir, os.path.splitext(os.path.basename(file))[0] + ".npy") | |
| if os.path.exists(output_file): | |
| print(f"Skipping {file}, output already exists") | |
| continue | |
| try: | |
| with torch.no_grad(): | |
| features = extract_feature(file).unsqueeze(0) | |
| np.save(output_file, features.detach().cpu().numpy()) | |
| except Exception as e: | |
| print(f"Failed to process {file}: {e}") | |
| # process the files | |
| process_directory(input_dir, output_dir, files) |