Jonathan Bejarano
decreasing halluncination
276a68a
raw
history blame
18.5 kB
import gradio as gr
from huggingface_hub import InferenceClient
import re
import random
import os
import requests
from bs4 import BeautifulSoup
from dotenv import load_dotenv
# Load environment variables from .env file if it exists
load_dotenv()
# Check if we're running locally with custom model settings
BASE_URL = os.getenv('BASE_URL')
LOCAL_TOKEN = os.getenv('TOKEN')
LOCAL_MODE = bool(BASE_URL and LOCAL_TOKEN)
MODEL_NAME = os.getenv('MODEL_NAME', 'meta-llama/Llama-3.2-3B-Instruct')
# List of countries for the game with URLs
COUNTRIES = [
{"name": "Algeria", "url": "/algeria-facts-for-kids.html"},
{"name": "Angola", "url": "/facts-about-angola.html"},
{"name": "Argentina", "url": "/argentina-facts.html"},
{"name": "Australia", "url": "/australia-facts.html"},
{"name": "Austria", "url": "/austria-facts.html"},
{"name": "Bahamas", "url": "/facts-about-the-bahamas.html"},
{"name": "Barbados", "url": "/barbados-facts.html"},
{"name": "Belgium", "url": "/belgium-facts.html"},
{"name": "Belize", "url": "/facts-about-belize.html"},
{"name": "Bhutan", "url": "/bhutan-facts.html"},
{"name": "Bolivia", "url": "/bolivia-facts.html"},
{"name": "Botswana", "url": "/facts-about-botswana.html"},
{"name": "Brazil", "url": "/brazil-facts.html"},
{"name": "Bulgaria", "url": "/facts-about-bulgaria.html"},
{"name": "Canada", "url": "/canada-facts-for-kids.html"},
{"name": "Chile", "url": "/chile-facts.html"},
{"name": "China", "url": "/china-facts.html"},
{"name": "Colombia", "url": "/colombia-facts.html"},
{"name": "Comoros", "url": "/comoros-facts.html"},
{"name": "Costa Rica", "url": "/costa-rica-facts.html"},
{"name": "Croatia", "url": "/croatia-facts.html"},
{"name": "Cuba", "url": "/cuba-facts.html"},
{"name": "Cyprus", "url": "/cyprus-facts-for-kids.html"},
{"name": "Denmark", "url": "/denmark-facts.html"},
{"name": "Dominican Republic", "url": "/dominican-republic-facts.html"},
{"name": "Ecuador", "url": "/ecuador-facts.html"},
{"name": "Egypt", "url": "/egypt-facts.html"},
{"name": "Estonia", "url": "/estonia-facts-for-kids.html"},
{"name": "Eswatini", "url": "/eswatini-facts.html"},
{"name": "Ethiopia", "url": "/ethiopia-facts.html"},
{"name": "Fiji", "url": "/facts-about-fiji.html"},
{"name": "Finland", "url": "/finland-facts.html"},
{"name": "France", "url": "/france-facts.html"},
{"name": "Georgia", "url": "/georgia-facts.html"},
{"name": "Germany", "url": "/germany-facts.html"},
{"name": "Ghana", "url": "/ghana-facts.html"},
{"name": "Greece", "url": "/greece-facts.html"},
{"name": "Greenland", "url": "/facts-about-greenland.html"},
{"name": "Guatemala", "url": "/guatemala-facts.html"},
{"name": "Guyana", "url": "/guyana-facts.html"},
{"name": "Honduras", "url": "/honduras-facts-for-kids.html"},
{"name": "Hong Kong", "url": "/facts-about-hong-kong.html"},
{"name": "Hungary", "url": "/hungary-facts.html"},
{"name": "Iceland", "url": "/iceland-facts.html"},
{"name": "India", "url": "/india-for-kids.html"},
{"name": "Indonesia", "url": "/indonesia-facts.html"},
{"name": "Iran", "url": "/iran-facts-for-kids.html"},
{"name": "Ireland", "url": "/ireland-for-kids.html"},
{"name": "Israel", "url": "/israel-facts.html"},
{"name": "Italy", "url": "/italy-facts.html"},
{"name": "Jamaica", "url": "/jamaica-facts.html"},
{"name": "Japan", "url": "/japan-facts.html"},
{"name": "Kenya", "url": "/facts-about-kenya.html"},
{"name": "Kiribati", "url": "/facts-about-kiribati.html"},
{"name": "Latvia", "url": "/latvia-facts-for-kids.html"},
{"name": "Lesotho", "url": "/lesotho-facts.html"},
{"name": "Liberia", "url": "/facts-about-liberia.html"},
{"name": "Lithuania", "url": "/lithuania-facts-for-kids.html"},
{"name": "Luxembourg", "url": "/luxembourg-facts.html"},
{"name": "Macao", "url": "/facts-about-macao.html"},
{"name": "Madagascar", "url": "/facts-about-madagascar.html"},
{"name": "Malaysia", "url": "/malaysia-facts.html"},
{"name": "Maldives", "url": "/maldives-facts.html"},
{"name": "Malta", "url": "/malta-for-kids.html"},
{"name": "Mauritius", "url": "/mauritius-facts.html"},
{"name": "Mexico", "url": "/mexico-facts.html"},
{"name": "Micronesia", "url": "/facts-about-micronesia.html"},
{"name": "Moldova", "url": "/moldova-facts-for-kids.html"},
{"name": "Monaco", "url": "/facts-about-monaco.html"},
{"name": "Morocco", "url": "/morocco-facts.html"},
{"name": "Mozambique", "url": "/mozambique-facts.html"},
{"name": "Myanmar", "url": "/myanmar-facts.html"},
{"name": "Namibia", "url": "/namibia-facts.html"},
{"name": "Nauru", "url": "/facts-about-nauru.html"},
{"name": "Nepal", "url": "/nepal-facts.html"},
{"name": "Netherlands", "url": "/facts-about-the-netherlands.html"},
{"name": "New Zealand", "url": "/new-zealand-facts.html"},
{"name": "Nicaragua", "url": "/nicaragua-facts.html"},
{"name": "Nigeria", "url": "/nigeria-facts.html"},
{"name": "Norway", "url": "/norway-facts.html"},
{"name": "Pakistan", "url": "/pakistan-facts.html"},
{"name": "Panama", "url": "/panama-facts.html"},
{"name": "Papua New Guinea", "url": "/papua-new-guinea.html"},
{"name": "Peru", "url": "/peru-facts.html"},
{"name": "Philippines", "url": "/philippines-facts.html"},
{"name": "Poland", "url": "/poland-facts.html"},
{"name": "Portugal", "url": "/portugal-facts.html"},
{"name": "Puerto Rico", "url": "/facts-about-puerto-rico.html"},
{"name": "Qatar", "url": "/qatar-facts.html"},
{"name": "Romania", "url": "/romania-facts-for-kids.html"},
{"name": "Russia", "url": "/russia-facts.html"},
{"name": "Samoa", "url": "/facts-about-samoa.html"},
{"name": "San Marino", "url": "/facts-about-san-marino.html"},
{"name": "Serbia", "url": "/facts-about-serbia.html"},
{"name": "Seychelles", "url": "/seychelles-facts.html"},
{"name": "Singapore", "url": "/singapore-facts.html"},
{"name": "Solomon Islands", "url": "/solomon-islands-facts.html"},
{"name": "South Africa", "url": "/south-africa-for-kids.html"},
{"name": "South Korea", "url": "/south-korea-facts.html"},
{"name": "Spain", "url": "/spain-facts.html"},
{"name": "Sri Lanka", "url": "/sri-lanka-facts.html"},
{"name": "Suriname", "url": "/suriname-facts.html"},
{"name": "Sweden", "url": "/Sweden-facts.html"},
{"name": "Switzerland", "url": "/switzerland-facts.html"},
{"name": "Taiwan", "url": "/taiwan-facts.html"},
{"name": "Tanzania", "url": "/tanzania-facts.html"},
{"name": "Thailand", "url": "/thailand-facts.html"},
{"name": "Togo", "url": "/togo-facts-for-kids.html"},
{"name": "Tonga", "url": "/facts-about-tonga.html"},
{"name": "Tunisia", "url": "/tunisia-facts.html"},
{"name": "TΓΌrkiye", "url": "/turkey-facts.html"},
{"name": "Tuvalu", "url": "/facts-about-tuvalu.html"},
{"name": "Uganda", "url": "/facts-about-uganda.html"},
{"name": "Ukraine", "url": "/ukraine-for-kids.html"},
{"name": "United Arab Emirates", "url": "/uae-facts.html"},
{"name": "United Kingdom", "url": "/uk-facts.html"},
{"name": "United States of America", "url": "/usa-facts.html"},
{"name": "Uruguay", "url": "/uruguay-facts.html"},
{"name": "Vanuatu", "url": "/facts-about-vanuatu.html"},
{"name": "Venezuela", "url": "/venezuela-for-kids.html"},
{"name": "Vietnam", "url": "/vietnam-facts.html"},
{"name": "Zambia", "url": "/zambia-facts.html"}
]
def fetch_country_facts(country_url):
"""Fetch facts about a country from the Kids World Travel Guide website"""
base_url = "https://www.kids-world-travel-guide.com"
full_url = base_url + country_url
try:
response = requests.get(full_url, timeout=10)
response.raise_for_status()
# Parse the HTML content
soup = BeautifulSoup(response.content, 'html.parser')
# Extract relevant facts - looking for common patterns in the website
facts = []
# Look for fact sections, lists, and key information
# This is a basic parser - you might need to adjust based on the actual HTML structure
# Try to find paragraphs with factual content
paragraphs = soup.find_all('p')
for p in paragraphs[:10]: # Limit to first 10 paragraphs to avoid too much content
text = p.get_text().strip()
if len(text) > 50 and not text.startswith('Related'): # Filter out short texts and navigation
facts.append(text)
# Look for list items that might contain facts
list_items = soup.find_all('li')
for li in list_items[:15]: # Limit to avoid too much content
text = li.get_text().strip()
if len(text) > 20 and len(text) < 200: # Filter for reasonable fact lengths
facts.append(text)
# Join facts with newlines, limit total length
facts_text = '\n'.join(facts[:10]) # Limit to 10 facts
# Truncate if too long to avoid token limits
if len(facts_text) > 2000:
facts_text = facts_text[:2000] + "..."
return facts_text
except Exception as e:
print(f"Error fetching facts for {country_url}: {str(e)}")
return "Unable to fetch additional facts about this country."
def get_system_message_with_country():
"""Generate a system message with a randomly selected country"""
global selected_country, selected_country_dict
selected_country_dict = random.choice(COUNTRIES)
selected_country = selected_country_dict["name"]
# Fetch facts about the selected country
print(f"Selected country for this session: {selected_country}")
print(f"Fetching facts from: {selected_country_dict['url']}")
country_facts = fetch_country_facts(selected_country_dict["url"])
return f"""You are a friendly geography game host playing 20 questions with students. You are thinking of the country: {selected_country}
COUNTRY FACTS (use these to answer questions accurately - DO NOT reveal the country name):
{country_facts}
RULES:
1. NEVER reveal the country name ({selected_country}) in your responses
2. Answer only 'Yes' or 'No' to their questions
3. Keep track of how many questions they've asked
4. When they correctly guess or ask if it is {selected_country}, respond with: 'Congratulations! The country was <<{selected_country}>>'
5. If they reach 20 questions without guessing correctly, respond with: 'Game over! The country was <<{selected_country}>>'
6. Be encouraging and give helpful hints through your yes/no answers
7. If they want to play again tell them they need to reload the page.
8. IMPORTANT: Only accept the country name "{selected_country}" as correct, but Spelling is not important and they can ask a question like it is? Do NOT accept neighboring countries, similar countries, or regions that contain this country.
9. If they guess a neighboring country or similar country, respond with "No" and continue the game.
10. Be very strict about the exact country match - only "{selected_country}" is the correct answer.
11. Use the COUNTRY FACTS above to provide accurate yes/no answers - do not make up information."""
current_system = get_system_message_with_country()
def format_game_result(response):
"""Format the game result with proper styling"""
if "The country was" in response:
print(f"πŸ” DEBUG - Game end detected! Country extracted: {selected_country}")
else:
print("πŸ” DEBUG - Regular response (no game end)")
if "Congratulations" in response:
return f"πŸŽ‰ **Congratulations!** You correctly guessed **{selected_country}**! Well done! πŸŽ‰\n\nTo play another round, please start a new conversation or reload the page."
elif "Game over" in response:
return f"πŸ˜” **Game Over!** You've used all 20 questions. The country I was thinking of was **{selected_country}**. πŸ˜”\n\nTo try again, please start a new conversation or reload the page."
return response
def respond(
message,
history: list[dict[str, str]],
system_message,
max_tokens,
temperature,
top_p,
hf_token: gr.OAuthToken | None = None,
):
"""
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
"""
global current_system
# If this is the start of a new conversation (empty history), generate a new country
if not history:
current_system = get_system_message_with_country()
print(f"πŸ” DEBUG - New session started, selected country: {selected_country}")
messages = [{"role": "system", "content": current_system}]
messages.extend(history)
messages.append({"role": "user", "content": message})
# Debug: Calculate approximate input token count
total_input_chars = sum(len(str(msg.get("content", ""))) for msg in messages)
estimated_input_tokens = total_input_chars // 4 # Rough approximation: 4 chars per token
print(f"πŸ” DEBUG - Estimated input tokens: {estimated_input_tokens}")
print(f"πŸ” DEBUG - Messages count: {len(messages)}")
print(f"πŸ” DEBUG - Max tokens setting: {max_tokens}")
# Debug: Show each message type and length
for i, msg in enumerate(messages):
role = msg.get("role", "unknown")
content = str(msg.get("content", ""))
print(f"πŸ” DEBUG - Message {i+1} ({role}): {len(content)} chars")
if role == "system":
print(f"πŸ” DEBUG - System message preview: {content[:100]}...")
elif role == "user":
print(f"πŸ” DEBUG - User message: {content}")
elif role == "assistant":
print(f"πŸ” DEBUG - Assistant message: {content[:50]}...")
# Choose client based on whether we're running locally or in the cloud
if LOCAL_MODE:
# Running locally with custom model settings
try:
# Use local inference server
client = InferenceClient(model=BASE_URL, token=LOCAL_TOKEN)
except Exception as e:
return f"Error connecting to local model: {str(e)}"
else:
# Running in cloud mode with HuggingFace
if not hf_token or not hf_token.token:
return "Please log in with your HuggingFace account to play the geography game!"
client = InferenceClient(token=hf_token.token, model=MODEL_NAME)
response = ""
output_token_count = 0
try:
for message_chunk in client.chat_completion(
messages,
model=MODEL_NAME,
max_tokens=max_tokens,
stream=True,
temperature=temperature,
top_p=top_p,
):
choices = message_chunk.choices
token = ""
if len(choices) and choices[0].delta.content:
token = choices[0].delta.content
output_token_count += 1
response += token
# Debug: Show output token statistics
estimated_output_tokens = len(response) // 4 # Rough approximation
print(f"πŸ” DEBUG - Output token chunks received: {output_token_count}")
print(f"πŸ” DEBUG - Estimated output tokens (by chars): {estimated_output_tokens}")
print(f"πŸ” DEBUG - Response length: {len(response)} characters")
print(f"πŸ” DEBUG - Raw response: {response}")
# Check if this is a game end response and format it nicely
if "The country was" in response:
print(f"πŸ” DEBUG - Game end detected! Country extracted: {selected_country}")
return format_game_result(response)
else:
print("πŸ” DEBUG - Regular response (no game end)")
return response
except Exception as e:
return f"Error during inference: {str(e)}"
"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
# Create description based on mode
if LOCAL_MODE:
description = f"🏠 Running locally with {MODEL_NAME}. I am thinking of a country, you have 20 yes or no questions to ask me to help you figure out what the country is"
else:
description = "I am thinking of a country, you have 20 yes or no questions to ask me to help you figure out what the country is"
# Common examples and settings
examples = [
["Is the country located in Europe?"],
["Is it in the Northern Hemisphere?"],
["Is the official language Spanish?"],
["Is the capital city Rome?"],
["Is this country bordered by an ocean?"],
["Does this country have more than 100 million people?"],
["Is this country known for producing coffee?"],
["Was this country ever a colony of the United Kingdom?"],
["Is this country located on an island?"],
["Is the currency the Euro?"],
]
# Create wrapper function that handles both local and cloud modes
if LOCAL_MODE:
# Local mode - no OAuth needed
def custom_respond(message, history):
system_message = ""
max_tokens = 2048
temperature = 0.3
top_p = 0.6
return respond(message, history, system_message, max_tokens, temperature, top_p, None)
chatbot = gr.ChatInterface(
custom_respond,
type="messages",
description=description,
examples=examples,
cache_examples=False,
)
else:
# Cloud mode - use OAuth
chatbot = gr.ChatInterface(
respond,
type="messages",
description=description,
examples=examples,
cache_examples=False,
additional_inputs=[
gr.Textbox(value="", visible=False), # system_message (hidden)
gr.Slider(minimum=1, maximum=4096, value=2048, visible=False), # max_tokens (hidden)
gr.Slider(minimum=0.1, maximum=2.0, value=0.3, visible=False), # temperature (hidden)
gr.Slider(minimum=0.1, maximum=1.0, value=0.6, visible=False), # top_p (hidden)
],
)
with gr.Blocks() as demo:
if not LOCAL_MODE:
# Only show login button when running in cloud mode
with gr.Sidebar():
gr.LoginButton()
chatbot.render()
if __name__ == "__main__":
demo.launch()