AndaiMD commited on
Commit
f04940c
·
1 Parent(s): d1e903b
Files changed (1) hide show
  1. app/main.py +21 -16
app/main.py CHANGED
@@ -9,29 +9,34 @@ model, tokenizer = load_model()
9
  @app.post("/predict")
10
  async def predict(request: Request):
11
  data = await request.json()
12
- input_text = data.get("input", "")
13
-
14
- # Extract last 5 words
15
- last_5_words = " ".join(input_text.strip().split()[-5:])
16
-
17
- # Tokenize and generate continuation
18
- inputs = tokenizer(last_5_words, return_tensors="pt").to(model.device)
19
-
 
 
 
 
 
 
 
 
20
  with torch.no_grad():
21
  outputs = model.generate(
22
  **inputs,
23
  max_new_tokens=20,
24
  do_sample=True,
25
- temperature=0.8,
26
  top_k=50,
27
- top_p=0.95,
28
- pad_token_id=tokenizer.eos_token_id
29
  )
30
 
31
- # Decode generated text
32
- generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
33
-
34
- # Remove the prompt portion to isolate generated words
35
- continuation = generated_text[len(last_5_words):].strip()
36
 
37
  return JSONResponse(content={"output": continuation})
 
9
  @app.post("/predict")
10
  async def predict(request: Request):
11
  data = await request.json()
12
+ raw_abstract = data.get("input", "")
13
+
14
+ # Get the last sentence (or few words) of the abstract
15
+ import re
16
+ sentences = re.split(r'(?<=[.!?]) +', raw_abstract.strip())
17
+ abstract_tail = sentences[-1] if len(sentences) > 1 else raw_abstract
18
+
19
+ # Construct the prompt
20
+ prompt = (
21
+ f"This neuroscience abstract ends as follows:\n"
22
+ f"\"{abstract_tail}\"\n\n"
23
+ f"Complete the next sentence logically:"
24
+ )
25
+
26
+ # Tokenize and generate
27
+ inputs = tokenizer(prompt, return_tensors="pt")
28
  with torch.no_grad():
29
  outputs = model.generate(
30
  **inputs,
31
  max_new_tokens=20,
32
  do_sample=True,
33
+ temperature=0.7,
34
  top_k=50,
35
+ top_p=0.95
 
36
  )
37
 
38
+ # Decode and trim
39
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
40
+ continuation = response[len(prompt):].strip()
 
 
41
 
42
  return JSONResponse(content={"output": continuation})