HonestAI / src /config.py
JatsTheAIGen's picture
Update model IDs to use Cerebras deployment and add gated repository error handling
b3aba24
raw
history blame
16.1 kB
"""
Configuration Management Module
This module provides secure, robust configuration management with:
- Environment variable handling with secure defaults
- Cache directory management with automatic fallbacks
- Comprehensive logging and error handling
- Security best practices for sensitive data
- Backward compatibility with existing code
Environment Variables:
HF_TOKEN: HuggingFace API token (required for API access)
HF_HOME: Primary cache directory for HuggingFace models
TRANSFORMERS_CACHE: Alternative cache directory path
MAX_WORKERS: Maximum worker threads (default: 4)
CACHE_TTL: Cache time-to-live in seconds (default: 3600)
DB_PATH: Database file path (default: sessions.db)
LOG_LEVEL: Logging level (default: INFO)
LOG_FORMAT: Log format (default: json)
Security Notes:
- Never commit .env files to version control
- Use environment variables for all sensitive data
- Cache directories are automatically secured with proper permissions
"""
import os
import logging
from pathlib import Path
from typing import Optional
from pydantic_settings import BaseSettings
from pydantic import Field, validator
# Configure logging
logger = logging.getLogger(__name__)
class CacheDirectoryManager:
"""
Manages cache directory with secure fallback mechanism.
Implements:
- Multi-level fallback strategy
- Permission validation
- Automatic directory creation
- Security best practices
"""
@staticmethod
def get_cache_directory() -> str:
"""
Get cache directory with secure fallback chain.
Priority order:
1. HF_HOME environment variable
2. TRANSFORMERS_CACHE environment variable
3. User home directory (~/.cache/huggingface)
4. User-specific fallback directory
5. Temporary directory (last resort)
Returns:
str: Path to writable cache directory
"""
cache_candidates = [
os.getenv("HF_HOME"),
os.getenv("TRANSFORMERS_CACHE"),
os.path.join(os.path.expanduser("~"), ".cache", "huggingface") if os.path.expanduser("~") else None,
os.path.join(os.path.expanduser("~"), ".cache", "huggingface_fallback") if os.path.expanduser("~") else None,
"/tmp/huggingface_cache"
]
for cache_dir in cache_candidates:
if not cache_dir:
continue
try:
# Ensure directory exists
cache_path = Path(cache_dir)
cache_path.mkdir(parents=True, exist_ok=True)
# Set secure permissions (rwxr-xr-x)
try:
os.chmod(cache_path, 0o755)
except (OSError, PermissionError):
# If we can't set permissions, continue if directory is writable
pass
# Test write access
test_file = cache_path / ".write_test"
try:
test_file.write_text("test")
test_file.unlink()
logger.info(f"✓ Cache directory verified: {cache_dir}")
return str(cache_path)
except (PermissionError, OSError) as e:
logger.debug(f"Write test failed for {cache_dir}: {e}")
continue
except (PermissionError, OSError) as e:
logger.debug(f"Could not create/access {cache_dir}: {e}")
continue
# If all candidates failed, use emergency fallback
fallback = "/tmp/huggingface_emergency"
try:
Path(fallback).mkdir(parents=True, exist_ok=True)
logger.warning(f"Using emergency fallback cache: {fallback}")
return fallback
except Exception as e:
logger.error(f"Emergency fallback also failed: {e}")
# Return a default that will fail gracefully later
return "/tmp/huggingface"
class Settings(BaseSettings):
"""
Application settings with secure defaults and validation.
Backward Compatibility:
- All existing attributes are preserved
- hf_token is accessible as string (via property)
- hf_cache_dir is accessible as property (works like before)
- All defaults match original implementation
"""
# ==================== HuggingFace Configuration ====================
# BACKWARD COMPAT: hf_token as regular field (backward compatible)
hf_token: str = Field(
default="",
description="HuggingFace API token",
env="HF_TOKEN"
)
@validator("hf_token", pre=True)
def validate_hf_token(cls, v):
"""Validate HF token (backward compatible)"""
if v is None:
return ""
token = str(v) if v else ""
if not token:
logger.debug("HF_TOKEN not set")
return token
@property
def hf_cache_dir(self) -> str:
"""
Get cache directory with automatic fallback and validation.
BACKWARD COMPAT: Works like the original hf_cache_dir field.
Returns:
str: Path to writable cache directory
"""
if not hasattr(self, '_cached_cache_dir'):
try:
self._cached_cache_dir = CacheDirectoryManager.get_cache_directory()
except Exception as e:
logger.error(f"Cache directory setup failed: {e}")
# Fallback to original default
fallback = os.getenv("HF_HOME", "/tmp/huggingface")
Path(fallback).mkdir(parents=True, exist_ok=True)
self._cached_cache_dir = fallback
return self._cached_cache_dir
# ==================== Model Configuration ====================
default_model: str = Field(
default="meta-llama/Llama-3.1-8B-Instruct:cerebras",
description="Primary model for reasoning tasks (Cerebras deployment with 4-bit quantization)"
)
embedding_model: str = Field(
default="intfloat/e5-large-v2",
description="Model for embeddings (upgraded: 1024-dim embeddings)"
)
classification_model: str = Field(
default="meta-llama/Llama-3.1-8B-Instruct:cerebras",
description="Model for classification tasks (Cerebras deployment)"
)
# ==================== Performance Configuration ====================
max_workers: int = Field(
default=4,
description="Maximum worker threads for parallel processing",
env="MAX_WORKERS"
)
@validator("max_workers", pre=True)
def validate_max_workers(cls, v):
"""Validate and convert max_workers (backward compatible)"""
if v is None:
return 4
if isinstance(v, str):
try:
v = int(v)
except ValueError:
logger.warning(f"Invalid MAX_WORKERS value: {v}, using default 4")
return 4
try:
val = int(v)
return max(1, min(16, val)) # Clamp between 1 and 16
except (ValueError, TypeError):
return 4
cache_ttl: int = Field(
default=3600,
description="Cache time-to-live in seconds",
env="CACHE_TTL"
)
@validator("cache_ttl", pre=True)
def validate_cache_ttl(cls, v):
"""Validate cache TTL (backward compatible)"""
if v is None:
return 3600
if isinstance(v, str):
try:
v = int(v)
except ValueError:
return 3600
try:
return max(0, int(v))
except (ValueError, TypeError):
return 3600
# ==================== Database Configuration ====================
db_path: str = Field(
default="sessions.db",
description="Path to SQLite database file",
env="DB_PATH"
)
@validator("db_path", pre=True)
def validate_db_path(cls, v):
"""Validate db_path with Docker fallback (backward compatible)"""
if v is None:
# Check if we're in Docker (HF Spaces) - if so, use /tmp
if os.path.exists("/.dockerenv") or os.path.exists("/tmp"):
return "/tmp/sessions.db"
return "sessions.db"
return str(v)
faiss_index_path: str = Field(
default="embeddings.faiss",
description="Path to FAISS index file",
env="FAISS_INDEX_PATH"
)
@validator("faiss_index_path", pre=True)
def validate_faiss_path(cls, v):
"""Validate faiss path with Docker fallback (backward compatible)"""
if v is None:
# Check if we're in Docker (HF Spaces) - if so, use /tmp
if os.path.exists("/.dockerenv") or os.path.exists("/tmp"):
return "/tmp/embeddings.faiss"
return "embeddings.faiss"
return str(v)
# ==================== Session Configuration ====================
session_timeout: int = Field(
default=3600,
description="Session timeout in seconds",
env="SESSION_TIMEOUT"
)
@validator("session_timeout", pre=True)
def validate_session_timeout(cls, v):
"""Validate session timeout (backward compatible)"""
if v is None:
return 3600
if isinstance(v, str):
try:
v = int(v)
except ValueError:
return 3600
try:
return max(60, int(v))
except (ValueError, TypeError):
return 3600
max_session_size_mb: int = Field(
default=10,
description="Maximum session size in megabytes",
env="MAX_SESSION_SIZE_MB"
)
@validator("max_session_size_mb", pre=True)
def validate_max_session_size(cls, v):
"""Validate max session size (backward compatible)"""
if v is None:
return 10
if isinstance(v, str):
try:
v = int(v)
except ValueError:
return 10
try:
return max(1, min(100, int(v)))
except (ValueError, TypeError):
return 10
# ==================== Mobile Optimization ====================
mobile_max_tokens: int = Field(
default=800,
description="Maximum tokens for mobile responses",
env="MOBILE_MAX_TOKENS"
)
@validator("mobile_max_tokens", pre=True)
def validate_mobile_max_tokens(cls, v):
"""Validate mobile max tokens (backward compatible)"""
if v is None:
return 800
if isinstance(v, str):
try:
v = int(v)
except ValueError:
return 800
try:
return max(100, min(2000, int(v)))
except (ValueError, TypeError):
return 800
mobile_timeout: int = Field(
default=15000,
description="Mobile request timeout in milliseconds",
env="MOBILE_TIMEOUT"
)
@validator("mobile_timeout", pre=True)
def validate_mobile_timeout(cls, v):
"""Validate mobile timeout (backward compatible)"""
if v is None:
return 15000
if isinstance(v, str):
try:
v = int(v)
except ValueError:
return 15000
try:
return max(5000, min(60000, int(v)))
except (ValueError, TypeError):
return 15000
# ==================== API Configuration ====================
gradio_port: int = Field(
default=7860,
description="Gradio server port",
env="GRADIO_PORT"
)
@validator("gradio_port", pre=True)
def validate_gradio_port(cls, v):
"""Validate gradio port (backward compatible)"""
if v is None:
return 7860
if isinstance(v, str):
try:
v = int(v)
except ValueError:
return 7860
try:
return max(1024, min(65535, int(v)))
except (ValueError, TypeError):
return 7860
gradio_host: str = Field(
default="0.0.0.0",
description="Gradio server host",
env="GRADIO_HOST"
)
# ==================== Logging Configuration ====================
log_level: str = Field(
default="INFO",
description="Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)",
env="LOG_LEVEL"
)
@validator("log_level")
def validate_log_level(cls, v):
"""Validate log level (backward compatible)"""
if not v:
return "INFO"
valid_levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]
if v.upper() not in valid_levels:
logger.warning(f"Invalid log level: {v}, using INFO")
return "INFO"
return v.upper()
log_format: str = Field(
default="json",
description="Log format (json or text)",
env="LOG_FORMAT"
)
@validator("log_format")
def validate_log_format(cls, v):
"""Validate log format (backward compatible)"""
if not v:
return "json"
if v.lower() not in ["json", "text"]:
logger.warning(f"Invalid log format: {v}, using json")
return "json"
return v.lower()
# ==================== Pydantic Configuration ====================
class Config:
"""Pydantic configuration"""
env_file = ".env"
env_file_encoding = "utf-8"
case_sensitive = False
validate_assignment = True
# Allow extra fields for backward compatibility
extra = "ignore"
# ==================== Utility Methods ====================
def validate_configuration(self) -> bool:
"""
Validate configuration and log status.
Returns:
bool: True if configuration is valid, False otherwise
"""
try:
# Validate cache directory
cache_dir = self.hf_cache_dir
if logger.isEnabledFor(logging.INFO):
logger.info("Configuration validated:")
logger.info(f" - Cache directory: {cache_dir}")
logger.info(f" - Max workers: {self.max_workers}")
logger.info(f" - Log level: {self.log_level}")
logger.info(f" - HF token: {'Set' if self.hf_token else 'Not set'}")
return True
except Exception as e:
logger.error(f"Configuration validation failed: {e}")
return False
# ==================== Global Settings Instance ====================
def get_settings() -> Settings:
"""
Get or create global settings instance.
Returns:
Settings: Global settings instance
Note:
This function ensures settings are loaded once and cached.
"""
if not hasattr(get_settings, '_instance'):
get_settings._instance = Settings()
# Validate on first load (non-blocking)
try:
get_settings._instance.validate_configuration()
except Exception as e:
logger.warning(f"Configuration validation warning: {e}")
return get_settings._instance
# Create global settings instance (backward compatible)
settings = get_settings()
# Log configuration on import (at INFO level, non-blocking)
if logger.isEnabledFor(logging.INFO):
try:
logger.info("=" * 60)
logger.info("Configuration Loaded")
logger.info("=" * 60)
logger.info(f"Cache directory: {settings.hf_cache_dir}")
logger.info(f"Max workers: {settings.max_workers}")
logger.info(f"Log level: {settings.log_level}")
logger.info("=" * 60)
except Exception as e:
logger.debug(f"Configuration logging skipped: {e}")