| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import os |
| | import random |
| | import subprocess |
| | import sys |
| | from enum import Enum, unique |
| |
|
| | from . import launcher |
| | from .api.app import run_api |
| | from .chat.chat_model import run_chat |
| | from .eval.evaluator import run_eval |
| | from .extras.env import VERSION, print_env |
| | from .extras.logging import get_logger |
| | from .extras.misc import get_device_count |
| | from .train.tuner import export_model, run_exp |
| | from .webui.interface import run_web_demo, run_web_ui |
| |
|
| |
|
| | USAGE = ( |
| | "-" * 70 |
| | + "\n" |
| | + "| Usage: |\n" |
| | + "| llamafactory-cli api -h: launch an OpenAI-style API server |\n" |
| | + "| llamafactory-cli chat -h: launch a chat interface in CLI |\n" |
| | + "| llamafactory-cli eval -h: evaluate models |\n" |
| | + "| llamafactory-cli export -h: merge LoRA adapters and export model |\n" |
| | + "| llamafactory-cli train -h: train models |\n" |
| | + "| llamafactory-cli webchat -h: launch a chat interface in Web UI |\n" |
| | + "| llamafactory-cli webui: launch LlamaBoard |\n" |
| | + "| llamafactory-cli version: show version info |\n" |
| | + "-" * 70 |
| | ) |
| |
|
| | WELCOME = ( |
| | "-" * 58 |
| | + "\n" |
| | + "| Welcome to LLaMA Factory, version {}".format(VERSION) |
| | + " " * (21 - len(VERSION)) |
| | + "|\n|" |
| | + " " * 56 |
| | + "|\n" |
| | + "| Project page: https://github.com/hiyouga/LLaMA-Factory |\n" |
| | + "-" * 58 |
| | ) |
| |
|
| | logger = get_logger(__name__) |
| |
|
| |
|
| | @unique |
| | class Command(str, Enum): |
| | API = "api" |
| | CHAT = "chat" |
| | ENV = "env" |
| | EVAL = "eval" |
| | EXPORT = "export" |
| | TRAIN = "train" |
| | WEBDEMO = "webchat" |
| | WEBUI = "webui" |
| | VER = "version" |
| | HELP = "help" |
| |
|
| |
|
| | def main(): |
| | command = sys.argv.pop(1) if len(sys.argv) != 1 else Command.HELP |
| | if command == Command.API: |
| | run_api() |
| | elif command == Command.CHAT: |
| | run_chat() |
| | elif command == Command.ENV: |
| | print_env() |
| | elif command == Command.EVAL: |
| | run_eval() |
| | elif command == Command.EXPORT: |
| | export_model() |
| | elif command == Command.TRAIN: |
| | force_torchrun = os.environ.get("FORCE_TORCHRUN", "0").lower() in ["true", "1"] |
| | if force_torchrun or get_device_count() > 1: |
| | master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1") |
| | master_port = os.environ.get("MASTER_PORT", str(random.randint(20001, 29999))) |
| | logger.info("Initializing distributed tasks at: {}:{}".format(master_addr, master_port)) |
| | process = subprocess.run( |
| | ( |
| | "torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} " |
| | "--master_addr {master_addr} --master_port {master_port} {file_name} {args}" |
| | ).format( |
| | nnodes=os.environ.get("NNODES", "1"), |
| | node_rank=os.environ.get("RANK", "0"), |
| | nproc_per_node=os.environ.get("NPROC_PER_NODE", str(get_device_count())), |
| | master_addr=master_addr, |
| | master_port=master_port, |
| | file_name=launcher.__file__, |
| | args=" ".join(sys.argv[1:]), |
| | ), |
| | shell=True, |
| | ) |
| | sys.exit(process.returncode) |
| | else: |
| | run_exp() |
| | elif command == Command.WEBDEMO: |
| | run_web_demo() |
| | elif command == Command.WEBUI: |
| | run_web_ui() |
| | elif command == Command.VER: |
| | print(WELCOME) |
| | elif command == Command.HELP: |
| | print(USAGE) |
| | else: |
| | raise NotImplementedError("Unknown command: {}.".format(command)) |
| |
|