import torch from transformers import AutoModelForCausalLM, AutoTokenizer import gradio as gr import spaces import os # Model configuration MODEL_PATH = "ibm-granite/granite-4.0-h-1b" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # Global variables to store model and tokenizer tokenizer = None model = None def load_model(): """Load the model and tokenizer""" global tokenizer, model if tokenizer is None or model is None: print("Loading model and tokenizer...") tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, device_map=DEVICE) model.eval() print("Model loaded successfully!") @spaces.GPU # Use GPU for inference def chat_with_model(message, history): """ Chat function that processes user input and generates responses Args: message (str): Current user message history (list): Previous conversation history Returns: str: Model response """ try: # Load model if not already loaded load_model() # Prepare chat format messages = [] # Add system message for better performance messages.append({ "role": "system", "content": "You are a helpful AI assistant. Provide clear, accurate, and helpful responses." }) # Add conversation history for user_msg, assistant_msg in history: messages.append({"role": "user", "content": user_msg}) messages.append({"role": "assistant", "content": assistant_msg}) # Add current message messages.append({"role": "user", "content": message}) # Apply chat template chat = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) # Tokenize input input_tokens = tokenizer(chat, return_tensors="pt").to(DEVICE) # Generate response with torch.no_grad(): output = model.generate( **input_tokens, max_new_tokens=200, temperature=0.7, do_sample=True, pad_token_id=tokenizer.eos_token_id ) # Decode response full_response = tokenizer.batch_decode(output)[0] # Extract only the assistant's response # Find the start of assistant role assistant_start = full_response.find('<|start_of_role|>assistant<|end_of_role|>') if assistant_start != -1: assistant_start += len('<|start_of_role|>assistant<|end_of_role|>') assistant_response = full_response[assistant_start:].strip() else: # Fallback to original method if pattern not found response_start = full_response.find('<|assistant|>') if response_start != -1: response_start += len('<|assistant|>') assistant_response = full_response[response_start:].strip() else: assistant_response = full_response.strip() # Clean up the response - remove end markers assistant_response = assistant_response.replace('<|endoftext|>', '').replace('<|end_of_text|>', '').strip() return assistant_response except Exception as e: print(f"Error generating response: {e}") return f"I apologize, but I encountered an error: {str(e)}. Please try again." def clear_chat(): """Clear the chat history""" return [] # Create the Gradio chat interface def create_chat_app(): with gr.Blocks(title="IBM Granite Chat", css=""" .header { text-align: center; padding: 10px; background: linear-gradient(90deg, #0066cc, #004499); color: white; margin-bottom: 20px; border-radius: 10px; } .header a { color: #ffffff; text-decoration: none; font-weight: bold; } .header a:hover { text-decoration: underline; } """) as demo: # Header with attribution gr.HTML("""
Powered by Built with anycoder