evalstate HF Staff commited on
Commit
b6a400a
Β·
verified Β·
1 Parent(s): 2638a07

Upload train_medium.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_medium.py +146 -0
train_medium.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # /// script
3
+ # dependencies = [
4
+ # "trl>=0.12.0",
5
+ # "peft>=0.7.0",
6
+ # "transformers>=4.36.0",
7
+ # "accelerate>=0.24.0",
8
+ # "trackio",
9
+ # ]
10
+ # ///
11
+
12
+ from datasets import load_dataset
13
+ from peft import LoraConfig
14
+ from trl import SFTTrainer, SFTConfig
15
+ import trackio
16
+ import os
17
+
18
+ print("πŸš€ Medium-Scale SFT Training with Trackio")
19
+ print("=" * 60)
20
+
21
+ # Initialize Trackio with Space sync
22
+ print("\nπŸ“Š Initializing Trackio...")
23
+ trackio.init(
24
+ project="medium-sft-training",
25
+ space_id="evalstate/trl-trackio-dashboard",
26
+ config={
27
+ "model": "Qwen/Qwen2.5-0.5B",
28
+ "dataset": "trl-lib/Capybara",
29
+ "dataset_size": 1000,
30
+ "num_epochs": 3,
31
+ "learning_rate": 2e-5,
32
+ "batch_size": 4,
33
+ "gradient_accumulation": 4,
34
+ "lora_r": 16,
35
+ "lora_alpha": 32,
36
+ "hardware": "a10g-large",
37
+ }
38
+ )
39
+ print("βœ… Trackio initialized!")
40
+ print("πŸ“ˆ Dashboard: https://huggingface.co/spaces/evalstate/trl-trackio-dashboard")
41
+
42
+ # Load dataset - 1000 examples
43
+ print("\nπŸ“Š Loading dataset...")
44
+ dataset = load_dataset("trl-lib/Capybara", split="train[:1000]")
45
+ print(f"βœ… Dataset loaded: {len(dataset)} examples")
46
+
47
+ # Get username
48
+ username = os.environ.get("HF_USERNAME", "evalstate")
49
+
50
+ # Training configuration - production settings
51
+ print("\nβš™οΈ Configuring training...")
52
+ config = SFTConfig(
53
+ # Output and Hub settings
54
+ output_dir="qwen-capybara-medium",
55
+ push_to_hub=True,
56
+ hub_model_id=f"{username}/qwen-capybara-medium",
57
+ hub_strategy="every_save", # Push all checkpoints
58
+
59
+ # Training parameters - 3 epochs on 1K examples
60
+ num_train_epochs=3,
61
+ per_device_train_batch_size=4,
62
+ gradient_accumulation_steps=4, # Effective batch size = 16
63
+
64
+ # Learning rate and schedule
65
+ learning_rate=2e-5,
66
+ warmup_ratio=0.1,
67
+ lr_scheduler_type="cosine",
68
+
69
+ # Logging and checkpointing
70
+ logging_steps=10, # Log every 10 steps
71
+ save_strategy="steps",
72
+ save_steps=50, # Save every 50 steps
73
+ save_total_limit=3, # Keep only 3 latest checkpoints
74
+
75
+ # Evaluation
76
+ eval_strategy="steps",
77
+ eval_steps=50,
78
+
79
+ # Optimization
80
+ bf16=True, # Use bfloat16 for A10G
81
+ gradient_checkpointing=True, # Save memory
82
+
83
+ # Trackio monitoring
84
+ report_to="trackio",
85
+ )
86
+
87
+ # LoRA configuration - larger than demo
88
+ print("πŸ”§ Setting up LoRA (r=16)...")
89
+ peft_config = LoraConfig(
90
+ r=16,
91
+ lora_alpha=32,
92
+ lora_dropout=0.05,
93
+ bias="none",
94
+ task_type="CAUSAL_LM",
95
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], # More modules
96
+ )
97
+
98
+ # Create eval split
99
+ print("\nπŸ”€ Creating train/eval split...")
100
+ dataset_split = dataset.train_test_split(test_size=0.1, seed=42)
101
+ train_dataset = dataset_split["train"]
102
+ eval_dataset = dataset_split["test"]
103
+ print(f" Train: {len(train_dataset)} examples")
104
+ print(f" Eval: {len(eval_dataset)} examples")
105
+
106
+ # Initialize trainer
107
+ print("\n🎯 Initializing trainer...")
108
+ trainer = SFTTrainer(
109
+ model="Qwen/Qwen2.5-0.5B",
110
+ train_dataset=train_dataset,
111
+ eval_dataset=eval_dataset,
112
+ args=config,
113
+ peft_config=peft_config,
114
+ )
115
+
116
+ # Calculate training info
117
+ total_steps = (len(train_dataset) // (4 * 4)) * 3 # samples / (batch * grad_accum) * epochs
118
+ print(f"\nπŸ“Š Training Info:")
119
+ print(f" Total steps: ~{total_steps}")
120
+ print(f" Epochs: 3")
121
+ print(f" Effective batch size: 16")
122
+ print(f" Expected time: ~45-60 minutes")
123
+ print(f" Checkpoints saved every 50 steps")
124
+
125
+ # Train!
126
+ print("\nπŸƒ Starting training...")
127
+ print("πŸ“ˆ Watch live metrics: https://huggingface.co/spaces/evalstate/trl-trackio-dashboard")
128
+ print("-" * 60)
129
+ trainer.train()
130
+
131
+ # Save to Hub
132
+ print("\nπŸ’Ύ Pushing final model to Hub...")
133
+ trainer.push_to_hub()
134
+
135
+ # Finish Trackio
136
+ print("\nπŸ“Š Finalizing Trackio metrics...")
137
+ trackio.finish()
138
+
139
+ print("\n" + "=" * 60)
140
+ print("βœ… Training complete!")
141
+ print(f"πŸ“¦ Model: https://huggingface.co/{username}/qwen-capybara-medium")
142
+ print(f"πŸ“Š Metrics: https://huggingface.co/spaces/evalstate/trl-trackio-dashboard")
143
+ print(f"πŸ’‘ Try the model with:")
144
+ print(f' from transformers import pipeline')
145
+ print(f' generator = pipeline("text-generation", model="{username}/qwen-capybara-medium")')
146
+ print("=" * 60)