Spaces:
Runtime error
Runtime error
| import os | |
| import sys | |
| import argparse | |
| import json | |
| import pandas as pd | |
| import torch | |
| from tqdm import tqdm | |
| from Bio import SeqIO | |
| from concurrent.futures import ThreadPoolExecutor, as_completed | |
| from src.data.prosst.structure.quantizer import PdbQuantizer | |
| from src.utils.data_utils import extract_seq_from_pdb | |
| import warnings | |
| warnings.filterwarnings("ignore", category=Warning) | |
| structure_vocab_size = 20 | |
| processor = PdbQuantizer(structure_vocab_size = structure_vocab_size) | |
| def get_prosst_token(pdb_file): | |
| """Generate ProSST structure tokens for a PDB file""" | |
| try: | |
| # 提取氨基酸序列 | |
| aa_seq = extract_seq_from_pdb(pdb_file) | |
| # 处理结构序列 | |
| structure_result = processor(pdb_file) | |
| pdb_name = os.path.basename(pdb_file) | |
| # 验证数据结构 | |
| if structure_vocab_size not in structure_result: | |
| raise ValueError(f"Missing structure key: {structure_vocab_size}") | |
| if pdb_name not in structure_result[structure_vocab_size]: | |
| raise ValueError(f"Missing PDB entry: {pdb_name}") | |
| struct_sequence = structure_result[structure_vocab_size][pdb_name]['struct'] | |
| struct_sequence = [int(num) for num in struct_sequence] | |
| # 添加特殊标记 [1] + sequence + [2] | |
| structure_sequence_offset = [3 + num for num in struct_sequence] | |
| structure_input_ids = torch.tensor( | |
| [[1] + structure_sequence_offset + [2]], | |
| dtype=torch.long | |
| ) | |
| return { | |
| "name": os.path.basename(pdb_file).split('.')[0], | |
| "aa_seq": aa_seq, | |
| "struct_tokens": structure_input_ids[0].tolist() | |
| }, None | |
| except Exception as e: | |
| return pdb_file, f"{str(e)}" | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser(description='ProSST structure token generator') | |
| parser.add_argument('--pdb_dir', type=str, help='Directory containing PDB files') | |
| parser.add_argument('--pdb_file', type=str, help='Single PDB file path') | |
| parser.add_argument('--num_workers', type=int, default=16, help='Number of parallel workers') | |
| parser.add_argument('--pdb_index_file', type=str, default=None, help='PDB index file for sharding') | |
| parser.add_argument('--pdb_index_level', type=int, default=1, help='Directory hierarchy depth') | |
| parser.add_argument('--error_file', type=str, help='Error log output path') | |
| parser.add_argument('--out_file', type=str, required=True, help='Output JSON file path') | |
| args = parser.parse_args() | |
| if args.pdb_dir is not None: | |
| # load pdb index file | |
| if args.pdb_index_file: | |
| pdbs = open(args.pdb_index_file).read().splitlines() | |
| pdb_files = [] | |
| for pdb in pdbs: | |
| pdb_relative_dir = args.pdb_dir | |
| for i in range(1, args.pdb_index_level+1): | |
| pdb_relative_dir = os.path.join(pdb_relative_dir, pdb[:i]) | |
| pdb_files.append(os.path.join(pdb_relative_dir, pdb+".pdb")) | |
| # regular pdb dir | |
| else: | |
| pdb_files = sorted([os.path.join(args.pdb_dir, p) for p in os.listdir(args.pdb_dir)]) | |
| # 并行处理 | |
| results, errors = [], [] | |
| with ThreadPoolExecutor(max_workers=args.num_workers) as executor: | |
| futures = {executor.submit(get_prosst_token, f): f for f in pdb_files} | |
| with tqdm(total=len(futures), desc="Processing PDBs") as progress: | |
| for future in as_completed(futures): | |
| result, error = future.result() | |
| if error: | |
| errors.append({"file": result, "error": error}) | |
| else: | |
| results.append(result) | |
| progress.update(1) | |
| if errors: | |
| error_path = args.error_file or args.out_file.replace('.json', '_errors.csv') | |
| pd.DataFrame(errors).to_csv(error_path, index=False) | |
| print(f"Encountered {len(errors)} errors. Saved to {error_path}") | |
| with open(args.out_file, 'w') as f: | |
| f.write('\n'.join(json.dumps(r) for r in results)) | |
| elif args.pdb_file: | |
| result, error = get_prosst_token(args.pdb_file) | |
| if error: | |
| raise RuntimeError(f"Error processing {args.pdb_file}: {error}") | |
| with open(args.out_file, 'w') as f: | |
| json.dump(result, f) |