Spaces:
Runtime error
Runtime error
| import logging | |
| import os | |
| from typing import Dict, Any, Union, Tuple | |
| import torch.nn as nn | |
| from transformers import PreTrainedModel | |
| def setup_logging(args: Dict[str, Any]) -> logging.Logger: | |
| """Setup logging configuration.""" | |
| # Create logger | |
| logger = logging.getLogger('training') | |
| logger.setLevel(logging.INFO) | |
| # Create formatters and handlers | |
| formatter = logging.Formatter( | |
| '%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
| datefmt='%Y-%m-%d %H:%M:%S' | |
| ) | |
| # Console handler | |
| console_handler = logging.StreamHandler() | |
| console_handler.setLevel(logging.INFO) | |
| console_handler.setFormatter(formatter) | |
| logger.addHandler(console_handler) | |
| # File handler | |
| log_file = os.path.join(args.output_dir, f'{args.output_model_name.split(".")[0]}_training.log') | |
| file_handler = logging.FileHandler(log_file) | |
| file_handler.setLevel(logging.INFO) | |
| file_handler.setFormatter(formatter) | |
| logger.addHandler(file_handler) | |
| # Log initial info | |
| logger.info("Starting training with configuration:") | |
| for key, value in vars(args).items(): | |
| logger.info(f"{key}: {value}") | |
| return logger | |
| def count_parameters(model: Union[nn.Module, PreTrainedModel]) -> Tuple[int, int]: | |
| """ | |
| Count total and trainable parameters in model. | |
| Args: | |
| model: PyTorch model or Hugging Face model | |
| Returns: | |
| Tuple of (total_params, trainable_params) | |
| """ | |
| total_params = sum(p.numel() for p in model.parameters()) | |
| trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| return total_params, trainable_params | |
| def format_parameter_count(count: int) -> str: | |
| """Format parameter count with appropriate unit.""" | |
| if count >= 1e9: | |
| return f"{count/1e9:.2f}B" | |
| elif count >= 1e6: | |
| return f"{count/1e6:.2f}M" | |
| elif count >= 1e3: | |
| return f"{count/1e3:.2f}K" | |
| return str(count) | |
| def print_model_parameters(model: nn.Module, plm_model: PreTrainedModel, logger=None): | |
| """ | |
| Print parameter statistics for both adapter and PLM models. | |
| Args: | |
| model: Adapter model | |
| plm_model: Pre-trained language model | |
| logger: Optional logger for output | |
| """ | |
| # Count adapter parameters | |
| adapter_total, adapter_trainable = count_parameters(model) | |
| # Count PLM parameters | |
| plm_total, plm_trainable = count_parameters(plm_model) | |
| # Prepare output strings | |
| output = [ | |
| "------------------------", | |
| "Model Parameters Statistics:", | |
| "------------------------", | |
| f"Adapter Model:", | |
| f" Total parameters: {format_parameter_count(adapter_total)}", | |
| f" Trainable parameters: {format_parameter_count(adapter_trainable)}", | |
| f"Pre-trained Model:", | |
| f" Total parameters: {format_parameter_count(plm_total)}", | |
| f" Trainable parameters: {format_parameter_count(plm_trainable)}", | |
| f"Combined:", | |
| f" Total parameters: {format_parameter_count(adapter_total + plm_total)}", | |
| f" Trainable parameters: {format_parameter_count(adapter_trainable + plm_trainable)}", | |
| f" Trainable percentage: {((adapter_trainable + plm_trainable) / (adapter_total + plm_total)) * 100:.2f}%", | |
| "------------------------" | |
| ] | |
| # Print output | |
| if logger: | |
| for line in output: | |
| logger.info(line) | |
| else: | |
| for line in output: | |
| print(line) |