| | import os |
| | import time |
| | from typing import Optional |
| |
|
| | from dotenv import load_dotenv |
| | from huggingface_hub import InferenceClient |
| |
|
| | load_dotenv() |
| |
|
| | class HuggingFaceInferenceAPI: |
| | """ |
| | Manages interactions with Hugging Face Inference API using the official InferenceClient. |
| | """ |
| | def __init__(self, api_token: Optional[str] = None): |
| | self.api_token = api_token or os.environ.get("HF_TOKEN") |
| | if not self.api_token: |
| | raise ValueError( |
| | "HF_TOKEN not found. Please set the HF_TOKEN environment variable or pass it as an argument." |
| | ) |
| | |
| | self.client = InferenceClient( |
| | provider="auto", |
| | api_key=self.api_token |
| | ) |
| | self.model = "meta-llama/Llama-3.2-3B-Instruct" |
| | |
| | |
| |
|
| | def _generate_text(self, prompt: str, max_tokens: int = 200) -> str: |
| | """ |
| | Generate text using the InferenceClient with retry logic. |
| | |
| | Args: |
| | prompt: The input prompt. |
| | max_tokens: Maximum tokens to generate. |
| | |
| | Returns: |
| | The generated text. |
| | """ |
| | try: |
| | |
| | response = self.client.chat.completions.create( |
| | model=self.model, |
| | messages=[{"role": "user", "content": prompt}], |
| | max_tokens=max_tokens, |
| | stream=False |
| | ) |
| | |
| | return response.choices[0].message.content |
| | except Exception as e: |
| | print(f"Error: {e}") |
| | return f"Error generating response: {e}" |
| |
|
| | def moderate_query(self, query: str) -> bool: |
| | |
| | |
| |
|
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | """ |
| | Moderates a query using a stable, high-availability model (Qwen 2.5). |
| | """ |
| | |
| | moderator_model = "Qwen/Qwen2.5-7B-Instruct" |
| | |
| | moderation_prompt = f"""<|im_start|>system |
| | You are a content moderator. Your job is to classify if a user query is SAFE or UNSAFE. |
| | - SAFE: General questions, product inquiries, electronics, store help, or friendly chat. |
| | - UNSAFE: Hate speech, violence, illegal acts, or sexual content. |
| | Respond with ONLY the word 'SAFE' or 'UNSAFE'.<|im_end|> |
| | <|im_start|>user |
| | {query}<|im_end|> |
| | <|im_start|>assistant""" |
| |
|
| | try: |
| | print(f"Sending moderation request to {moderator_model}...") |
| | response = self.client.chat.completions.create( |
| | model=moderator_model, |
| | messages=[{"role": "user", "content": moderation_prompt}], |
| | max_tokens=5, |
| | ) |
| | |
| | result = response.choices[0].message.content.strip().upper() |
| | print(f"Moderation result: {result}") |
| | |
| | return "UNSAFE" not in result |
| |
|
| | except Exception as e: |
| | |
| | print(f"Moderation API Error: {repr(e)}") |
| | |
| | return True |
| | |
| | |
| | |
| | def generate_response(self, query: str, system_prompt: str) -> str: |
| | """ |
| | Generates a response using Mistral-7B-Instruct via Hugging Face Inference API. |
| | |
| | Args: |
| | query: The user's query. |
| | system_prompt: The system prompt with context and instructions. |
| | |
| | Returns: |
| | The generated response. |
| | """ |
| | try: |
| | messages = [ |
| | {"role": "system", "content": system_prompt}, |
| | {"role": "user", "content": query}, |
| | ] |
| | |
| | |
| | formatted_messages = "\n".join( |
| | [f"<s>[INST] {m['content']} [/INST]" if m["role"] == "user" |
| | else f"{m['content']}" for m in messages] |
| | ) |
| | |
| | response = self._generate_text(formatted_messages, max_tokens=500) |
| | return response.strip() |
| | except Exception as e: |
| | print(f"Error during response generation: {e}") |
| | return "I'm sorry, but I encountered an error while trying to generate a response." |
| |
|
| | |
| | def rewrite_query(self, query: str, system_prompt: str) -> str: |
| | """ |
| | Rewrites a query using Mistral-7B-Instruct via Hugging Face Inference API. |
| | |
| | Args: |
| | query: The user's query. |
| | system_prompt: The system prompt with instructions. |
| | |
| | Returns: |
| | The rewritten query. |
| | """ |
| | try: |
| | messages = [ |
| | {"role": "system", "content": system_prompt}, |
| | {"role": "user", "content": f"User query: '{query}'"}, |
| | ] |
| | |
| | |
| | formatted_messages = "\n".join( |
| | [f"<s>[INST] {m['content']} [/INST]" if m["role"] == "user" |
| | else f"{m['content']}" for m in messages] |
| | ) |
| | |
| | response = self._generate_text(formatted_messages, max_tokens=200) |
| | rewritten = response.strip() |
| | |
| | |
| | if rewritten.startswith('"') and rewritten.endswith('"'): |
| | rewritten = rewritten[1:-1] |
| | if rewritten.startswith("'") and rewritten.endswith("'"): |
| | rewritten = rewritten[1:-1] |
| | |
| | return rewritten |
| | except Exception as e: |
| | print(f"Error during query rewrite: {e}") |
| | return query |
| |
|
| |
|
| | |
| | _api_client = None |
| |
|
| | def get_api_client() -> HuggingFaceInferenceAPI: |
| | """Get or initialize the Hugging Face Inference API client.""" |
| | global _api_client |
| | if _api_client is None: |
| | _api_client = HuggingFaceInferenceAPI() |
| | return _api_client |
| |
|
| |
|
| | def moderate_query(query: str) -> bool: |
| | """ |
| | Moderates a query using Qwen via Hugging Face Inference API. |
| | |
| | Args: |
| | query: The user's query. |
| | |
| | Returns: |
| | True if the query is safe, False otherwise. |
| | """ |
| | print("Moderating query...") |
| | client = get_api_client() |
| | return client.moderate_query(query) |
| |
|
| | def generate_response(query: str, retrieved_docs: list, history: list) -> str: |
| | """ |
| | Generates a response using Llama-3.2-3B-Instruct via Hugging Face Inference API, |
| | ensuring it adheres to the retrieved documents. |
| | |
| | Args: |
| | query: The user's query. |
| | retrieved_docs: A list of document contents. |
| | history: The chat history from Gradio. |
| | |
| | Returns: |
| | The generated response. |
| | """ |
| | system_prompt = """You are a specialized product inquiry assistant. \ |
| | Your primary and ONLY role is to answer user questions based on \ |
| | the 'Retrieved Documents' provided below. |
| | |
| | Follow these rules strictly: |
| | 1. Base your entire response on the information found within the 'Retrieved Documents'. \ |
| | Do not use any external knowledge. |
| | 2. If there are no documents or \ |
| | the documents do not contain the information needed to answer the query, \ |
| | you MUST respond with: \"I'm sorry, but I cannot answer your question with the information I have.\" |
| | 3. If the documents contain relavant information, use it to construct a clear and concise answer. |
| | The documents may include metadata such as price, product name, brand, and category. |
| | The documents may also include product descriptions and features. |
| | The documents may include customer reviews which can be used to answer questions \ |
| | about product quality and user satisfaction. |
| | 4. Some documents may not be fully relevant; \ |
| | carefully select and synthesize information only from the relevant parts. |
| | 5. Do not fabricate or assume any information not present in the documents. |
| | 6. Analyze the chat history provided under 'Chat History' for conversational context, \ |
| | but do not use it as a source for answers. |
| | 7. Respond in a friendly and helpful tone, with concise answers and directly related to the query.\ |
| | 8. Make sure to ask the user relevant follow-up questions.\ |
| | 9. Always format prices with a dollar sign and two decimal places.\ |
| | 10. Do not use the term 'Retrieved Documents' in your response. It is only for your reference. |
| | |
| | |
| | Retrieved Documents: |
| | ``` |
| | {context} |
| | ``` |
| | |
| | Chat History: |
| | {chat_history} |
| | """ |
| |
|
| | context = "\n\n---\n\n".join(doc for doc in retrieved_docs) |
| | |
| | |
| | |
| | formatted_history = "" |
| | for msg in history: |
| | if msg["role"] == "user": |
| | formatted_history += f"User: {msg['content']}\n" |
| | elif msg["role"] == "assistant": |
| | formatted_history += f"Assistant: {msg['content']}\n" |
| |
|
| | prompt = system_prompt.format(context=context, chat_history=formatted_history) |
| |
|
| | client = get_api_client() |
| | return client.generate_response(query, prompt) |
| |
|
| |
|
| |
|
| | def rewrite_query(query: str, history: list) -> str: |
| | """ |
| | Rewrites a conversational query into a self-contained query using the chat history |
| | via Hugging Face Inference API. |
| | |
| | Args: |
| | query: The user's potentially vague query. |
| | history: The chat history from Gradio. |
| | |
| | Returns: |
| | A self-contained query. |
| | """ |
| | system_prompt = """You are an expert at query rewriting. Your task is to rewrite a given 'user query' \ |
| | into a self-contained, specific query that can be understood without the context of the 'chat history'. |
| | |
| | Follow these rules strictly: |
| | 1. Analyze the 'chat history' to understand the context of the conversation. |
| | 2. Identify any pronouns (e.g., 'it', 'its', 'they', 'that') or vague references in the 'user query'. |
| | 3. Replace these pronouns and vague references with the specific entities or topics they refer to from the chat history. |
| | 4. If the 'user query' is already self-contained and specific, return it unchanged. |
| | 5. CRITICAL: If the 'user query' is about a completely new topic not covered in the chat history, \ |
| | you MUST return it unchanged. Do NOT try to connect it to the previous conversation. |
| | 6. The rewritten query should be a single, clear question or statement. |
| | 7. Output ONLY the rewritten query, with no extra text, labels, or explanations. |
| | |
| | Here are some examples of how to behave: |
| | |
| | --- |
| | Example 1: Rewriting a contextual query |
| | Chat History: |
| | User: Do you have the TechPro Ultrabook in stock? |
| | Assistant: Yes, the TechPro Ultrabook (TP-UB100) is available. |
| | User query: 'Tell me about its warranty.' |
| | Rewritten query: 'What is the warranty for the TechPro Ultrabook (TP-UB100)?' |
| | --- |
| | Example 2: Handling a topic change |
| | Chat History: |
| | User: Do you have the TechPro Ultrabook in stock? |
| | Assistant: Yes, the TechPro Ultrabook (TP-UB100) is available. |
| | User query: 'Okay, do you have any monitors?' |
| | Rewritten query: 'Okay, do you have any monitors?' |
| | --- |
| | Example 3: Handling a self-contained query |
| | Chat History: |
| | User: What's the price of the BlueWave Gaming Laptop? |
| | Assistant: The BlueWave Gaming Laptop (BW-GL200) is $1299.99. |
| | User query: 'What is the price of the GameSphere X console?' |
| | Rewritten query: 'What is the price of the GameSphere X console?' |
| | --- |
| | |
| | Chat History: |
| | {chat_history} |
| | """ |
| |
|
| | |
| | |
| | formatted_history = "" |
| | for msg in history: |
| | if msg["role"] == "user": |
| | formatted_history += f"User: {msg['content']}\n" |
| | elif msg["role"] == "assistant": |
| | formatted_history += f"Assistant: {msg['content']}\n" |
| |
|
| | prompt = system_prompt.format(chat_history=formatted_history) |
| | |
| | client = get_api_client() |
| | return client.rewrite_query(query, prompt) |
| |
|