exaucengarti commited on
Commit
26b321b
·
verified ·
1 Parent(s): ee62ce4

Update common/model.py

Browse files
Files changed (1) hide show
  1. common/model.py +122 -122
common/model.py CHANGED
@@ -1,123 +1,123 @@
1
- import os
2
- from typing import Optional
3
- import torch
4
- from transformers import AutoModelForCausalLM, AutoTokenizer
5
-
6
- class PoetryModel:
7
- """
8
- Minimal wrapper for two choices:
9
- - HuggingFace Llama 3.1 8B Instruct
10
- - OpenAI gpt-5-mini
11
-
12
- Use model_name="llama3.1_8b" or "openai".
13
- """
14
-
15
- HF_MODEL_ID = "meta-llama/Llama-3.1-8B-Instruct"
16
- OPENAI_MODEL_ID = "gpt-5-mini"
17
-
18
- def __init__(
19
- self,
20
- model_name: Optional[str] = None,
21
- device: str = "cpu",
22
- openai_api_key: Optional[str] = None,
23
- use_llama_guard: bool = False
24
- ):
25
- self.device = device
26
- self.model_type = "hf"
27
- self.openai_client = None
28
- self.tokenizer = None
29
- self.model = None
30
- self.use_llama_guard = use_llama_guard
31
- self.guard_model = None
32
- self.guard_tokenizer = None
33
-
34
- model_name = model_name or "llama3.1_8b"
35
-
36
- if model_name == "openai":
37
- self.model_type = "openai"
38
- self.model_name = self.OPENAI_MODEL_ID
39
- key = openai_api_key or os.getenv("OPENAI_API_KEY")
40
- if not key:
41
- raise ValueError("OPENAI_API_KEY missing for OpenAI usage.")
42
- try:
43
- from openai import OpenAI
44
- self.openai_client = OpenAI(api_key=key)
45
- except ImportError:
46
- raise ImportError("Install OpenAI client: pip install openai")
47
- elif model_name == "llama3.1_8b":
48
- self.model_name = self.HF_MODEL_ID
49
- self._load_hf()
50
- else:
51
- raise ValueError("model_name must be 'llama3.1_8b' or 'openai'.")
52
-
53
- if self.use_llama_guard:
54
- self._load_llama_guard()
55
-
56
- def _load_hf(self):
57
- self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True)
58
- if self.tokenizer.pad_token is None:
59
- self.tokenizer.pad_token = self.tokenizer.eos_token
60
- dtype = torch.float16 if self.device == "cuda" else torch.float32
61
- self.model = AutoModelForCausalLM.from_pretrained(
62
- self.model_name,
63
- dtype=dtype,
64
- trust_remote_code=True,
65
- low_cpu_mem_usage=True
66
- )
67
- self.model.to(self.device)
68
-
69
- def _load_llama_guard(self):
70
- try:
71
- guard_id = "meta-llama/LlamaGuard-7b"
72
- self.guard_tokenizer = AutoTokenizer.from_pretrained(guard_id)
73
- self.guard_model = AutoModelForCausalLM.from_pretrained(guard_id)
74
- self.guard_model.to(self.device)
75
- except Exception as e:
76
- print(f"Skipping LlamaGuard: {e}")
77
- self.use_llama_guard = False
78
-
79
- def _check_with_llama_guard(self, text: str) -> bool:
80
- if not (self.use_llama_guard and self.guard_model):
81
- return True
82
- try:
83
- prompt = f"[INST] Is this safe for children? {text} [/INST]"
84
- inputs = self.guard_tokenizer(prompt, return_tensors="pt").to(self.device)
85
- out = self.guard_model.generate(**inputs, max_new_tokens=16)
86
- resp = self.guard_tokenizer.decode(out[0], skip_special_tokens=True).lower()
87
- return "safe" in resp
88
- except Exception:
89
- return True
90
-
91
- def generate(self, prompt: str, max_tokens: int = 128, temperature: float = 0.7) -> str:
92
- if self.model_type == "openai":
93
- try:
94
- resp = self.openai_client.responses.create(
95
- model=self.model_name,
96
- input=prompt,
97
- text={
98
- "verbosity": "medium"
99
- }
100
- )
101
- text = (resp.output_text or "").strip()
102
- except Exception as e:
103
- return f"Error: {e}"
104
- else:
105
- inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024).to(self.device)
106
- gen = self.model.generate(
107
- **inputs,
108
- max_new_tokens=max_tokens,
109
- temperature=temperature,
110
- do_sample=True,
111
- top_k=50,
112
- top_p=0.95,
113
- pad_token_id=self.tokenizer.pad_token_id
114
- )
115
- decoded = self.tokenizer.decode(gen[0], skip_special_tokens=True)
116
- if decoded.startswith(prompt):
117
- text = decoded[len(prompt):].strip()
118
- else:
119
- text = decoded.strip()
120
-
121
- if self.use_llama_guard and not self._check_with_llama_guard(text):
122
- return "Content filtered for safety."
123
  return text
 
1
+ import os
2
+ from typing import Optional
3
+ import torch
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer
5
+
6
+ class PoetryModel:
7
+ """
8
+ Minimal wrapper for two choices:
9
+ - HuggingFace Llama 3.1 8B Instruct
10
+ - OpenAI gpt-5-mini
11
+
12
+ Use model_name="llama3.1_8b" or "openai".
13
+ """
14
+
15
+ HF_MODEL_ID = "meta-llama/Llama-3.1-8B-Instruct"
16
+ OPENAI_MODEL_ID = "gpt-5-mini"
17
+
18
+ def __init__(
19
+ self,
20
+ model_name: Optional[str] = None,
21
+ device: str = "cpu",
22
+ openai_api_key: Optional[str] = None,
23
+ use_llama_guard: bool = False
24
+ ):
25
+ self.device = device
26
+ self.model_type = "hf"
27
+ self.openai_client = None
28
+ self.tokenizer = None
29
+ self.model = None
30
+ self.use_llama_guard = use_llama_guard
31
+ self.guard_model = None
32
+ self.guard_tokenizer = None
33
+
34
+ model_name = os.getenv("DEFAULT_MODEL")
35
+
36
+ if model_name == "openai":
37
+ self.model_type = "openai"
38
+ self.model_name = self.OPENAI_MODEL_ID
39
+ key = openai_api_key or os.getenv("OPENAI_API_KEY")
40
+ if not key:
41
+ raise ValueError("OPENAI_API_KEY missing for OpenAI usage.")
42
+ try:
43
+ from openai import OpenAI
44
+ self.openai_client = OpenAI(api_key=key)
45
+ except ImportError:
46
+ raise ImportError("Install OpenAI client: pip install openai")
47
+ elif model_name == "llama3.1_8b":
48
+ self.model_name = self.HF_MODEL_ID
49
+ self._load_hf()
50
+ else:
51
+ raise ValueError("model_name must be 'llama3.1_8b' or 'openai'.")
52
+
53
+ if self.use_llama_guard:
54
+ self._load_llama_guard()
55
+
56
+ def _load_hf(self):
57
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True)
58
+ if self.tokenizer.pad_token is None:
59
+ self.tokenizer.pad_token = self.tokenizer.eos_token
60
+ dtype = torch.float16 if self.device == "cuda" else torch.float32
61
+ self.model = AutoModelForCausalLM.from_pretrained(
62
+ self.model_name,
63
+ dtype=dtype,
64
+ trust_remote_code=True,
65
+ low_cpu_mem_usage=True
66
+ )
67
+ self.model.to(self.device)
68
+
69
+ def _load_llama_guard(self):
70
+ try:
71
+ guard_id = "meta-llama/LlamaGuard-7b"
72
+ self.guard_tokenizer = AutoTokenizer.from_pretrained(guard_id)
73
+ self.guard_model = AutoModelForCausalLM.from_pretrained(guard_id)
74
+ self.guard_model.to(self.device)
75
+ except Exception as e:
76
+ print(f"Skipping LlamaGuard: {e}")
77
+ self.use_llama_guard = False
78
+
79
+ def _check_with_llama_guard(self, text: str) -> bool:
80
+ if not (self.use_llama_guard and self.guard_model):
81
+ return True
82
+ try:
83
+ prompt = f"[INST] Is this safe for children? {text} [/INST]"
84
+ inputs = self.guard_tokenizer(prompt, return_tensors="pt").to(self.device)
85
+ out = self.guard_model.generate(**inputs, max_new_tokens=16)
86
+ resp = self.guard_tokenizer.decode(out[0], skip_special_tokens=True).lower()
87
+ return "safe" in resp
88
+ except Exception:
89
+ return True
90
+
91
+ def generate(self, prompt: str, max_tokens: int = 128, temperature: float = 0.7) -> str:
92
+ if self.model_type == "openai":
93
+ try:
94
+ resp = self.openai_client.responses.create(
95
+ model=self.model_name,
96
+ input=prompt,
97
+ text={
98
+ "verbosity": "medium"
99
+ }
100
+ )
101
+ text = (resp.output_text or "").strip()
102
+ except Exception as e:
103
+ return f"Error: {e}"
104
+ else:
105
+ inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024).to(self.device)
106
+ gen = self.model.generate(
107
+ **inputs,
108
+ max_new_tokens=max_tokens,
109
+ temperature=temperature,
110
+ do_sample=True,
111
+ top_k=50,
112
+ top_p=0.95,
113
+ pad_token_id=self.tokenizer.pad_token_id
114
+ )
115
+ decoded = self.tokenizer.decode(gen[0], skip_special_tokens=True)
116
+ if decoded.startswith(prompt):
117
+ text = decoded[len(prompt):].strip()
118
+ else:
119
+ text = decoded.strip()
120
+
121
+ if self.use_llama_guard and not self._check_with_llama_guard(text):
122
+ return "Content filtered for safety."
123
  return text