Spaces:
Paused
Paused
| import json | |
| import logging | |
| import os | |
| import shutil | |
| from datetime import datetime | |
| from pathlib import Path | |
| from typing import Generic, Optional, TypeVar | |
| from urllib.parse import urlparse | |
| import chromadb | |
| import requests | |
| import yaml | |
| from open_webui.apps.webui.internal.db import Base, get_db | |
| from open_webui.env import ( | |
| OPEN_WEBUI_DIR, | |
| DATA_DIR, | |
| ENV, | |
| FRONTEND_BUILD_DIR, | |
| WEBUI_AUTH, | |
| WEBUI_FAVICON_URL, | |
| WEBUI_NAME, | |
| log, | |
| ) | |
| from pydantic import BaseModel | |
| from sqlalchemy import JSON, Column, DateTime, Integer, func | |
| class EndpointFilter(logging.Filter): | |
| def filter(self, record: logging.LogRecord) -> bool: | |
| return record.getMessage().find("/health") == -1 | |
| # Filter out /endpoint | |
| logging.getLogger("uvicorn.access").addFilter(EndpointFilter()) | |
| #################################### | |
| # Config helpers | |
| #################################### | |
| # Function to run the alembic migrations | |
| def run_migrations(): | |
| print("Running migrations") | |
| try: | |
| from alembic import command | |
| from alembic.config import Config | |
| alembic_cfg = Config(OPEN_WEBUI_DIR / "alembic.ini") | |
| # Set the script location dynamically | |
| migrations_path = OPEN_WEBUI_DIR / "migrations" | |
| alembic_cfg.set_main_option("script_location", str(migrations_path)) | |
| command.upgrade(alembic_cfg, "head") | |
| except Exception as e: | |
| print(f"Error: {e}") | |
| run_migrations() | |
| class Config(Base): | |
| __tablename__ = "config" | |
| id = Column(Integer, primary_key=True) | |
| data = Column(JSON, nullable=False) | |
| version = Column(Integer, nullable=False, default=0) | |
| created_at = Column(DateTime, nullable=False, server_default=func.now()) | |
| updated_at = Column(DateTime, nullable=True, onupdate=func.now()) | |
| def load_json_config(): | |
| with open(f"{DATA_DIR}/config.json", "r") as file: | |
| return json.load(file) | |
| def save_to_db(data): | |
| with get_db() as db: | |
| existing_config = db.query(Config).first() | |
| if not existing_config: | |
| new_config = Config(data=data, version=0) | |
| db.add(new_config) | |
| else: | |
| existing_config.data = data | |
| existing_config.updated_at = datetime.now() | |
| db.add(existing_config) | |
| db.commit() | |
| def reset_config(): | |
| with get_db() as db: | |
| db.query(Config).delete() | |
| db.commit() | |
| # When initializing, check if config.json exists and migrate it to the database | |
| if os.path.exists(f"{DATA_DIR}/config.json"): | |
| data = load_json_config() | |
| save_to_db(data) | |
| os.rename(f"{DATA_DIR}/config.json", f"{DATA_DIR}/old_config.json") | |
| DEFAULT_CONFIG = { | |
| "version": 0, | |
| "ui": { | |
| "default_locale": "", | |
| "prompt_suggestions": [ | |
| { | |
| "title": [ | |
| "Help me study", | |
| "vocabulary for a college entrance exam", | |
| ], | |
| "content": "Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option.", | |
| }, | |
| { | |
| "title": [ | |
| "Give me ideas", | |
| "for what to do with my kids' art", | |
| ], | |
| "content": "What are 5 creative things I could do with my kids' art? I don't want to throw them away, but it's also so much clutter.", | |
| }, | |
| { | |
| "title": ["Tell me a fun fact", "about the Roman Empire"], | |
| "content": "Tell me a random fun fact about the Roman Empire", | |
| }, | |
| { | |
| "title": [ | |
| "Show me a code snippet", | |
| "of a website's sticky header", | |
| ], | |
| "content": "Show me a code snippet of a website's sticky header in CSS and JavaScript.", | |
| }, | |
| { | |
| "title": [ | |
| "Explain options trading", | |
| "if I'm familiar with buying and selling stocks", | |
| ], | |
| "content": "Explain options trading in simple terms if I'm familiar with buying and selling stocks.", | |
| }, | |
| { | |
| "title": ["Overcome procrastination", "give me tips"], | |
| "content": "Could you start by asking me about instances when I procrastinate the most and then give me some suggestions to overcome it?", | |
| }, | |
| { | |
| "title": [ | |
| "Grammar check", | |
| "rewrite it for better readability ", | |
| ], | |
| "content": 'Check the following sentence for grammar and clarity: "[sentence]". Rewrite it for better readability while maintaining its original meaning.', | |
| }, | |
| ], | |
| }, | |
| } | |
| def get_config(): | |
| with get_db() as db: | |
| config_entry = db.query(Config).order_by(Config.id.desc()).first() | |
| return config_entry.data if config_entry else DEFAULT_CONFIG | |
| CONFIG_DATA = get_config() | |
| def get_config_value(config_path: str): | |
| path_parts = config_path.split(".") | |
| cur_config = CONFIG_DATA | |
| for key in path_parts: | |
| if key in cur_config: | |
| cur_config = cur_config[key] | |
| else: | |
| return None | |
| return cur_config | |
| PERSISTENT_CONFIG_REGISTRY = [] | |
| def save_config(config): | |
| global CONFIG_DATA | |
| global PERSISTENT_CONFIG_REGISTRY | |
| try: | |
| save_to_db(config) | |
| CONFIG_DATA = config | |
| # Trigger updates on all registered PersistentConfig entries | |
| for config_item in PERSISTENT_CONFIG_REGISTRY: | |
| config_item.update() | |
| except Exception as e: | |
| log.exception(e) | |
| return False | |
| return True | |
| T = TypeVar("T") | |
| class PersistentConfig(Generic[T]): | |
| def __init__(self, env_name: str, config_path: str, env_value: T): | |
| self.env_name = env_name | |
| self.config_path = config_path | |
| self.env_value = env_value | |
| self.config_value = get_config_value(config_path) | |
| if self.config_value is not None: | |
| log.info(f"'{env_name}' loaded from the latest database entry") | |
| self.value = self.config_value | |
| else: | |
| self.value = env_value | |
| PERSISTENT_CONFIG_REGISTRY.append(self) | |
| def __str__(self): | |
| return str(self.value) | |
| def __dict__(self): | |
| raise TypeError( | |
| "PersistentConfig object cannot be converted to dict, use config_get or .value instead." | |
| ) | |
| def __getattribute__(self, item): | |
| if item == "__dict__": | |
| raise TypeError( | |
| "PersistentConfig object cannot be converted to dict, use config_get or .value instead." | |
| ) | |
| return super().__getattribute__(item) | |
| def update(self): | |
| new_value = get_config_value(self.config_path) | |
| if new_value is not None: | |
| self.value = new_value | |
| log.info(f"Updated {self.env_name} to new value {self.value}") | |
| def save(self): | |
| log.info(f"Saving '{self.env_name}' to the database") | |
| path_parts = self.config_path.split(".") | |
| sub_config = CONFIG_DATA | |
| for key in path_parts[:-1]: | |
| if key not in sub_config: | |
| sub_config[key] = {} | |
| sub_config = sub_config[key] | |
| sub_config[path_parts[-1]] = self.value | |
| save_to_db(CONFIG_DATA) | |
| self.config_value = self.value | |
| class AppConfig: | |
| _state: dict[str, PersistentConfig] | |
| def __init__(self): | |
| super().__setattr__("_state", {}) | |
| def __setattr__(self, key, value): | |
| if isinstance(value, PersistentConfig): | |
| self._state[key] = value | |
| else: | |
| self._state[key].value = value | |
| self._state[key].save() | |
| def __getattr__(self, key): | |
| return self._state[key].value | |
| #################################### | |
| # WEBUI_AUTH (Required for security) | |
| #################################### | |
| JWT_EXPIRES_IN = PersistentConfig( | |
| "JWT_EXPIRES_IN", "auth.jwt_expiry", os.environ.get("JWT_EXPIRES_IN", "-1") | |
| ) | |
| #################################### | |
| # OAuth config | |
| #################################### | |
| ENABLE_OAUTH_SIGNUP = PersistentConfig( | |
| "ENABLE_OAUTH_SIGNUP", | |
| "oauth.enable_signup", | |
| os.environ.get("ENABLE_OAUTH_SIGNUP", "False").lower() == "true", | |
| ) | |
| OAUTH_MERGE_ACCOUNTS_BY_EMAIL = PersistentConfig( | |
| "OAUTH_MERGE_ACCOUNTS_BY_EMAIL", | |
| "oauth.merge_accounts_by_email", | |
| os.environ.get("OAUTH_MERGE_ACCOUNTS_BY_EMAIL", "False").lower() == "true", | |
| ) | |
| OAUTH_PROVIDERS = {} | |
| GOOGLE_CLIENT_ID = PersistentConfig( | |
| "GOOGLE_CLIENT_ID", | |
| "oauth.google.client_id", | |
| os.environ.get("GOOGLE_CLIENT_ID", ""), | |
| ) | |
| GOOGLE_CLIENT_SECRET = PersistentConfig( | |
| "GOOGLE_CLIENT_SECRET", | |
| "oauth.google.client_secret", | |
| os.environ.get("GOOGLE_CLIENT_SECRET", ""), | |
| ) | |
| GOOGLE_OAUTH_SCOPE = PersistentConfig( | |
| "GOOGLE_OAUTH_SCOPE", | |
| "oauth.google.scope", | |
| os.environ.get("GOOGLE_OAUTH_SCOPE", "openid email profile"), | |
| ) | |
| GOOGLE_REDIRECT_URI = PersistentConfig( | |
| "GOOGLE_REDIRECT_URI", | |
| "oauth.google.redirect_uri", | |
| os.environ.get("GOOGLE_REDIRECT_URI", ""), | |
| ) | |
| MICROSOFT_CLIENT_ID = PersistentConfig( | |
| "MICROSOFT_CLIENT_ID", | |
| "oauth.microsoft.client_id", | |
| os.environ.get("MICROSOFT_CLIENT_ID", ""), | |
| ) | |
| MICROSOFT_CLIENT_SECRET = PersistentConfig( | |
| "MICROSOFT_CLIENT_SECRET", | |
| "oauth.microsoft.client_secret", | |
| os.environ.get("MICROSOFT_CLIENT_SECRET", ""), | |
| ) | |
| MICROSOFT_CLIENT_TENANT_ID = PersistentConfig( | |
| "MICROSOFT_CLIENT_TENANT_ID", | |
| "oauth.microsoft.tenant_id", | |
| os.environ.get("MICROSOFT_CLIENT_TENANT_ID", ""), | |
| ) | |
| MICROSOFT_OAUTH_SCOPE = PersistentConfig( | |
| "MICROSOFT_OAUTH_SCOPE", | |
| "oauth.microsoft.scope", | |
| os.environ.get("MICROSOFT_OAUTH_SCOPE", "openid email profile"), | |
| ) | |
| MICROSOFT_REDIRECT_URI = PersistentConfig( | |
| "MICROSOFT_REDIRECT_URI", | |
| "oauth.microsoft.redirect_uri", | |
| os.environ.get("MICROSOFT_REDIRECT_URI", ""), | |
| ) | |
| OAUTH_CLIENT_ID = PersistentConfig( | |
| "OAUTH_CLIENT_ID", | |
| "oauth.oidc.client_id", | |
| os.environ.get("OAUTH_CLIENT_ID", ""), | |
| ) | |
| OAUTH_CLIENT_SECRET = PersistentConfig( | |
| "OAUTH_CLIENT_SECRET", | |
| "oauth.oidc.client_secret", | |
| os.environ.get("OAUTH_CLIENT_SECRET", ""), | |
| ) | |
| OPENID_PROVIDER_URL = PersistentConfig( | |
| "OPENID_PROVIDER_URL", | |
| "oauth.oidc.provider_url", | |
| os.environ.get("OPENID_PROVIDER_URL", ""), | |
| ) | |
| OPENID_REDIRECT_URI = PersistentConfig( | |
| "OPENID_REDIRECT_URI", | |
| "oauth.oidc.redirect_uri", | |
| os.environ.get("OPENID_REDIRECT_URI", ""), | |
| ) | |
| OAUTH_SCOPES = PersistentConfig( | |
| "OAUTH_SCOPES", | |
| "oauth.oidc.scopes", | |
| os.environ.get("OAUTH_SCOPES", "openid email profile"), | |
| ) | |
| OAUTH_PROVIDER_NAME = PersistentConfig( | |
| "OAUTH_PROVIDER_NAME", | |
| "oauth.oidc.provider_name", | |
| os.environ.get("OAUTH_PROVIDER_NAME", "SSO"), | |
| ) | |
| OAUTH_USERNAME_CLAIM = PersistentConfig( | |
| "OAUTH_USERNAME_CLAIM", | |
| "oauth.oidc.username_claim", | |
| os.environ.get("OAUTH_USERNAME_CLAIM", "name"), | |
| ) | |
| OAUTH_PICTURE_CLAIM = PersistentConfig( | |
| "OAUTH_USERNAME_CLAIM", | |
| "oauth.oidc.avatar_claim", | |
| os.environ.get("OAUTH_PICTURE_CLAIM", "picture"), | |
| ) | |
| OAUTH_EMAIL_CLAIM = PersistentConfig( | |
| "OAUTH_EMAIL_CLAIM", | |
| "oauth.oidc.email_claim", | |
| os.environ.get("OAUTH_EMAIL_CLAIM", "email"), | |
| ) | |
| def load_oauth_providers(): | |
| OAUTH_PROVIDERS.clear() | |
| if GOOGLE_CLIENT_ID.value and GOOGLE_CLIENT_SECRET.value: | |
| OAUTH_PROVIDERS["google"] = { | |
| "client_id": GOOGLE_CLIENT_ID.value, | |
| "client_secret": GOOGLE_CLIENT_SECRET.value, | |
| "server_metadata_url": "https://accounts.google.com/.well-known/openid-configuration", | |
| "scope": GOOGLE_OAUTH_SCOPE.value, | |
| "redirect_uri": GOOGLE_REDIRECT_URI.value, | |
| } | |
| if ( | |
| MICROSOFT_CLIENT_ID.value | |
| and MICROSOFT_CLIENT_SECRET.value | |
| and MICROSOFT_CLIENT_TENANT_ID.value | |
| ): | |
| OAUTH_PROVIDERS["microsoft"] = { | |
| "client_id": MICROSOFT_CLIENT_ID.value, | |
| "client_secret": MICROSOFT_CLIENT_SECRET.value, | |
| "server_metadata_url": f"https://login.microsoftonline.com/{MICROSOFT_CLIENT_TENANT_ID.value}/v2.0/.well-known/openid-configuration", | |
| "scope": MICROSOFT_OAUTH_SCOPE.value, | |
| "redirect_uri": MICROSOFT_REDIRECT_URI.value, | |
| } | |
| if ( | |
| OAUTH_CLIENT_ID.value | |
| and OAUTH_CLIENT_SECRET.value | |
| and OPENID_PROVIDER_URL.value | |
| ): | |
| OAUTH_PROVIDERS["oidc"] = { | |
| "client_id": OAUTH_CLIENT_ID.value, | |
| "client_secret": OAUTH_CLIENT_SECRET.value, | |
| "server_metadata_url": OPENID_PROVIDER_URL.value, | |
| "scope": OAUTH_SCOPES.value, | |
| "name": OAUTH_PROVIDER_NAME.value, | |
| "redirect_uri": OPENID_REDIRECT_URI.value, | |
| } | |
| load_oauth_providers() | |
| #################################### | |
| # Static DIR | |
| #################################### | |
| STATIC_DIR = Path(os.getenv("STATIC_DIR", OPEN_WEBUI_DIR / "static")).resolve() | |
| frontend_favicon = FRONTEND_BUILD_DIR / "static" / "favicon.png" | |
| if frontend_favicon.exists(): | |
| try: | |
| shutil.copyfile(frontend_favicon, STATIC_DIR / "favicon.png") | |
| except Exception as e: | |
| logging.error(f"An error occurred: {e}") | |
| else: | |
| logging.warning(f"Frontend favicon not found at {frontend_favicon}") | |
| frontend_splash = FRONTEND_BUILD_DIR / "static" / "splash.png" | |
| if frontend_splash.exists(): | |
| try: | |
| shutil.copyfile(frontend_splash, STATIC_DIR / "splash.png") | |
| except Exception as e: | |
| logging.error(f"An error occurred: {e}") | |
| else: | |
| logging.warning(f"Frontend splash not found at {frontend_splash}") | |
| #################################### | |
| # CUSTOM_NAME | |
| #################################### | |
| CUSTOM_NAME = os.environ.get("CUSTOM_NAME", "") | |
| if CUSTOM_NAME: | |
| try: | |
| r = requests.get(f"https://api.openwebui.com/api/v1/custom/{CUSTOM_NAME}") | |
| data = r.json() | |
| if r.ok: | |
| if "logo" in data: | |
| WEBUI_FAVICON_URL = url = ( | |
| f"https://api.openwebui.com{data['logo']}" | |
| if data["logo"][0] == "/" | |
| else data["logo"] | |
| ) | |
| r = requests.get(url, stream=True) | |
| if r.status_code == 200: | |
| with open(f"{STATIC_DIR}/favicon.png", "wb") as f: | |
| r.raw.decode_content = True | |
| shutil.copyfileobj(r.raw, f) | |
| if "splash" in data: | |
| url = ( | |
| f"https://api.openwebui.com{data['splash']}" | |
| if data["splash"][0] == "/" | |
| else data["splash"] | |
| ) | |
| r = requests.get(url, stream=True) | |
| if r.status_code == 200: | |
| with open(f"{STATIC_DIR}/splash.png", "wb") as f: | |
| r.raw.decode_content = True | |
| shutil.copyfileobj(r.raw, f) | |
| WEBUI_NAME = data["name"] | |
| except Exception as e: | |
| log.exception(e) | |
| pass | |
| #################################### | |
| # File Upload DIR | |
| #################################### | |
| UPLOAD_DIR = f"{DATA_DIR}/uploads" | |
| Path(UPLOAD_DIR).mkdir(parents=True, exist_ok=True) | |
| #################################### | |
| # Cache DIR | |
| #################################### | |
| CACHE_DIR = f"{DATA_DIR}/cache" | |
| Path(CACHE_DIR).mkdir(parents=True, exist_ok=True) | |
| #################################### | |
| # Docs DIR | |
| #################################### | |
| DOCS_DIR = os.getenv("DOCS_DIR", f"{DATA_DIR}/docs") | |
| Path(DOCS_DIR).mkdir(parents=True, exist_ok=True) | |
| #################################### | |
| # Tools DIR | |
| #################################### | |
| TOOLS_DIR = os.getenv("TOOLS_DIR", f"{DATA_DIR}/tools") | |
| Path(TOOLS_DIR).mkdir(parents=True, exist_ok=True) | |
| #################################### | |
| # Functions DIR | |
| #################################### | |
| FUNCTIONS_DIR = os.getenv("FUNCTIONS_DIR", f"{DATA_DIR}/functions") | |
| Path(FUNCTIONS_DIR).mkdir(parents=True, exist_ok=True) | |
| #################################### | |
| # OLLAMA_BASE_URL | |
| #################################### | |
| ENABLE_OLLAMA_API = PersistentConfig( | |
| "ENABLE_OLLAMA_API", | |
| "ollama.enable", | |
| os.environ.get("ENABLE_OLLAMA_API", "True").lower() == "true", | |
| ) | |
| OLLAMA_API_BASE_URL = os.environ.get( | |
| "OLLAMA_API_BASE_URL", "http://localhost:11434/api" | |
| ) | |
| OLLAMA_BASE_URL = os.environ.get("OLLAMA_BASE_URL", "") | |
| AIOHTTP_CLIENT_TIMEOUT = os.environ.get("AIOHTTP_CLIENT_TIMEOUT", "") | |
| if AIOHTTP_CLIENT_TIMEOUT == "": | |
| AIOHTTP_CLIENT_TIMEOUT = None | |
| else: | |
| try: | |
| AIOHTTP_CLIENT_TIMEOUT = int(AIOHTTP_CLIENT_TIMEOUT) | |
| except Exception: | |
| AIOHTTP_CLIENT_TIMEOUT = 300 | |
| K8S_FLAG = os.environ.get("K8S_FLAG", "") | |
| USE_OLLAMA_DOCKER = os.environ.get("USE_OLLAMA_DOCKER", "false") | |
| if OLLAMA_BASE_URL == "" and OLLAMA_API_BASE_URL != "": | |
| OLLAMA_BASE_URL = ( | |
| OLLAMA_API_BASE_URL[:-4] | |
| if OLLAMA_API_BASE_URL.endswith("/api") | |
| else OLLAMA_API_BASE_URL | |
| ) | |
| if ENV == "prod": | |
| if OLLAMA_BASE_URL == "/ollama" and not K8S_FLAG: | |
| if USE_OLLAMA_DOCKER.lower() == "true": | |
| # if you use all-in-one docker container (Open WebUI + Ollama) | |
| # with the docker build arg USE_OLLAMA=true (--build-arg="USE_OLLAMA=true") this only works with http://localhost:11434 | |
| OLLAMA_BASE_URL = "http://localhost:11434" | |
| else: | |
| OLLAMA_BASE_URL = "http://host.docker.internal:11434" | |
| elif K8S_FLAG: | |
| OLLAMA_BASE_URL = "http://ollama-service.open-webui.svc.cluster.local:11434" | |
| OLLAMA_BASE_URLS = os.environ.get("OLLAMA_BASE_URLS", "") | |
| OLLAMA_BASE_URLS = OLLAMA_BASE_URLS if OLLAMA_BASE_URLS != "" else OLLAMA_BASE_URL | |
| OLLAMA_BASE_URLS = [url.strip() for url in OLLAMA_BASE_URLS.split(";")] | |
| OLLAMA_BASE_URLS = PersistentConfig( | |
| "OLLAMA_BASE_URLS", "ollama.base_urls", OLLAMA_BASE_URLS | |
| ) | |
| #################################### | |
| # OPENAI_API | |
| #################################### | |
| ENABLE_OPENAI_API = PersistentConfig( | |
| "ENABLE_OPENAI_API", | |
| "openai.enable", | |
| os.environ.get("ENABLE_OPENAI_API", "True").lower() == "true", | |
| ) | |
| OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "") | |
| OPENAI_API_BASE_URL = os.environ.get("OPENAI_API_BASE_URL", "") | |
| if OPENAI_API_BASE_URL == "": | |
| OPENAI_API_BASE_URL = "https://api.openai.com/v1" | |
| OPENAI_API_KEYS = os.environ.get("OPENAI_API_KEYS", "") | |
| OPENAI_API_KEYS = OPENAI_API_KEYS if OPENAI_API_KEYS != "" else OPENAI_API_KEY | |
| OPENAI_API_KEYS = [url.strip() for url in OPENAI_API_KEYS.split(";")] | |
| OPENAI_API_KEYS = PersistentConfig( | |
| "OPENAI_API_KEYS", "openai.api_keys", OPENAI_API_KEYS | |
| ) | |
| OPENAI_API_BASE_URLS = os.environ.get("OPENAI_API_BASE_URLS", "") | |
| OPENAI_API_BASE_URLS = ( | |
| OPENAI_API_BASE_URLS if OPENAI_API_BASE_URLS != "" else OPENAI_API_BASE_URL | |
| ) | |
| OPENAI_API_BASE_URLS = [ | |
| url.strip() if url != "" else "https://api.openai.com/v1" | |
| for url in OPENAI_API_BASE_URLS.split(";") | |
| ] | |
| OPENAI_API_BASE_URLS = PersistentConfig( | |
| "OPENAI_API_BASE_URLS", "openai.api_base_urls", OPENAI_API_BASE_URLS | |
| ) | |
| OPENAI_API_KEY = "" | |
| try: | |
| OPENAI_API_KEY = OPENAI_API_KEYS.value[ | |
| OPENAI_API_BASE_URLS.value.index("https://api.openai.com/v1") | |
| ] | |
| except Exception: | |
| pass | |
| OPENAI_API_BASE_URL = "https://api.openai.com/v1" | |
| #################################### | |
| # WEBUI | |
| #################################### | |
| ENABLE_SIGNUP = PersistentConfig( | |
| "ENABLE_SIGNUP", | |
| "ui.enable_signup", | |
| ( | |
| False | |
| if not WEBUI_AUTH | |
| else os.environ.get("ENABLE_SIGNUP", "True").lower() == "true" | |
| ), | |
| ) | |
| ENABLE_LOGIN_FORM = PersistentConfig( | |
| "ENABLE_LOGIN_FORM", | |
| "ui.ENABLE_LOGIN_FORM", | |
| os.environ.get("ENABLE_LOGIN_FORM", "True").lower() == "true", | |
| ) | |
| DEFAULT_LOCALE = PersistentConfig( | |
| "DEFAULT_LOCALE", | |
| "ui.default_locale", | |
| os.environ.get("DEFAULT_LOCALE", ""), | |
| ) | |
| DEFAULT_MODELS = PersistentConfig( | |
| "DEFAULT_MODELS", "ui.default_models", os.environ.get("DEFAULT_MODELS", None) | |
| ) | |
| DEFAULT_PROMPT_SUGGESTIONS = PersistentConfig( | |
| "DEFAULT_PROMPT_SUGGESTIONS", | |
| "ui.prompt_suggestions", | |
| [ | |
| { | |
| "title": ["Help me study", "vocabulary for a college entrance exam"], | |
| "content": "Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option.", | |
| }, | |
| { | |
| "title": ["Give me ideas", "for what to do with my kids' art"], | |
| "content": "What are 5 creative things I could do with my kids' art? I don't want to throw them away, but it's also so much clutter.", | |
| }, | |
| { | |
| "title": ["Tell me a fun fact", "about the Roman Empire"], | |
| "content": "Tell me a random fun fact about the Roman Empire", | |
| }, | |
| { | |
| "title": ["Show me a code snippet", "of a website's sticky header"], | |
| "content": "Show me a code snippet of a website's sticky header in CSS and JavaScript.", | |
| }, | |
| { | |
| "title": [ | |
| "Explain options trading", | |
| "if I'm familiar with buying and selling stocks", | |
| ], | |
| "content": "Explain options trading in simple terms if I'm familiar with buying and selling stocks.", | |
| }, | |
| { | |
| "title": ["Overcome procrastination", "give me tips"], | |
| "content": "Could you start by asking me about instances when I procrastinate the most and then give me some suggestions to overcome it?", | |
| }, | |
| ], | |
| ) | |
| DEFAULT_USER_ROLE = PersistentConfig( | |
| "DEFAULT_USER_ROLE", | |
| "ui.default_user_role", | |
| os.getenv("DEFAULT_USER_ROLE", "pending"), | |
| ) | |
| USER_PERMISSIONS_CHAT_DELETION = ( | |
| os.environ.get("USER_PERMISSIONS_CHAT_DELETION", "True").lower() == "true" | |
| ) | |
| USER_PERMISSIONS_CHAT_EDITING = ( | |
| os.environ.get("USER_PERMISSIONS_CHAT_EDITING", "True").lower() == "true" | |
| ) | |
| USER_PERMISSIONS_CHAT_TEMPORARY = ( | |
| os.environ.get("USER_PERMISSIONS_CHAT_TEMPORARY", "True").lower() == "true" | |
| ) | |
| USER_PERMISSIONS = PersistentConfig( | |
| "USER_PERMISSIONS", | |
| "ui.user_permissions", | |
| { | |
| "chat": { | |
| "deletion": USER_PERMISSIONS_CHAT_DELETION, | |
| "editing": USER_PERMISSIONS_CHAT_EDITING, | |
| "temporary": USER_PERMISSIONS_CHAT_TEMPORARY, | |
| } | |
| }, | |
| ) | |
| ENABLE_MODEL_FILTER = PersistentConfig( | |
| "ENABLE_MODEL_FILTER", | |
| "model_filter.enable", | |
| os.environ.get("ENABLE_MODEL_FILTER", "False").lower() == "true", | |
| ) | |
| MODEL_FILTER_LIST = os.environ.get("MODEL_FILTER_LIST", "") | |
| MODEL_FILTER_LIST = PersistentConfig( | |
| "MODEL_FILTER_LIST", | |
| "model_filter.list", | |
| [model.strip() for model in MODEL_FILTER_LIST.split(";")], | |
| ) | |
| WEBHOOK_URL = PersistentConfig( | |
| "WEBHOOK_URL", "webhook_url", os.environ.get("WEBHOOK_URL", "") | |
| ) | |
| ENABLE_ADMIN_EXPORT = os.environ.get("ENABLE_ADMIN_EXPORT", "True").lower() == "true" | |
| ENABLE_ADMIN_CHAT_ACCESS = ( | |
| os.environ.get("ENABLE_ADMIN_CHAT_ACCESS", "True").lower() == "true" | |
| ) | |
| ENABLE_COMMUNITY_SHARING = PersistentConfig( | |
| "ENABLE_COMMUNITY_SHARING", | |
| "ui.enable_community_sharing", | |
| os.environ.get("ENABLE_COMMUNITY_SHARING", "True").lower() == "true", | |
| ) | |
| ENABLE_MESSAGE_RATING = PersistentConfig( | |
| "ENABLE_MESSAGE_RATING", | |
| "ui.enable_message_rating", | |
| os.environ.get("ENABLE_MESSAGE_RATING", "True").lower() == "true", | |
| ) | |
| def validate_cors_origins(origins): | |
| for origin in origins: | |
| if origin != "*": | |
| validate_cors_origin(origin) | |
| def validate_cors_origin(origin): | |
| parsed_url = urlparse(origin) | |
| # Check if the scheme is either http or https | |
| if parsed_url.scheme not in ["http", "https"]: | |
| raise ValueError( | |
| f"Invalid scheme in CORS_ALLOW_ORIGIN: '{origin}'. Only 'http' and 'https' are allowed." | |
| ) | |
| # Ensure that the netloc (domain + port) is present, indicating it's a valid URL | |
| if not parsed_url.netloc: | |
| raise ValueError(f"Invalid URL structure in CORS_ALLOW_ORIGIN: '{origin}'.") | |
| # For production, you should only need one host as | |
| # fastapi serves the svelte-kit built frontend and backend from the same host and port. | |
| # To test CORS_ALLOW_ORIGIN locally, you can set something like | |
| # CORS_ALLOW_ORIGIN=http://localhost:5173;http://localhost:8080 | |
| # in your .env file depending on your frontend port, 5173 in this case. | |
| CORS_ALLOW_ORIGIN = os.environ.get("CORS_ALLOW_ORIGIN", "*").split(";") | |
| if "*" in CORS_ALLOW_ORIGIN: | |
| log.warning( | |
| "\n\nWARNING: CORS_ALLOW_ORIGIN IS SET TO '*' - NOT RECOMMENDED FOR PRODUCTION DEPLOYMENTS.\n" | |
| ) | |
| validate_cors_origins(CORS_ALLOW_ORIGIN) | |
| class BannerModel(BaseModel): | |
| id: str | |
| type: str | |
| title: Optional[str] = None | |
| content: str | |
| dismissible: bool | |
| timestamp: int | |
| try: | |
| banners = json.loads(os.environ.get("WEBUI_BANNERS", "[]")) | |
| banners = [BannerModel(**banner) for banner in banners] | |
| except Exception as e: | |
| print(f"Error loading WEBUI_BANNERS: {e}") | |
| banners = [] | |
| WEBUI_BANNERS = PersistentConfig("WEBUI_BANNERS", "ui.banners", banners) | |
| SHOW_ADMIN_DETAILS = PersistentConfig( | |
| "SHOW_ADMIN_DETAILS", | |
| "auth.admin.show", | |
| os.environ.get("SHOW_ADMIN_DETAILS", "true").lower() == "true", | |
| ) | |
| ADMIN_EMAIL = PersistentConfig( | |
| "ADMIN_EMAIL", | |
| "auth.admin.email", | |
| os.environ.get("ADMIN_EMAIL", None), | |
| ) | |
| #################################### | |
| # TASKS | |
| #################################### | |
| TASK_MODEL = PersistentConfig( | |
| "TASK_MODEL", | |
| "task.model.default", | |
| os.environ.get("TASK_MODEL", ""), | |
| ) | |
| TASK_MODEL_EXTERNAL = PersistentConfig( | |
| "TASK_MODEL_EXTERNAL", | |
| "task.model.external", | |
| os.environ.get("TASK_MODEL_EXTERNAL", ""), | |
| ) | |
| TITLE_GENERATION_PROMPT_TEMPLATE = PersistentConfig( | |
| "TITLE_GENERATION_PROMPT_TEMPLATE", | |
| "task.title.prompt_template", | |
| os.environ.get("TITLE_GENERATION_PROMPT_TEMPLATE", ""), | |
| ) | |
| ENABLE_SEARCH_QUERY = PersistentConfig( | |
| "ENABLE_SEARCH_QUERY", | |
| "task.search.enable", | |
| os.environ.get("ENABLE_SEARCH_QUERY", "True").lower() == "true", | |
| ) | |
| SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = PersistentConfig( | |
| "SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE", | |
| "task.search.prompt_template", | |
| os.environ.get("SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE", ""), | |
| ) | |
| TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = PersistentConfig( | |
| "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE", | |
| "task.tools.prompt_template", | |
| os.environ.get("TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE", ""), | |
| ) | |
| #################################### | |
| # Vector Database | |
| #################################### | |
| VECTOR_DB = os.environ.get("VECTOR_DB", "chroma") | |
| # Chroma | |
| CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db" | |
| CHROMA_TENANT = os.environ.get("CHROMA_TENANT", chromadb.DEFAULT_TENANT) | |
| CHROMA_DATABASE = os.environ.get("CHROMA_DATABASE", chromadb.DEFAULT_DATABASE) | |
| CHROMA_HTTP_HOST = os.environ.get("CHROMA_HTTP_HOST", "") | |
| CHROMA_HTTP_PORT = int(os.environ.get("CHROMA_HTTP_PORT", "8000")) | |
| # Comma-separated list of header=value pairs | |
| CHROMA_HTTP_HEADERS = os.environ.get("CHROMA_HTTP_HEADERS", "") | |
| if CHROMA_HTTP_HEADERS: | |
| CHROMA_HTTP_HEADERS = dict( | |
| [pair.split("=") for pair in CHROMA_HTTP_HEADERS.split(",")] | |
| ) | |
| else: | |
| CHROMA_HTTP_HEADERS = None | |
| CHROMA_HTTP_SSL = os.environ.get("CHROMA_HTTP_SSL", "false").lower() == "true" | |
| # this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (sentence-transformers/all-MiniLM-L6-v2) | |
| # Milvus | |
| MILVUS_URI = os.environ.get("MILVUS_URI", f"{DATA_DIR}/vector_db/milvus.db") | |
| #################################### | |
| # Information Retrieval (RAG) | |
| #################################### | |
| # RAG Content Extraction | |
| CONTENT_EXTRACTION_ENGINE = PersistentConfig( | |
| "CONTENT_EXTRACTION_ENGINE", | |
| "rag.CONTENT_EXTRACTION_ENGINE", | |
| os.environ.get("CONTENT_EXTRACTION_ENGINE", "").lower(), | |
| ) | |
| TIKA_SERVER_URL = PersistentConfig( | |
| "TIKA_SERVER_URL", | |
| "rag.tika_server_url", | |
| os.getenv("TIKA_SERVER_URL", "http://tika:9998"), # Default for sidecar deployment | |
| ) | |
| RAG_TOP_K = PersistentConfig( | |
| "RAG_TOP_K", "rag.top_k", int(os.environ.get("RAG_TOP_K", "3")) | |
| ) | |
| RAG_RELEVANCE_THRESHOLD = PersistentConfig( | |
| "RAG_RELEVANCE_THRESHOLD", | |
| "rag.relevance_threshold", | |
| float(os.environ.get("RAG_RELEVANCE_THRESHOLD", "0.0")), | |
| ) | |
| ENABLE_RAG_HYBRID_SEARCH = PersistentConfig( | |
| "ENABLE_RAG_HYBRID_SEARCH", | |
| "rag.enable_hybrid_search", | |
| os.environ.get("ENABLE_RAG_HYBRID_SEARCH", "").lower() == "true", | |
| ) | |
| RAG_FILE_MAX_COUNT = PersistentConfig( | |
| "RAG_FILE_MAX_COUNT", | |
| "rag.file.max_count", | |
| ( | |
| int(os.environ.get("RAG_FILE_MAX_COUNT")) | |
| if os.environ.get("RAG_FILE_MAX_COUNT") | |
| else None | |
| ), | |
| ) | |
| RAG_FILE_MAX_SIZE = PersistentConfig( | |
| "RAG_FILE_MAX_SIZE", | |
| "rag.file.max_size", | |
| ( | |
| int(os.environ.get("RAG_FILE_MAX_SIZE")) | |
| if os.environ.get("RAG_FILE_MAX_SIZE") | |
| else None | |
| ), | |
| ) | |
| ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = PersistentConfig( | |
| "ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION", | |
| "rag.enable_web_loader_ssl_verification", | |
| os.environ.get("ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION", "True").lower() == "true", | |
| ) | |
| RAG_EMBEDDING_ENGINE = PersistentConfig( | |
| "RAG_EMBEDDING_ENGINE", | |
| "rag.embedding_engine", | |
| os.environ.get("RAG_EMBEDDING_ENGINE", ""), | |
| ) | |
| PDF_EXTRACT_IMAGES = PersistentConfig( | |
| "PDF_EXTRACT_IMAGES", | |
| "rag.pdf_extract_images", | |
| os.environ.get("PDF_EXTRACT_IMAGES", "False").lower() == "true", | |
| ) | |
| RAG_EMBEDDING_MODEL = PersistentConfig( | |
| "RAG_EMBEDDING_MODEL", | |
| "rag.embedding_model", | |
| os.environ.get("RAG_EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2"), | |
| ) | |
| log.info(f"Embedding model set: {RAG_EMBEDDING_MODEL.value}") | |
| RAG_EMBEDDING_MODEL_AUTO_UPDATE = ( | |
| os.environ.get("RAG_EMBEDDING_MODEL_AUTO_UPDATE", "").lower() == "true" | |
| ) | |
| RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = ( | |
| os.environ.get("RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true" | |
| ) | |
| RAG_EMBEDDING_OPENAI_BATCH_SIZE = PersistentConfig( | |
| "RAG_EMBEDDING_OPENAI_BATCH_SIZE", | |
| "rag.embedding_openai_batch_size", | |
| int(os.environ.get("RAG_EMBEDDING_OPENAI_BATCH_SIZE", "1")), | |
| ) | |
| RAG_RERANKING_MODEL = PersistentConfig( | |
| "RAG_RERANKING_MODEL", | |
| "rag.reranking_model", | |
| os.environ.get("RAG_RERANKING_MODEL", ""), | |
| ) | |
| if RAG_RERANKING_MODEL.value != "": | |
| log.info(f"Reranking model set: {RAG_RERANKING_MODEL.value}") | |
| RAG_RERANKING_MODEL_AUTO_UPDATE = ( | |
| os.environ.get("RAG_RERANKING_MODEL_AUTO_UPDATE", "").lower() == "true" | |
| ) | |
| RAG_RERANKING_MODEL_TRUST_REMOTE_CODE = ( | |
| os.environ.get("RAG_RERANKING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true" | |
| ) | |
| CHUNK_SIZE = PersistentConfig( | |
| "CHUNK_SIZE", "rag.chunk_size", int(os.environ.get("CHUNK_SIZE", "1000")) | |
| ) | |
| CHUNK_OVERLAP = PersistentConfig( | |
| "CHUNK_OVERLAP", | |
| "rag.chunk_overlap", | |
| int(os.environ.get("CHUNK_OVERLAP", "100")), | |
| ) | |
| DEFAULT_RAG_TEMPLATE = """You are given a user query, some textual context and rules, all inside xml tags. You have to answer the query based on the context while respecting the rules. | |
| <context> | |
| [context] | |
| </context> | |
| <rules> | |
| - If you don't know, just say so. | |
| - If you are not sure, ask for clarification. | |
| - Answer in the same language as the user query. | |
| - If the context appears unreadable or of poor quality, tell the user then answer as best as you can. | |
| - If the answer is not in the context but you think you know the answer, explain that to the user then answer with your own knowledge. | |
| - Answer directly and without using xml tags. | |
| </rules> | |
| <user_query> | |
| [query] | |
| </user_query> | |
| """ | |
| RAG_TEMPLATE = PersistentConfig( | |
| "RAG_TEMPLATE", | |
| "rag.template", | |
| os.environ.get("RAG_TEMPLATE", DEFAULT_RAG_TEMPLATE), | |
| ) | |
| RAG_OPENAI_API_BASE_URL = PersistentConfig( | |
| "RAG_OPENAI_API_BASE_URL", | |
| "rag.openai_api_base_url", | |
| os.getenv("RAG_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL), | |
| ) | |
| RAG_OPENAI_API_KEY = PersistentConfig( | |
| "RAG_OPENAI_API_KEY", | |
| "rag.openai_api_key", | |
| os.getenv("RAG_OPENAI_API_KEY", OPENAI_API_KEY), | |
| ) | |
| ENABLE_RAG_LOCAL_WEB_FETCH = ( | |
| os.getenv("ENABLE_RAG_LOCAL_WEB_FETCH", "False").lower() == "true" | |
| ) | |
| YOUTUBE_LOADER_LANGUAGE = PersistentConfig( | |
| "YOUTUBE_LOADER_LANGUAGE", | |
| "rag.youtube_loader_language", | |
| os.getenv("YOUTUBE_LOADER_LANGUAGE", "en").split(","), | |
| ) | |
| ENABLE_RAG_WEB_SEARCH = PersistentConfig( | |
| "ENABLE_RAG_WEB_SEARCH", | |
| "rag.web.search.enable", | |
| os.getenv("ENABLE_RAG_WEB_SEARCH", "False").lower() == "true", | |
| ) | |
| RAG_WEB_SEARCH_ENGINE = PersistentConfig( | |
| "RAG_WEB_SEARCH_ENGINE", | |
| "rag.web.search.engine", | |
| os.getenv("RAG_WEB_SEARCH_ENGINE", ""), | |
| ) | |
| # You can provide a list of your own websites to filter after performing a web search. | |
| # This ensures the highest level of safety and reliability of the information sources. | |
| RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = PersistentConfig( | |
| "RAG_WEB_SEARCH_DOMAIN_FILTER_LIST", | |
| "rag.rag.web.search.domain.filter_list", | |
| [ | |
| # "wikipedia.com", | |
| # "wikimedia.org", | |
| # "wikidata.org", | |
| ], | |
| ) | |
| SEARXNG_QUERY_URL = PersistentConfig( | |
| "SEARXNG_QUERY_URL", | |
| "rag.web.search.searxng_query_url", | |
| os.getenv("SEARXNG_QUERY_URL", ""), | |
| ) | |
| GOOGLE_PSE_API_KEY = PersistentConfig( | |
| "GOOGLE_PSE_API_KEY", | |
| "rag.web.search.google_pse_api_key", | |
| os.getenv("GOOGLE_PSE_API_KEY", ""), | |
| ) | |
| GOOGLE_PSE_ENGINE_ID = PersistentConfig( | |
| "GOOGLE_PSE_ENGINE_ID", | |
| "rag.web.search.google_pse_engine_id", | |
| os.getenv("GOOGLE_PSE_ENGINE_ID", ""), | |
| ) | |
| BRAVE_SEARCH_API_KEY = PersistentConfig( | |
| "BRAVE_SEARCH_API_KEY", | |
| "rag.web.search.brave_search_api_key", | |
| os.getenv("BRAVE_SEARCH_API_KEY", ""), | |
| ) | |
| SERPSTACK_API_KEY = PersistentConfig( | |
| "SERPSTACK_API_KEY", | |
| "rag.web.search.serpstack_api_key", | |
| os.getenv("SERPSTACK_API_KEY", ""), | |
| ) | |
| SERPSTACK_HTTPS = PersistentConfig( | |
| "SERPSTACK_HTTPS", | |
| "rag.web.search.serpstack_https", | |
| os.getenv("SERPSTACK_HTTPS", "True").lower() == "true", | |
| ) | |
| SERPER_API_KEY = PersistentConfig( | |
| "SERPER_API_KEY", | |
| "rag.web.search.serper_api_key", | |
| os.getenv("SERPER_API_KEY", ""), | |
| ) | |
| SERPLY_API_KEY = PersistentConfig( | |
| "SERPLY_API_KEY", | |
| "rag.web.search.serply_api_key", | |
| os.getenv("SERPLY_API_KEY", ""), | |
| ) | |
| TAVILY_API_KEY = PersistentConfig( | |
| "TAVILY_API_KEY", | |
| "rag.web.search.tavily_api_key", | |
| os.getenv("TAVILY_API_KEY", ""), | |
| ) | |
| SEARCHAPI_API_KEY = PersistentConfig( | |
| "SEARCHAPI_API_KEY", | |
| "rag.web.search.searchapi_api_key", | |
| os.getenv("SEARCHAPI_API_KEY", ""), | |
| ) | |
| SEARCHAPI_ENGINE = PersistentConfig( | |
| "SEARCHAPI_ENGINE", | |
| "rag.web.search.searchapi_engine", | |
| os.getenv("SEARCHAPI_ENGINE", ""), | |
| ) | |
| RAG_WEB_SEARCH_RESULT_COUNT = PersistentConfig( | |
| "RAG_WEB_SEARCH_RESULT_COUNT", | |
| "rag.web.search.result_count", | |
| int(os.getenv("RAG_WEB_SEARCH_RESULT_COUNT", "3")), | |
| ) | |
| RAG_WEB_SEARCH_CONCURRENT_REQUESTS = PersistentConfig( | |
| "RAG_WEB_SEARCH_CONCURRENT_REQUESTS", | |
| "rag.web.search.concurrent_requests", | |
| int(os.getenv("RAG_WEB_SEARCH_CONCURRENT_REQUESTS", "10")), | |
| ) | |
| #################################### | |
| # Transcribe | |
| #################################### | |
| WHISPER_MODEL = os.getenv("WHISPER_MODEL", "base") | |
| WHISPER_MODEL_DIR = os.getenv("WHISPER_MODEL_DIR", f"{CACHE_DIR}/whisper/models") | |
| WHISPER_MODEL_AUTO_UPDATE = ( | |
| os.environ.get("WHISPER_MODEL_AUTO_UPDATE", "").lower() == "true" | |
| ) | |
| #################################### | |
| # Images | |
| #################################### | |
| IMAGE_GENERATION_ENGINE = PersistentConfig( | |
| "IMAGE_GENERATION_ENGINE", | |
| "image_generation.engine", | |
| os.getenv("IMAGE_GENERATION_ENGINE", "openai"), | |
| ) | |
| ENABLE_IMAGE_GENERATION = PersistentConfig( | |
| "ENABLE_IMAGE_GENERATION", | |
| "image_generation.enable", | |
| os.environ.get("ENABLE_IMAGE_GENERATION", "").lower() == "true", | |
| ) | |
| AUTOMATIC1111_BASE_URL = PersistentConfig( | |
| "AUTOMATIC1111_BASE_URL", | |
| "image_generation.automatic1111.base_url", | |
| os.getenv("AUTOMATIC1111_BASE_URL", ""), | |
| ) | |
| AUTOMATIC1111_API_AUTH = PersistentConfig( | |
| "AUTOMATIC1111_API_AUTH", | |
| "image_generation.automatic1111.api_auth", | |
| os.getenv("AUTOMATIC1111_API_AUTH", ""), | |
| ) | |
| AUTOMATIC1111_CFG_SCALE = PersistentConfig( | |
| "AUTOMATIC1111_CFG_SCALE", | |
| "image_generation.automatic1111.cfg_scale", | |
| ( | |
| float(os.environ.get("AUTOMATIC1111_CFG_SCALE")) | |
| if os.environ.get("AUTOMATIC1111_CFG_SCALE") | |
| else None | |
| ), | |
| ) | |
| AUTOMATIC1111_SAMPLER = PersistentConfig( | |
| "AUTOMATIC1111_SAMPLERE", | |
| "image_generation.automatic1111.sampler", | |
| ( | |
| os.environ.get("AUTOMATIC1111_SAMPLER") | |
| if os.environ.get("AUTOMATIC1111_SAMPLER") | |
| else None | |
| ), | |
| ) | |
| AUTOMATIC1111_SCHEDULER = PersistentConfig( | |
| "AUTOMATIC1111_SCHEDULER", | |
| "image_generation.automatic1111.scheduler", | |
| ( | |
| os.environ.get("AUTOMATIC1111_SCHEDULER") | |
| if os.environ.get("AUTOMATIC1111_SCHEDULER") | |
| else None | |
| ), | |
| ) | |
| COMFYUI_BASE_URL = PersistentConfig( | |
| "COMFYUI_BASE_URL", | |
| "image_generation.comfyui.base_url", | |
| os.getenv("COMFYUI_BASE_URL", ""), | |
| ) | |
| COMFYUI_DEFAULT_WORKFLOW = """ | |
| { | |
| "3": { | |
| "inputs": { | |
| "seed": 0, | |
| "steps": 20, | |
| "cfg": 8, | |
| "sampler_name": "euler", | |
| "scheduler": "normal", | |
| "denoise": 1, | |
| "model": [ | |
| "4", | |
| 0 | |
| ], | |
| "positive": [ | |
| "6", | |
| 0 | |
| ], | |
| "negative": [ | |
| "7", | |
| 0 | |
| ], | |
| "latent_image": [ | |
| "5", | |
| 0 | |
| ] | |
| }, | |
| "class_type": "KSampler", | |
| "_meta": { | |
| "title": "KSampler" | |
| } | |
| }, | |
| "4": { | |
| "inputs": { | |
| "ckpt_name": "model.safetensors" | |
| }, | |
| "class_type": "CheckpointLoaderSimple", | |
| "_meta": { | |
| "title": "Load Checkpoint" | |
| } | |
| }, | |
| "5": { | |
| "inputs": { | |
| "width": 512, | |
| "height": 512, | |
| "batch_size": 1 | |
| }, | |
| "class_type": "EmptyLatentImage", | |
| "_meta": { | |
| "title": "Empty Latent Image" | |
| } | |
| }, | |
| "6": { | |
| "inputs": { | |
| "text": "Prompt", | |
| "clip": [ | |
| "4", | |
| 1 | |
| ] | |
| }, | |
| "class_type": "CLIPTextEncode", | |
| "_meta": { | |
| "title": "CLIP Text Encode (Prompt)" | |
| } | |
| }, | |
| "7": { | |
| "inputs": { | |
| "text": "", | |
| "clip": [ | |
| "4", | |
| 1 | |
| ] | |
| }, | |
| "class_type": "CLIPTextEncode", | |
| "_meta": { | |
| "title": "CLIP Text Encode (Prompt)" | |
| } | |
| }, | |
| "8": { | |
| "inputs": { | |
| "samples": [ | |
| "3", | |
| 0 | |
| ], | |
| "vae": [ | |
| "4", | |
| 2 | |
| ] | |
| }, | |
| "class_type": "VAEDecode", | |
| "_meta": { | |
| "title": "VAE Decode" | |
| } | |
| }, | |
| "9": { | |
| "inputs": { | |
| "filename_prefix": "ComfyUI", | |
| "images": [ | |
| "8", | |
| 0 | |
| ] | |
| }, | |
| "class_type": "SaveImage", | |
| "_meta": { | |
| "title": "Save Image" | |
| } | |
| } | |
| } | |
| """ | |
| COMFYUI_WORKFLOW = PersistentConfig( | |
| "COMFYUI_WORKFLOW", | |
| "image_generation.comfyui.workflow", | |
| os.getenv("COMFYUI_WORKFLOW", COMFYUI_DEFAULT_WORKFLOW), | |
| ) | |
| COMFYUI_WORKFLOW_NODES = PersistentConfig( | |
| "COMFYUI_WORKFLOW", | |
| "image_generation.comfyui.nodes", | |
| [], | |
| ) | |
| IMAGES_OPENAI_API_BASE_URL = PersistentConfig( | |
| "IMAGES_OPENAI_API_BASE_URL", | |
| "image_generation.openai.api_base_url", | |
| os.getenv("IMAGES_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL), | |
| ) | |
| IMAGES_OPENAI_API_KEY = PersistentConfig( | |
| "IMAGES_OPENAI_API_KEY", | |
| "image_generation.openai.api_key", | |
| os.getenv("IMAGES_OPENAI_API_KEY", OPENAI_API_KEY), | |
| ) | |
| IMAGE_SIZE = PersistentConfig( | |
| "IMAGE_SIZE", "image_generation.size", os.getenv("IMAGE_SIZE", "512x512") | |
| ) | |
| IMAGE_STEPS = PersistentConfig( | |
| "IMAGE_STEPS", "image_generation.steps", int(os.getenv("IMAGE_STEPS", 50)) | |
| ) | |
| IMAGE_GENERATION_MODEL = PersistentConfig( | |
| "IMAGE_GENERATION_MODEL", | |
| "image_generation.model", | |
| os.getenv("IMAGE_GENERATION_MODEL", ""), | |
| ) | |
| #################################### | |
| # Audio | |
| #################################### | |
| AUDIO_STT_OPENAI_API_BASE_URL = PersistentConfig( | |
| "AUDIO_STT_OPENAI_API_BASE_URL", | |
| "audio.stt.openai.api_base_url", | |
| os.getenv("AUDIO_STT_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL), | |
| ) | |
| AUDIO_STT_OPENAI_API_KEY = PersistentConfig( | |
| "AUDIO_STT_OPENAI_API_KEY", | |
| "audio.stt.openai.api_key", | |
| os.getenv("AUDIO_STT_OPENAI_API_KEY", OPENAI_API_KEY), | |
| ) | |
| AUDIO_STT_ENGINE = PersistentConfig( | |
| "AUDIO_STT_ENGINE", | |
| "audio.stt.engine", | |
| os.getenv("AUDIO_STT_ENGINE", ""), | |
| ) | |
| AUDIO_STT_MODEL = PersistentConfig( | |
| "AUDIO_STT_MODEL", | |
| "audio.stt.model", | |
| os.getenv("AUDIO_STT_MODEL", "whisper-1"), | |
| ) | |
| AUDIO_TTS_OPENAI_API_BASE_URL = PersistentConfig( | |
| "AUDIO_TTS_OPENAI_API_BASE_URL", | |
| "audio.tts.openai.api_base_url", | |
| os.getenv("AUDIO_TTS_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL), | |
| ) | |
| AUDIO_TTS_OPENAI_API_KEY = PersistentConfig( | |
| "AUDIO_TTS_OPENAI_API_KEY", | |
| "audio.tts.openai.api_key", | |
| os.getenv("AUDIO_TTS_OPENAI_API_KEY", OPENAI_API_KEY), | |
| ) | |
| AUDIO_TTS_API_KEY = PersistentConfig( | |
| "AUDIO_TTS_API_KEY", | |
| "audio.tts.api_key", | |
| os.getenv("AUDIO_TTS_API_KEY", ""), | |
| ) | |
| AUDIO_TTS_ENGINE = PersistentConfig( | |
| "AUDIO_TTS_ENGINE", | |
| "audio.tts.engine", | |
| os.getenv("AUDIO_TTS_ENGINE", ""), | |
| ) | |
| AUDIO_TTS_MODEL = PersistentConfig( | |
| "AUDIO_TTS_MODEL", | |
| "audio.tts.model", | |
| os.getenv("AUDIO_TTS_MODEL", "tts-1"), # OpenAI default model | |
| ) | |
| AUDIO_TTS_VOICE = PersistentConfig( | |
| "AUDIO_TTS_VOICE", | |
| "audio.tts.voice", | |
| os.getenv("AUDIO_TTS_VOICE", "alloy"), # OpenAI default voice | |
| ) | |
| AUDIO_TTS_SPLIT_ON = PersistentConfig( | |
| "AUDIO_TTS_SPLIT_ON", | |
| "audio.tts.split_on", | |
| os.getenv("AUDIO_TTS_SPLIT_ON", "punctuation"), | |
| ) | |
| AUDIO_TTS_AZURE_SPEECH_REGION = PersistentConfig( | |
| "AUDIO_TTS_AZURE_SPEECH_REGION", | |
| "audio.tts.azure.speech_region", | |
| os.getenv("AUDIO_TTS_AZURE_SPEECH_REGION", "eastus"), | |
| ) | |
| AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT = PersistentConfig( | |
| "AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT", | |
| "audio.tts.azure.speech_output_format", | |
| os.getenv( | |
| "AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT", "audio-24khz-160kbitrate-mono-mp3" | |
| ), | |
| ) | |