File size: 6,007 Bytes
3e1a79e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4b82591
 
 
 
 
 
 
 
 
 
 
 
 
3e1a79e
4b82591
 
3e1a79e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b03dbd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import gradio as gr
import spaces
import os

# Model configuration
MODEL_PATH = "ibm-granite/granite-4.0-h-1b"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Global variables to store model and tokenizer
tokenizer = None
model = None

def load_model():
    """Load the model and tokenizer"""
    global tokenizer, model
    
    if tokenizer is None or model is None:
        print("Loading model and tokenizer...")
        tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
        model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, device_map=DEVICE)
        model.eval()
        print("Model loaded successfully!")

@spaces.GPU  # Use GPU for inference
def chat_with_model(message, history):
    """
    Chat function that processes user input and generates responses
    
    Args:
        message (str): Current user message
        history (list): Previous conversation history
        
    Returns:
        str: Model response
    """
    try:
        # Load model if not already loaded
        load_model()
        
        # Prepare chat format
        messages = []
        
        # Add system message for better performance
        messages.append({
            "role": "system", 
            "content": "You are a helpful AI assistant. Provide clear, accurate, and helpful responses."
        })
        
        # Add conversation history
        for user_msg, assistant_msg in history:
            messages.append({"role": "user", "content": user_msg})
            messages.append({"role": "assistant", "content": assistant_msg})
        
        # Add current message
        messages.append({"role": "user", "content": message})
        
        # Apply chat template
        chat = tokenizer.apply_chat_template(
            messages, 
            tokenize=False, 
            add_generation_prompt=True
        )
        
        # Tokenize input
        input_tokens = tokenizer(chat, return_tensors="pt").to(DEVICE)
        
        # Generate response
        with torch.no_grad():
            output = model.generate(
                **input_tokens,
                max_new_tokens=200,
                temperature=0.7,
                do_sample=True,
                pad_token_id=tokenizer.eos_token_id
            )
        
        # Decode response
        full_response = tokenizer.batch_decode(output)[0]
        
        # Extract only the assistant's response
        # Find the start of assistant role
        assistant_start = full_response.find('<|start_of_role|>assistant<|end_of_role|>')
        if assistant_start != -1:
            assistant_start += len('<|start_of_role|>assistant<|end_of_role|>')
            assistant_response = full_response[assistant_start:].strip()
        else:
            # Fallback to original method if pattern not found
            response_start = full_response.find('<|assistant|>')
            if response_start != -1:
                response_start += len('<|assistant|>')
                assistant_response = full_response[response_start:].strip()
            else:
                assistant_response = full_response.strip()
        
        # Clean up the response - remove end markers
        assistant_response = assistant_response.replace('<|endoftext|>', '').replace('<|end_of_text|>', '').strip()
        
        return assistant_response
        
    except Exception as e:
        print(f"Error generating response: {e}")
        return f"I apologize, but I encountered an error: {str(e)}. Please try again."

def clear_chat():
    """Clear the chat history"""
    return []

# Create the Gradio chat interface
def create_chat_app():
    with gr.Blocks(title="IBM Granite Chat", css="""
        .header {
            text-align: center;
            padding: 10px;
            background: linear-gradient(90deg, #0066cc, #004499);
            color: white;
            margin-bottom: 20px;
            border-radius: 10px;
        }
        .header a {
            color: #ffffff;
            text-decoration: none;
            font-weight: bold;
        }
        .header a:hover {
            text-decoration: underline;
        }
    """) as demo:
        
        # Header with attribution
        gr.HTML("""
            <div class="header">
                <h1>IBM Granite 4.0 Chat</h1>
                <p>Powered by <a href="https://huggingface.co/spaces/akhaliq/anycoder" target="_blank">Built with anycoder</a></p>
            </div>
        """)
        
        # Chat interface
        chatbot = gr.ChatInterface(
            fn=chat_with_model,
            title="Chat with IBM Granite 4.0",
            description="Chat with the IBM Granite 4.0 1B parameter language model. Ask questions, get help, or have a conversation!",
            examples=[
                "What is machine learning?",
                "Explain quantum computing in simple terms",
                "How can I improve my programming skills?",
                "What are the latest developments in AI?",
                "Tell me about IBM Research"
            ],
        )
        
        # Additional info
        with gr.Accordion("Model Information", open=False):
            gr.Markdown(f"""
            ## Model Details
            - **Model**: {MODEL_PATH}
            - **Parameters**: 1B
            - **Device**: {DEVICE.upper()}
            - **Max Tokens**: 200 per response
            - **Temperature**: 0.7 (for balanced creativity and accuracy)
            
            ## Tips
            - Ask specific questions for better results
            - The model works best with clear, concise prompts
            - Try asking follow-up questions to dive deeper into topics
            - The model can help with programming, explanations, and general knowledge
            """)
    
    return demo

if __name__ == "__main__":
    # Create and launch the app
    app = create_chat_app()
    
    # Launch configuration
    app.launch()