import os from typing import Optional import torch from transformers import AutoModelForCausalLM, AutoTokenizer class PoetryModel: """ Minimal wrapper for two choices: - HuggingFace Llama 3.1 8B Instruct - OpenAI gpt-5-mini Use model_name="llama3.1_8b" or "openai". """ HF_MODEL_ID = "meta-llama/Llama-3.1-8B-Instruct" OPENAI_MODEL_ID = "gpt-5-mini" def __init__( self, model_name: Optional[str] = None, device: str = "cpu", openai_api_key: Optional[str] = None, use_llama_guard: bool = False ): self.device = device self.model_type = "hf" self.openai_client = None self.tokenizer = None self.model = None self.use_llama_guard = use_llama_guard self.guard_model = None self.guard_tokenizer = None model_name = os.getenv("DEFAULT_MODEL") if model_name == "openai": self.model_type = "openai" self.model_name = self.OPENAI_MODEL_ID key = openai_api_key or os.getenv("OPENAI_API_KEY") if not key: raise ValueError("OPENAI_API_KEY missing for OpenAI usage.") try: from openai import OpenAI self.openai_client = OpenAI(api_key=key) except ImportError: raise ImportError("Install OpenAI client: pip install openai") elif model_name == "llama3.1_8b": self.model_name = self.HF_MODEL_ID self._load_hf() else: raise ValueError("model_name must be 'llama3.1_8b' or 'openai'.") if self.use_llama_guard: self._load_llama_guard() def _load_hf(self): self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token dtype = torch.float16 if self.device == "cuda" else torch.float32 self.model = AutoModelForCausalLM.from_pretrained( self.model_name, dtype=dtype, trust_remote_code=True, low_cpu_mem_usage=True ) self.model.to(self.device) def _load_llama_guard(self): try: guard_id = "meta-llama/LlamaGuard-7b" self.guard_tokenizer = AutoTokenizer.from_pretrained(guard_id) self.guard_model = AutoModelForCausalLM.from_pretrained(guard_id) self.guard_model.to(self.device) except Exception as e: print(f"Skipping LlamaGuard: {e}") self.use_llama_guard = False def _check_with_llama_guard(self, text: str) -> bool: if not (self.use_llama_guard and self.guard_model): return True try: prompt = f"[INST] Is this safe for children? {text} [/INST]" inputs = self.guard_tokenizer(prompt, return_tensors="pt").to(self.device) out = self.guard_model.generate(**inputs, max_new_tokens=16) resp = self.guard_tokenizer.decode(out[0], skip_special_tokens=True).lower() return "safe" in resp except Exception: return True def generate(self, prompt: str, max_tokens: int = 128, temperature: float = 0.7) -> str: if self.model_type == "openai": try: resp = self.openai_client.responses.create( model=self.model_name, input=prompt, text={ "verbosity": "medium" } ) text = (resp.output_text or "").strip() except Exception as e: return f"Error: {e}" else: inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024).to(self.device) gen = self.model.generate( **inputs, max_new_tokens=max_tokens, temperature=temperature, do_sample=True, top_k=50, top_p=0.95, pad_token_id=self.tokenizer.pad_token_id ) decoded = self.tokenizer.decode(gen[0], skip_special_tokens=True) if decoded.startswith(prompt): text = decoded[len(prompt):].strip() else: text = decoded.strip() if self.use_llama_guard and not self._check_with_llama_guard(text): return "Content filtered for safety." return text