Bmccloud22 commited on
Commit
90a59c9
·
verified ·
1 Parent(s): 360e349

Deploy LaunchLLM - Production AI Training Platform

Browse files
Files changed (4) hide show
  1. .gitignore +1 -0
  2. README.md +2 -2
  3. runpod_client.py +262 -0
  4. runpod_manager.py +483 -0
.gitignore CHANGED
@@ -1,5 +1,6 @@
1
  __pycache__/
2
  *.py[cod]
 
3
  *.log
4
  .secrets/
5
  .gradio/
 
1
  __pycache__/
2
  *.py[cod]
3
+ *.pyc
4
  *.log
5
  .secrets/
6
  .gradio/
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 🚀
4
  colorFrom: blue
5
  colorTo: purple
6
  sdk: gradio
7
- sdk_version: 5.49.1
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
@@ -174,4 +174,4 @@ Start by clicking the **Environment** tab above and adding your HuggingFace toke
174
 
175
  ---
176
 
177
- **Built with ❤️ for domain experts who want custom AI without the complexity**
 
4
  colorFrom: blue
5
  colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 4.0.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
 
174
 
175
  ---
176
 
177
+ **Built with ❤️ for domain experts who want custom AI without the complexity**
runpod_client.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ RunPod Client - Low-level GraphQL API client for RunPod
3
+
4
+ Provides direct access to RunPod's GraphQL API for pod management.
5
+ """
6
+
7
+ import os
8
+ import requests
9
+ from typing import Optional, List, Dict
10
+ from dataclasses import dataclass
11
+
12
+
13
+ @dataclass
14
+ class PodInfo:
15
+ """Information about a RunPod pod"""
16
+ id: str
17
+ name: str
18
+ status: str
19
+ gpu_type: str
20
+ gpu_count: int
21
+ cost_per_hour: float
22
+ runtime: Optional[Dict] = None
23
+
24
+
25
+ class RunPodClient:
26
+ """Low-level client for RunPod GraphQL API"""
27
+
28
+ def __init__(self, api_key: Optional[str] = None):
29
+ self.api_key = api_key or os.getenv("RUNPOD_API_KEY")
30
+ if not self.api_key:
31
+ raise ValueError("RunPod API key required. Set RUNPOD_API_KEY environment variable.")
32
+
33
+ self.endpoint = "https://api.runpod.io/graphql"
34
+ self.headers = {
35
+ "Content-Type": "application/json",
36
+ "Authorization": f"Bearer {self.api_key}"
37
+ }
38
+
39
+ def _query(self, query: str, variables: Optional[Dict] = None) -> Dict:
40
+ """Execute a GraphQL query"""
41
+ payload = {
42
+ "query": query,
43
+ "variables": variables or {}
44
+ }
45
+
46
+ response = requests.post(
47
+ self.endpoint,
48
+ json=payload,
49
+ headers=self.headers,
50
+ timeout=30
51
+ )
52
+
53
+ if response.status_code != 200:
54
+ raise Exception(f"GraphQL request failed: {response.status_code} {response.text}")
55
+
56
+ return response.json()
57
+
58
+ def list_pods(self) -> List[PodInfo]:
59
+ """List all pods"""
60
+ query = """
61
+ query {
62
+ myself {
63
+ pods {
64
+ id
65
+ name
66
+ desiredStatus
67
+ runtime {
68
+ gpus {
69
+ id
70
+ }
71
+ }
72
+ machine {
73
+ podHostId
74
+ }
75
+ costPerHr
76
+ gpuCount
77
+ }
78
+ }
79
+ }
80
+ """
81
+
82
+ result = self._query(query)
83
+
84
+ if "errors" in result:
85
+ print(f"Error listing pods: {result['errors']}")
86
+ return []
87
+
88
+ pods_data = result.get("data", {}).get("myself", {}).get("pods", [])
89
+ pods = []
90
+
91
+ for pod_data in pods_data:
92
+ gpu_type = "GPU" # Generic GPU type since API doesn't provide type details
93
+ if pod_data.get("runtime") and pod_data["runtime"].get("gpus"):
94
+ gpu_id = pod_data["runtime"]["gpus"][0].get("id", "")
95
+ if gpu_id:
96
+ gpu_type = f"GPU-{gpu_id[:8]}" # Use shortened GPU ID
97
+
98
+ pods.append(PodInfo(
99
+ id=pod_data["id"],
100
+ name=pod_data["name"],
101
+ status=pod_data.get("desiredStatus", "unknown"),
102
+ gpu_type=gpu_type,
103
+ gpu_count=pod_data.get("gpuCount", 0),
104
+ cost_per_hour=pod_data.get("costPerHr", 0.0),
105
+ runtime=pod_data.get("runtime")
106
+ ))
107
+
108
+ return pods
109
+
110
+ def create_pod(
111
+ self,
112
+ name: str,
113
+ image_name: str,
114
+ gpu_type_id: str,
115
+ gpu_count: int = 1,
116
+ volume_in_gb: int = 100,
117
+ container_disk_in_gb: int = 50,
118
+ ports: str = "8888/http"
119
+ ) -> Optional[str]:
120
+ """Create a new pod"""
121
+ query = """
122
+ mutation($input: PodFindAndDeployOnDemandInput!) {
123
+ podFindAndDeployOnDemand(input: $input) {
124
+ id
125
+ name
126
+ desiredStatus
127
+ }
128
+ }
129
+ """
130
+
131
+ variables = {
132
+ "input": {
133
+ "name": name,
134
+ "imageName": image_name,
135
+ "gpuTypeId": gpu_type_id,
136
+ "gpuCount": gpu_count,
137
+ "volumeInGb": volume_in_gb,
138
+ "containerDiskInGb": container_disk_in_gb,
139
+ "ports": ports,
140
+ "cloudType": "ALL"
141
+ }
142
+ }
143
+
144
+ result = self._query(query, variables)
145
+
146
+ if "errors" in result:
147
+ print(f"Error creating pod: {result['errors']}")
148
+ return None
149
+
150
+ pod_data = result.get("data", {}).get("podFindAndDeployOnDemand")
151
+ if pod_data:
152
+ return pod_data["id"]
153
+
154
+ return None
155
+
156
+ def stop_pod(self, pod_id: str) -> bool:
157
+ """Stop a running pod"""
158
+ query = """
159
+ mutation($input: PodStopInput!) {
160
+ podStop(input: $input) {
161
+ id
162
+ desiredStatus
163
+ }
164
+ }
165
+ """
166
+
167
+ variables = {
168
+ "input": {
169
+ "podId": pod_id
170
+ }
171
+ }
172
+
173
+ result = self._query(query, variables)
174
+
175
+ if "errors" in result:
176
+ print(f"Error stopping pod: {result['errors']}")
177
+ return False
178
+
179
+ return True
180
+
181
+ def terminate_pod(self, pod_id: str) -> bool:
182
+ """Terminate a pod"""
183
+ query = """
184
+ mutation($input: PodTerminateInput!) {
185
+ podTerminate(input: $input)
186
+ }
187
+ """
188
+
189
+ variables = {
190
+ "input": {
191
+ "podId": pod_id
192
+ }
193
+ }
194
+
195
+ result = self._query(query, variables)
196
+
197
+ if "errors" in result:
198
+ print(f"Error terminating pod: {result['errors']}")
199
+ return False
200
+
201
+ return True
202
+
203
+ def get_gpu_types(self) -> List[Dict]:
204
+ """Get available GPU types"""
205
+ query = """
206
+ query {
207
+ gpuTypes {
208
+ id
209
+ displayName
210
+ memoryInGb
211
+ secureCloud
212
+ communityCloud
213
+ }
214
+ }
215
+ """
216
+
217
+ result = self._query(query)
218
+
219
+ if "errors" in result:
220
+ print(f"Error getting GPU types: {result['errors']}")
221
+ return []
222
+
223
+ gpu_types = result.get("data", {}).get("gpuTypes", [])
224
+ return gpu_types
225
+
226
+ def get_pod_details(self, pod_id: str) -> Optional[Dict]:
227
+ """Get detailed information about a specific pod"""
228
+ query = """
229
+ query($podId: String!) {
230
+ pod(input: {podId: $podId}) {
231
+ id
232
+ name
233
+ desiredStatus
234
+ runtime {
235
+ gpus {
236
+ id
237
+ }
238
+ ports {
239
+ ip
240
+ isIpPublic
241
+ privatePort
242
+ publicPort
243
+ type
244
+ }
245
+ }
246
+ machine {
247
+ podHostId
248
+ }
249
+ gpuCount
250
+ costPerHr
251
+ }
252
+ }
253
+ """
254
+
255
+ variables = {"podId": pod_id}
256
+ result = self._query(query, variables)
257
+
258
+ if "errors" in result:
259
+ print(f"Error getting pod details: {result['errors']}")
260
+ return None
261
+
262
+ return result.get("data", {}).get("pod")
runpod_manager.py ADDED
@@ -0,0 +1,483 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ RunPod Manager - High-level management for RunPod instances
3
+
4
+ Provides higher-level functions for managing RunPod instances including
5
+ deployment, monitoring, and SSH access.
6
+ """
7
+
8
+ import paramiko
9
+ import time
10
+ from typing import Optional, Dict, List
11
+ from dataclasses import dataclass, field
12
+ from runpod_client import RunPodClient, PodInfo
13
+
14
+
15
+ @dataclass
16
+ class DeploymentConfig:
17
+ """Configuration for RunPod deployment."""
18
+ name: str = "aura-training-pod"
19
+ gpu_type: str = "NVIDIA A100 80GB PCIe"
20
+ gpu_count: int = 1
21
+ storage_gb: int = 100
22
+ image: str = "runpod/pytorch:2.1.0-py3.10-cuda11.8.0-devel-ubuntu22.04"
23
+ ports: str = "8888/http,22/tcp,7860/http" # Jupyter, SSH, Gradio
24
+
25
+
26
+ @dataclass
27
+ class TrainingConfig:
28
+ """Configuration for model training on RunPod."""
29
+ model_name: str = "Qwen/Qwen2.5-7B-Instruct"
30
+ lora_rank: int = 8
31
+ learning_rate: float = 2e-4
32
+ num_epochs: int = 3
33
+ batch_size: int = 4
34
+ gradient_accumulation_steps: int = 4
35
+ use_4bit: bool = True
36
+ max_length: int = 2048
37
+
38
+
39
+ class RunPodManager:
40
+ """Manager for RunPod instances with deployment and monitoring"""
41
+
42
+ def __init__(self, api_key: Optional[str] = None):
43
+ self.client = RunPodClient(api_key)
44
+
45
+ def deploy_training_pod(
46
+ self,
47
+ name: str,
48
+ gpu_type: str = "NVIDIA A100 80GB PCIe",
49
+ gpu_count: int = 1,
50
+ storage_gb: int = 100
51
+ ) -> Optional[str]:
52
+ """Deploy a pod configured for model training"""
53
+
54
+ # Use PyTorch image with CUDA support
55
+ image = "runpod/pytorch:2.1.0-py3.10-cuda11.8.0-devel-ubuntu22.04"
56
+
57
+ print(f"Deploying training pod '{name}'...")
58
+ print(f" GPU: {gpu_type} x{gpu_count}")
59
+ print(f" Storage: {storage_gb}GB")
60
+
61
+ pod_id = self.client.create_pod(
62
+ name=name,
63
+ image_name=image,
64
+ gpu_type_id=gpu_type,
65
+ gpu_count=gpu_count,
66
+ volume_in_gb=storage_gb,
67
+ container_disk_in_gb=50,
68
+ ports="8888/http,22/tcp,7860/http" # Jupyter, SSH, Gradio
69
+ )
70
+
71
+ if pod_id:
72
+ print(f"Pod created: {pod_id}")
73
+ print("Waiting for pod to start...")
74
+ time.sleep(10) # Give it time to start
75
+
76
+ return pod_id
77
+
78
+ def get_pod_status(self, pod_id: str) -> Optional[Dict]:
79
+ """Get current status of a pod"""
80
+ pods = self.client.list_pods()
81
+
82
+ for pod in pods:
83
+ if pod.id == pod_id:
84
+ return {
85
+ "id": pod.id,
86
+ "name": pod.name,
87
+ "status": pod.status,
88
+ "gpu_type": pod.gpu_type,
89
+ "cost_per_hour": pod.cost_per_hour
90
+ }
91
+
92
+ return None
93
+
94
+ def list_all_pods(self) -> List[PodInfo]:
95
+ """List all pods"""
96
+ return self.client.list_pods()
97
+
98
+ def stop_pod(self, pod_id: str) -> bool:
99
+ """Stop a running pod"""
100
+ print(f"Stopping pod {pod_id}...")
101
+ return self.client.stop_pod(pod_id)
102
+
103
+ def terminate_pod(self, pod_id: str) -> bool:
104
+ """Terminate a pod"""
105
+ print(f"Terminating pod {pod_id}...")
106
+ return self.client.terminate_pod(pod_id)
107
+
108
+ def get_ssh_connection(
109
+ self,
110
+ pod_ip: str,
111
+ username: str = "root",
112
+ key_file: Optional[str] = None,
113
+ password: Optional[str] = None
114
+ ) -> Optional[paramiko.SSHClient]:
115
+ """Get SSH connection to a pod"""
116
+
117
+ ssh = paramiko.SSHClient()
118
+ ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
119
+
120
+ try:
121
+ if key_file:
122
+ ssh.connect(
123
+ pod_ip,
124
+ username=username,
125
+ key_filename=key_file,
126
+ timeout=10
127
+ )
128
+ elif password:
129
+ ssh.connect(
130
+ pod_ip,
131
+ username=username,
132
+ password=password,
133
+ timeout=10
134
+ )
135
+ else:
136
+ print("Either key_file or password must be provided")
137
+ return None
138
+
139
+ return ssh
140
+
141
+ except Exception as e:
142
+ print(f"SSH connection failed: {e}")
143
+ return None
144
+
145
+ def execute_command(
146
+ self,
147
+ ssh: paramiko.SSHClient,
148
+ command: str
149
+ ) -> tuple[str, str]:
150
+ """Execute a command via SSH"""
151
+
152
+ stdin, stdout, stderr = ssh.exec_command(command)
153
+ return stdout.read().decode(), stderr.read().decode()
154
+
155
+ def upload_file(
156
+ self,
157
+ ssh: paramiko.SSHClient,
158
+ local_path: str,
159
+ remote_path: str
160
+ ) -> bool:
161
+ """Upload a file to the pod"""
162
+
163
+ try:
164
+ sftp = ssh.open_sftp()
165
+ sftp.put(local_path, remote_path)
166
+ sftp.close()
167
+ return True
168
+ except Exception as e:
169
+ print(f"File upload failed: {e}")
170
+ return False
171
+
172
+ def download_file(
173
+ self,
174
+ ssh: paramiko.SSHClient,
175
+ remote_path: str,
176
+ local_path: str
177
+ ) -> bool:
178
+ """Download a file from the pod"""
179
+
180
+ try:
181
+ sftp = ssh.open_sftp()
182
+ sftp.get(remote_path, local_path)
183
+ sftp.close()
184
+ return True
185
+ except Exception as e:
186
+ print(f"File download failed: {e}")
187
+ return False
188
+
189
+ def setup_training_environment(
190
+ self,
191
+ ssh: paramiko.SSHClient,
192
+ requirements_file: Optional[str] = None
193
+ ) -> bool:
194
+ """Setup the training environment on a pod"""
195
+
196
+ print("Setting up training environment...")
197
+
198
+ # Update pip
199
+ print("Updating pip...")
200
+ stdout, stderr = self.execute_command(ssh, "pip install --upgrade pip")
201
+
202
+ if requirements_file:
203
+ # Upload requirements file
204
+ print("Uploading requirements...")
205
+ if not self.upload_file(ssh, requirements_file, "/tmp/requirements.txt"):
206
+ return False
207
+
208
+ # Install requirements
209
+ print("Installing requirements...")
210
+ stdout, stderr = self.execute_command(
211
+ ssh,
212
+ "pip install -r /tmp/requirements.txt"
213
+ )
214
+
215
+ if stderr and "error" in stderr.lower():
216
+ print(f"Installation errors: {stderr}")
217
+ return False
218
+
219
+ print("Environment setup complete!")
220
+ return True
221
+
222
+ def monitor_training(
223
+ self,
224
+ ssh: paramiko.SSHClient,
225
+ log_file: str = "/workspace/training.log",
226
+ interval: int = 30
227
+ ):
228
+ """Monitor training progress"""
229
+
230
+ print(f"Monitoring training log: {log_file}")
231
+ print(f"Checking every {interval} seconds...")
232
+ print("Press Ctrl+C to stop monitoring\n")
233
+
234
+ last_line_count = 0
235
+
236
+ try:
237
+ while True:
238
+ # Get log file content
239
+ stdout, stderr = self.execute_command(
240
+ ssh,
241
+ f"cat {log_file} 2>/dev/null || echo 'Log file not found'"
242
+ )
243
+
244
+ lines = stdout.strip().split('\n')
245
+ new_lines = lines[last_line_count:]
246
+
247
+ if new_lines and new_lines[0] != 'Log file not found':
248
+ for line in new_lines:
249
+ print(line)
250
+ last_line_count = len(lines)
251
+
252
+ time.sleep(interval)
253
+
254
+ except KeyboardInterrupt:
255
+ print("\nStopped monitoring")
256
+
257
+ def get_available_gpus(self) -> List[Dict]:
258
+ """Get list of available GPU types"""
259
+ return self.client.get_gpu_types()
260
+
261
+ def estimate_cost(
262
+ self,
263
+ gpu_type: str,
264
+ gpu_count: int,
265
+ hours: float
266
+ ) -> Optional[float]:
267
+ """Estimate cost for a training job"""
268
+
269
+ pods = self.client.list_pods()
270
+
271
+ # Find cost per hour for this GPU type
272
+ for pod in pods:
273
+ if pod.gpu_type == gpu_type and pod.gpu_count == gpu_count:
274
+ total_cost = pod.cost_per_hour * hours
275
+ return total_cost
276
+
277
+ return None
278
+
279
+ def run_training_on_pod(
280
+ self,
281
+ pod_id: str,
282
+ training_data: List[Dict],
283
+ model_name: str,
284
+ lora_config: Dict,
285
+ training_config: Dict
286
+ ) -> bool:
287
+ """Run training on RunPod pod instead of locally"""
288
+ import json
289
+ import tempfile
290
+
291
+ print(f"Starting remote training on pod {pod_id}...")
292
+
293
+ # 1. Get pod details to find SSH info
294
+ pod_details = self.client.get_pod_details(pod_id)
295
+ if not pod_details:
296
+ print("Error: Could not get pod details")
297
+ return False
298
+
299
+ # Extract SSH connection info
300
+ runtime = pod_details.get("runtime")
301
+ if not runtime or not runtime.get("ports"):
302
+ print("Error: Pod runtime not available. Pod may still be starting.")
303
+ return False
304
+
305
+ # Find SSH port
306
+ ssh_port = None
307
+ ssh_ip = None
308
+ for port in runtime["ports"]:
309
+ if port.get("privatePort") == 22:
310
+ ssh_ip = port.get("ip")
311
+ ssh_port = port.get("publicPort")
312
+ break
313
+
314
+ if not ssh_ip or not ssh_port:
315
+ print("Error: SSH port not found in pod details")
316
+ return False
317
+
318
+ print(f"SSH Connection: {ssh_ip}:{ssh_port}")
319
+
320
+ # 2. Save training data to temp file
321
+ with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
322
+ json.dump(training_data, f)
323
+ data_file = f.name
324
+
325
+ # 3. Create training script
326
+ training_script = f"""
327
+ import json
328
+ import sys
329
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
330
+ from peft import LoraConfig, get_peft_model, TaskType, prepare_model_for_kbit_training
331
+ from datasets import Dataset
332
+ import torch
333
+
334
+ print("Loading training data...")
335
+ with open('/workspace/training_data.json', 'r') as f:
336
+ data = json.load(f)
337
+
338
+ print(f"Loaded {{len(data)}} training examples")
339
+
340
+ print("Loading model: {model_name}")
341
+ model = AutoModelForCausalLM.from_pretrained(
342
+ "{model_name}",
343
+ load_in_4bit=True,
344
+ device_map="auto",
345
+ torch_dtype=torch.float16
346
+ )
347
+ tokenizer = AutoTokenizer.from_pretrained("{model_name}")
348
+ tokenizer.pad_token = tokenizer.eos_token
349
+
350
+ print("Preparing model for training...")
351
+ model = prepare_model_for_kbit_training(model)
352
+
353
+ lora_config = LoraConfig(
354
+ r={lora_config.get('r', 16)},
355
+ lora_alpha={lora_config.get('lora_alpha', 32)},
356
+ target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
357
+ lora_dropout=0.05,
358
+ bias="none",
359
+ task_type=TaskType.CAUSAL_LM
360
+ )
361
+
362
+ model = get_peft_model(model, lora_config)
363
+ model.print_trainable_parameters()
364
+
365
+ print("Preparing dataset...")
366
+ def format_data(example):
367
+ text = f"###Instruction: {{example['instruction']}}\\n###Response: {{example['output']}}"
368
+ return tokenizer(text, truncation=True, max_length=2048, padding="max_length")
369
+
370
+ dataset = Dataset.from_list(data)
371
+ dataset = dataset.map(format_data, batched=False)
372
+
373
+ training_args = TrainingArguments(
374
+ output_dir="/workspace/outputs",
375
+ num_train_epochs={training_config.get('num_epochs', 3)},
376
+ per_device_train_batch_size={training_config.get('batch_size', 1)},
377
+ gradient_accumulation_steps={training_config.get('gradient_accumulation_steps', 16)},
378
+ learning_rate={training_config.get('learning_rate', 2e-4)},
379
+ logging_steps=10,
380
+ save_steps=100,
381
+ save_total_limit=2,
382
+ fp16=True,
383
+ report_to="none"
384
+ )
385
+
386
+ print("Starting training...")
387
+ trainer = Trainer(
388
+ model=model,
389
+ args=training_args,
390
+ train_dataset=dataset
391
+ )
392
+
393
+ trainer.train()
394
+
395
+ print("Saving model...")
396
+ model.save_pretrained("/workspace/final_model")
397
+ tokenizer.save_pretrained("/workspace/final_model")
398
+
399
+ print("Training complete!")
400
+ """
401
+
402
+ # Save script to temp file
403
+ with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
404
+ f.write(training_script)
405
+ script_file = f.name
406
+
407
+ print("Connecting to pod via SSH...")
408
+
409
+ # Get path to SSH key
410
+ import os
411
+ key_path = os.path.join(os.getcwd(), ".ssh", "runpod_key")
412
+
413
+ if not os.path.exists(key_path):
414
+ print(f"Error: SSH key not found at {key_path}")
415
+ print("Run: ssh-keygen -t ed25519 -f .ssh/runpod_key -N ''")
416
+ print("Then add the public key to RunPod: https://www.runpod.io/console/user/settings")
417
+ return False
418
+
419
+ # Get SSH connection (RunPod uses root user by default)
420
+ ssh = self.get_ssh_connection(
421
+ pod_ip=ssh_ip,
422
+ username="root",
423
+ password=None,
424
+ key_file=key_path
425
+ )
426
+
427
+ if not ssh:
428
+ print("Error: Could not establish SSH connection")
429
+ print(f"Tried using key: {key_path}")
430
+ print("Verify the public key is added to RunPod: https://www.runpod.io/console/user/settings")
431
+ return False
432
+
433
+ try:
434
+ # Upload training data
435
+ print("Uploading training data...")
436
+ if not self.upload_file(ssh, data_file, "/workspace/training_data.json"):
437
+ return False
438
+
439
+ # Upload training script
440
+ print("Uploading training script...")
441
+ if not self.upload_file(ssh, script_file, "/workspace/train.py"):
442
+ return False
443
+
444
+ # Install required packages
445
+ print("Installing required packages...")
446
+ stdout, stderr = self.execute_command(
447
+ ssh,
448
+ "pip install transformers peft datasets accelerate bitsandbytes"
449
+ )
450
+
451
+ # Execute training
452
+ print("Starting training on pod...")
453
+ print("Training will run in the background on the pod.")
454
+ print("You can monitor progress by checking the pod's logs.")
455
+
456
+ # Run training in background with nohup
457
+ stdout, stderr = self.execute_command(
458
+ ssh,
459
+ "nohup python /workspace/train.py > /workspace/training.log 2>&1 &"
460
+ )
461
+
462
+ print("\nTraining initiated successfully!")
463
+ print("Training data uploaded to: /workspace/training_data.json")
464
+ print("Training script uploaded to: /workspace/train.py")
465
+ print("Training log available at: /workspace/training.log")
466
+ print("\nTo monitor progress, you can:")
467
+ print(f" 1. SSH to pod: ssh root@{ssh_ip} -p {ssh_port}")
468
+ print(" 2. View logs: tail -f /workspace/training.log")
469
+
470
+ return True
471
+
472
+ except Exception as e:
473
+ print(f"Error during remote training setup: {e}")
474
+ return False
475
+ finally:
476
+ ssh.close()
477
+ # Clean up temp files
478
+ import os
479
+ try:
480
+ os.unlink(data_file)
481
+ os.unlink(script_file)
482
+ except:
483
+ pass