| # Use CUDA-enabled PyTorch base image | |
| FROM pytorch/pytorch:2.1.0-cuda12.1-cudnn8-runtime | |
| # Set working directory | |
| WORKDIR /app | |
| # Install system dependencies | |
| RUN apt-get update && apt-get install -y \ | |
| git \ | |
| && rm -rf /var/lib/apt/lists/* | |
| # Copy requirements file | |
| COPY requirements.txt . | |
| # Install Python dependencies | |
| RUN pip install --no-cache-dir -r requirements.txt | |
| # Copy project files | |
| COPY . . | |
| # Create directories for data and models | |
| RUN mkdir -p dataset/final_balanced weights | |
| # Set environment variables | |
| ENV PYTHONPATH=/app | |
| ENV WANDB_API_KEY="" | |
| # Default command to run training | |
| CMD ["python", "model/train.py"] |