thelip commited on
Commit
ce47731
·
verified ·
1 Parent(s): f00d573

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +122 -0
app.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ import torch
5
+ import os
6
+
7
+ # --- Configuration ---
8
+ MODEL_NAME = os.getenv("MODEL_NAME", "google/gemma-2b-it") # Or "google/gemma-7b-it" if you have resources
9
+ DEVICE = "cpu" # Explicitly set to CPU
10
+ TORCH_DTYPE = torch.float32 # Use float32 for CPU for broader compatibility and stability
11
+ # For some newer CPUs, bfloat16 might offer speedups if supported
12
+ # but can sometimes be less stable or require specific setups.
13
+
14
+ # --- Model Loading ---
15
+ # This will run when the Docker container starts, or when the app is first imported.
16
+ # It might take a few minutes for larger models.
17
+ print(f"Loading model: {MODEL_NAME} on {DEVICE} with dtype {TORCH_DTYPE}...")
18
+ try:
19
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
20
+ model = AutoModelForCausalLM.from_pretrained(
21
+ MODEL_NAME,
22
+ torch_dtype=TORCH_DTYPE,
23
+ # low_cpu_mem_usage=True, # Can be useful for very large models on CPU, but might slow down loading
24
+ # device_map="auto" # 'auto' will select CPU if no GPU is available or if specified.
25
+ # Forcing CPU ensures no GPU attempts.
26
+ )
27
+ model.to(DEVICE) # Ensure model is on CPU
28
+ print(f"Model {MODEL_NAME} loaded successfully on {DEVICE}.")
29
+ except Exception as e:
30
+ print(f"Error loading model: {e}")
31
+ # If model loading fails, we can't serve requests.
32
+ # Depending on deployment, you might want to exit or handle this differently.
33
+ raise RuntimeError(f"Failed to load model: {e}") from e
34
+
35
+
36
+ # --- FastAPI App ---
37
+ app = FastAPI(
38
+ title="Gemma CPU Inference API",
39
+ description="API to run inference on a Gemma model using CPU.",
40
+ version="0.1.0"
41
+ )
42
+
43
+ class GenerationRequest(BaseModel):
44
+ prompt: str
45
+ max_new_tokens: int = 50
46
+ temperature: float = 0.7
47
+ do_sample: bool = True
48
+
49
+ class GenerationResponse(BaseModel):
50
+ generated_text: str
51
+ input_prompt: str
52
+
53
+ @app.post("/generate", response_model=GenerationResponse)
54
+ async def generate_text(request: GenerationRequest):
55
+ """
56
+ Generates text based on the input prompt using the loaded Gemma model.
57
+ """
58
+ if not model or not tokenizer:
59
+ raise HTTPException(status_code=503, detail="Model not loaded or failed to load.")
60
+
61
+ print(f"Received request: {request.prompt[:50]}...") # Log snippet of prompt
62
+
63
+ try:
64
+ # Format prompt for instruction-tuned models (like gemma-*-it)
65
+ # This is a common format, adjust if your model expects something different
66
+ chat = [
67
+ { "role": "user", "content": request.prompt },
68
+ ]
69
+ formatted_prompt = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
70
+
71
+ input_ids = tokenizer(formatted_prompt, return_tensors="pt").to(DEVICE)
72
+
73
+ print(f"Generating text with max_new_tokens={request.max_new_tokens}, temperature={request.temperature}...")
74
+ with torch.no_grad(): # Important for inference
75
+ outputs = model.generate(
76
+ **input_ids,
77
+ max_new_tokens=request.max_new_tokens,
78
+ temperature=request.temperature,
79
+ do_sample=request.do_sample,
80
+ # Add other generation parameters as needed: top_k, top_p, etc.
81
+ )
82
+
83
+ # Decode the generated text (only the new tokens)
84
+ # The generated output includes the input prompt, so we slice it off.
85
+ # For some models, the slice point might need adjustment.
86
+ # decoded_text = tokenizer.decode(outputs[0, input_ids.input_ids.shape[1]:], skip_special_tokens=True)
87
+
88
+ # A more robust way to get only the generated part, especially with chat templates
89
+ full_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
90
+ # Remove the prompt part. This depends on how apply_chat_template works.
91
+ # For many models, the prompt itself is part of the output of apply_chat_template.
92
+ # A simple way if the prompt is directly prepended:
93
+ if full_text.startswith(formatted_prompt.replace("<bos>", "").replace("<eos>", "")): # Handle potential BOS/EOS tokens in prompt
94
+ decoded_text = full_text[len(formatted_prompt.replace("<bos>", "").replace("<eos>", "")):]
95
+ else:
96
+ # Fallback or more sophisticated stripping might be needed depending on the template
97
+ # For Gemma's instruction-tuned template, this usually works by finding the assistant's turn start
98
+ assistant_turn_start = "<start_of_turn>model\n"
99
+ if assistant_turn_start in full_text:
100
+ decoded_text = full_text.split(assistant_turn_start, 1)[-1]
101
+ else:
102
+ # If not found, it might be that the prompt itself wasn't fully included in the output
103
+ # or the template is different. As a simpler fallback, we take the part after input_ids.
104
+ decoded_text = tokenizer.decode(outputs[0, input_ids.input_ids.shape[1]:], skip_special_tokens=True)
105
+
106
+
107
+ print(f"Generated: {decoded_text[:100]}...")
108
+
109
+ return GenerationResponse(generated_text=decoded_text.strip(), input_prompt=request.prompt)
110
+
111
+ except Exception as e:
112
+ print(f"Error during generation: {e}")
113
+ raise HTTPException(status_code=500, detail=f"Error during generation: {str(e)}")
114
+
115
+ @app.get("/")
116
+ async def root():
117
+ return {"message": "Gemma CPU Inference API is running. POST to /generate for inference."}
118
+
119
+ # To run locally (optional, uvicorn in CMD will handle it in Docker)
120
+ if __name__ == "__main__":
121
+ import uvicorn
122
+ uvicorn.run(app, host="0.0.0.0", port=8000)