qiuxi337 commited on
Commit
2009aa5
·
verified ·
1 Parent(s): 387cd41

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +191 -386
app.py CHANGED
@@ -25,33 +25,34 @@ try:
25
  HAS_SPACES = True
26
  except ImportError:
27
  HAS_SPACES = False
28
- # Define placeholder decorator
29
  class spaces:
30
  @staticmethod
31
- def GPU():
32
- def decorator(func):
33
  return func
34
- return decorator
35
 
36
  # Check if GPU is available
37
  HAS_GPU = torch.cuda.is_available()
38
 
39
  # Try to install flash-attn (only in GPU environment)
40
- # if HAS_GPU:
41
- # try:
42
- # import subprocess
43
- # subprocess.run('pip install flash-attn==2.7.4.post1 --no-build-isolation',
44
- # env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"},
45
- # shell=True,
46
- # capture_output=True,
47
- # timeout=30)
48
- # HAS_FLASH_ATTN = True
49
- # except:
50
- # HAS_FLASH_ATTN = False
51
- # else:
52
- # HAS_FLASH_ATTN = False
53
- HAS_FLASH_ATTN = False
 
54
 
 
55
  # Default model checkpoint path
56
  DEFAULT_CKPT_PATH = 'qiuxi337/IntrinSight-4B'
57
 
@@ -59,7 +60,6 @@ DEFAULT_CKPT_PATH = 'qiuxi337/IntrinSight-4B'
59
  DEFAULT_SYSTEM_PROMPT = (
60
  "A conversation between user and assistant. The user asks a question, and the assistant solves it. The assistant "
61
  "first thinks about the reasoning process in the mind and then provides the user with the answer. "
62
- "The reasoning process is to solve the problem step by step, so you will think about it sincerely. "
63
  "The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
64
  "<think> reasoning process here </think><answer> answer here </answer>."
65
  )
@@ -193,9 +193,9 @@ textarea:focus {
193
  def _get_args():
194
  """Parse command line arguments"""
195
  parser = ArgumentParser()
196
- parser.add_argument('-c', '--checkpoint-path',
197
  type=str,
198
- default=DEFAULT_CKPT_PATH,
199
  help='Checkpoint name or path, default to %(default)r')
200
  parser.add_argument('--cpu-only', action='store_true', help='Run demo with CPU only')
201
  parser.add_argument('--share',
@@ -217,26 +217,21 @@ def encode_image_pil(image_path):
217
  """Encode image to base64 using PIL"""
218
  try:
219
  if isinstance(image_path, str):
220
- # It's a file path
221
  img = Image.open(image_path)
222
  elif isinstance(image_path, np.ndarray):
223
- # It's a numpy array
224
  img = Image.fromarray(image_path)
225
  elif isinstance(image_path, Image.Image):
226
- # It's already a PIL Image
227
  img = image_path
228
  else:
229
  print(f"Unsupported image type: {type(image_path)}")
230
  return None
231
-
232
- # Convert to RGB if necessary
233
  if img.mode not in ('RGB', 'RGBA'):
234
  img = img.convert('RGB')
235
-
236
- # Resize if too large
237
  max_size = (1024, 1024)
238
  img.thumbnail(max_size, Image.Resampling.LANCZOS)
239
-
240
  buffered = io.BytesIO()
241
  img.save(buffered, format="PNG")
242
  return base64.b64encode(buffered.getvalue()).decode('utf-8')
@@ -248,54 +243,39 @@ def encode_image_pil(image_path):
248
  def _load_model_processor(args):
249
  """Intelligently load model, automatically choose CPU or GPU based on environment"""
250
  global HAS_GPU, HAS_FLASH_ATTN
251
-
252
- # Determine device to use
253
  use_gpu = HAS_GPU and not args.cpu_only
254
  device = 'cuda' if use_gpu else 'cpu'
255
-
256
  print(f"{'='*50}")
257
  print(f"🚀 Loading model: {args.checkpoint_path}")
258
  print(f"📱 Device: {'GPU (CUDA)' if use_gpu else 'CPU'}")
259
  print(f"⚡ Flash Attention: {'Enabled' if (use_gpu and HAS_FLASH_ATTN) else 'Disabled'}")
260
  print(f"{'='*50}")
261
-
262
- # Choose appropriate configuration based on device
263
  model_kwargs = {
264
  'pretrained_model_name_or_path': args.checkpoint_path,
265
  'torch_dtype': torch.bfloat16 if use_gpu else torch.float32,
266
  }
267
-
268
- # Use flash attention only on GPU when available
269
  if use_gpu and HAS_FLASH_ATTN:
270
  model_kwargs['attn_implementation'] = 'flash_attention_2'
271
-
272
- # Set device_map
273
  if use_gpu:
274
  model_kwargs['device_map'] = 'auto'
275
  else:
276
  model_kwargs['device_map'] = None
277
  model_kwargs['low_cpu_mem_usage'] = True
278
-
279
  try:
280
- # First try to use specific model class
281
- try:
282
- from transformers import Gemma3ForConditionalGeneration
283
- model = Gemma3ForConditionalGeneration.from_pretrained(**model_kwargs)
284
- except:
285
- # If failed, use generic AutoModel
286
- model = AutoModelForImageTextToText.from_pretrained(**model_kwargs)
287
-
288
  model.eval()
289
-
290
- # If CPU mode, manually move to CPU
291
  if not use_gpu:
292
  model = model.to(device)
293
-
294
  except Exception as e:
295
  print(f"⚠️ Failed to load model with optimal settings: {e}")
296
  print("🔄 Falling back to CPU mode...")
297
-
298
- # Fallback to CPU mode
299
  model_kwargs = {
300
  'pretrained_model_name_or_path': args.checkpoint_path,
301
  'torch_dtype': torch.float32,
@@ -307,9 +287,9 @@ def _load_model_processor(args):
307
  model.eval()
308
  use_gpu = False
309
  device = 'cpu'
310
-
311
  processor = AutoProcessor.from_pretrained(args.checkpoint_path)
312
-
313
  print(f"✅ Model loaded successfully on {device}")
314
  return model, processor, device
315
 
@@ -361,8 +341,7 @@ def _parse_text(text):
361
 
362
  def _remove_image_special(text):
363
  """Remove special image tags from text"""
364
- if text is None:
365
- return ""
366
  text = text.replace('<ref>', '').replace('</ref>', '')
367
  return re.sub(r'<box>.*?(</box>|$)', '', text)
368
 
@@ -377,391 +356,229 @@ def _gc():
377
 
378
  def _transform_messages(original_messages, system_prompt):
379
  """Transform messages with custom system prompt"""
380
- transformed_messages = []
381
- system_message = {"role": "system", "content": [{"type": "text", "text":system_prompt}]}
382
- transformed_messages.append(system_message)
383
-
384
  for message in original_messages:
385
  new_content = []
386
  for item in message['content']:
387
  if 'image' in item:
388
- new_item = {'type': 'image', 'image': item['image']}
389
  elif 'text' in item:
390
- new_item = {'type': 'text', 'text': item['text']}
391
- else:
392
- continue
393
- new_content.append(new_item)
394
-
395
- new_message = {'role': message['role'], 'content': new_content}
396
- transformed_messages.append(new_message)
397
-
398
  return transformed_messages
399
 
400
 
401
  def normalize_task_history_item(item):
402
- """规范化task_history中的项目为字典格式"""
403
  if isinstance(item, dict):
404
- # 已经是字典格式,检查必要的键
405
- return {
406
- 'text': item.get('text', ''),
407
- 'images': item.get('images', []),
408
- 'response': item.get('response', None)
409
- }
410
  elif isinstance(item, (list, tuple)) and len(item) >= 2:
411
- # 旧格式: (query, response)
412
  query, response = item[0], item[1]
413
  if isinstance(query, (list, tuple)):
414
- # query是图片列表
415
- return {
416
- 'text': '',
417
- 'images': list(query),
418
- 'response': response
419
- }
420
  else:
421
- # query是文本
422
- return {
423
- 'text': str(query) if query else '',
424
- 'images': [],
425
- 'response': response
426
- }
427
  else:
428
- # 其他格式,尝试处理
429
- return {
430
- 'text': str(item) if item else '',
431
- 'images': [],
432
- 'response': None
433
- }
434
 
435
 
436
  def _launch_demo(args, model, processor, device):
437
  """Launch the Gradio demo interface"""
438
-
439
- @spaces.GPU
440
  def call_local_model(model, processor, messages, system_prompt, temperature, top_p, max_tokens):
441
  """Call the local model with streaming response"""
442
  messages = _transform_messages(messages, system_prompt)
 
443
  inputs = processor.apply_chat_template(
444
  messages,
445
  add_generation_prompt=True,
446
  tokenize=True,
447
  return_dict=True,
448
  return_tensors="pt"
449
- ).to(model.device, dtype=torch.bfloat16)
450
-
 
 
 
 
 
 
 
 
 
 
 
451
  tokenizer = processor.tokenizer
452
  streamer = TextIteratorStreamer(tokenizer, timeout=2000.0, skip_prompt=True, skip_special_tokens=True)
453
-
454
  gen_kwargs = {
455
- 'max_new_tokens': max_tokens,
456
- "do_sample": True,
457
- "temperature": temperature,
458
- "top_p": top_p,
459
- "top_k": 20,
460
- 'streamer': streamer,
461
- **inputs
462
  }
463
-
464
  with torch.inference_mode():
465
  thread = Thread(target=model.generate, kwargs=gen_kwargs)
466
  thread.start()
467
-
468
  generated_text = ''
469
  for new_text in streamer:
470
  generated_text += new_text
471
- if "<think>" in generated_text:
472
- generated_text = generated_text.replace("<think>", "**Reasoning Process**:\n")
473
- if "</think>" in generated_text:
474
- generated_text = generated_text.replace("</think>", "\n")
475
- if "<answer>" in generated_text:
476
- generated_text = generated_text.replace("<answer>", "**Final Answer**:\n")
477
- if "</answer>" in generated_text:
478
- generated_text = generated_text.replace("</answer>", "")
479
- yield generated_text
480
-
481
- def create_predict_fn():
482
- """Create prediction function with optional GPU acceleration"""
483
- def predict_impl(_chatbot, task_history, system_prompt, temperature, top_p, max_tokens):
484
- """Implementation of prediction logic"""
485
- if not _chatbot or not task_history:
486
- yield _chatbot
487
- return
488
-
489
- chat_query = _chatbot[-1][0]
490
-
491
- # 规范化task_history中的最后一个项目
492
- last_item = normalize_task_history_item(task_history[-1])
493
-
494
- if not chat_query and not last_item['text'] and not last_item['images']:
495
- _chatbot.pop()
496
- task_history.pop()
497
- yield _chatbot
498
- return
499
-
500
- print(f'User query: {last_item}')
501
-
502
- # 规范化整个history
503
- history_cp = [normalize_task_history_item(item) for item in copy.deepcopy(task_history)]
504
- full_response = ''
505
- messages = []
506
-
507
- # 构建消息:确保每个user/assistant对都正确交替
508
- for i, item in enumerate(history_cp):
509
- if item['response'] is None: # 当前正在处理的消息
510
- content = []
511
-
512
- # 添加图片
513
- if item['images']:
514
- for img_path in item['images']:
515
- if img_path:
516
- encoded_img = encode_image_pil(img_path)
517
- if encoded_img:
518
- content.append({'image': encoded_img})
519
-
520
- # 添加文本
521
- if item['text']:
522
- content.append({'text': str(item['text'])})
523
-
524
- if content:
525
- messages.append({'role': 'user', 'content': content})
526
- else: # 历史消息
527
- content = []
528
-
529
- # 添加图片
530
- if item['images']:
531
- for img_path in item['images']:
532
- if img_path:
533
- encoded_img = encode_image_pil(img_path)
534
- if encoded_img:
535
- content.append({'image': encoded_img})
536
-
537
- # 添加文本
538
- if item['text']:
539
- content.append({'text': str(item['text'])})
540
-
541
- if content:
542
- messages.append({'role': 'user', 'content': content})
543
- messages.append({'role': 'assistant', 'content': [{'text': str(item['response'])}]})
544
-
545
- try:
546
- for response in call_local_model(model, processor, messages, system_prompt, temperature, top_p, max_tokens):
547
- _chatbot[-1] = (_parse_text(chat_query), _remove_image_special(_parse_text(response)))
548
- yield _chatbot
549
- full_response = response
550
-
551
- # 更新task_history中的response
552
- if isinstance(task_history[-1], dict):
553
- task_history[-1]['response'] = full_response
554
- else:
555
- # 如果是旧格式,转换为新格式
556
- normalized_item = normalize_task_history_item(task_history[-1])
557
- normalized_item['response'] = full_response
558
- task_history[-1] = normalized_item
559
-
560
- print(f'Assistant: {full_response}')
561
- except Exception as e:
562
- print(f"Error during generation: {e}")
563
- import traceback
564
- traceback.print_exc()
565
- _chatbot[-1] = (_parse_text(chat_query), f"Error: {str(e)}")
566
-
567
- # 更新错误信息到task_history
568
- if isinstance(task_history[-1], dict):
569
- task_history[-1]['response'] = f"Error: {str(e)}"
570
- else:
571
- normalized_item = normalize_task_history_item(task_history[-1])
572
- normalized_item['response'] = f"Error: {str(e)}"
573
- task_history[-1] = normalized_item
574
-
575
  yield _chatbot
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
576
 
577
- # Use GPU decorator if spaces is available and using GPU
578
- if HAS_SPACES and device == 'cuda':
579
- predict = spaces.GPU()(predict_impl)
580
- else:
581
- predict = predict_impl
582
-
583
- return predict
584
-
585
- def create_regenerate_fn():
586
- """Create regenerate function"""
587
- def regenerate(_chatbot, task_history, system_prompt, temperature, top_p, max_tokens):
588
- if not task_history or not _chatbot:
589
- yield _chatbot
590
- return
591
-
592
- # 规范化最后一个项目
593
- last_item = normalize_task_history_item(task_history[-1])
594
-
595
- if last_item['response'] is None:
596
- yield _chatbot
597
- return
598
-
599
- # 重置最后一个回复
600
- last_item['response'] = None
601
- task_history[-1] = last_item
602
-
603
- chatbot_item = _chatbot.pop(-1) if _chatbot else None
604
-
605
- if chatbot_item:
606
- if chatbot_item[0] is None and len(_chatbot) > 0:
607
- _chatbot[-1] = (_chatbot[-1][0], None)
608
- else:
609
- _chatbot.append((chatbot_item[0], None))
610
-
611
- # Use the predict function directly
612
- for updated_chatbot in predict(_chatbot, task_history, system_prompt, temperature, top_p, max_tokens):
613
- yield updated_chatbot
614
-
615
- return regenerate
616
-
617
- predict = create_predict_fn()
618
- regenerate = create_regenerate_fn()
619
 
620
  def add_text_and_files(history, task_history, text, files):
621
- """合并文本和文件到同一个消息中"""
622
  history = history if history is not None else []
623
  task_history = task_history if task_history is not None else []
624
-
625
- # 检查是否有有效输入
626
  has_text = text and text.strip()
627
  has_files = files and len(files) > 0
628
-
629
  if not has_text and not has_files:
630
  return history, task_history, text, files
631
-
632
- # 准备消息内容
633
- display_parts = []
634
- file_paths = []
635
-
636
- # 处理文件
637
  if has_files:
638
  for file in files:
639
- if file is not None:
640
- file_path = file.name if hasattr(file, 'name') else str(file)
641
- file_paths.append(file_path)
642
-
643
  if file_paths:
644
  display_parts.append(f"[Uploaded {len(file_paths)} images]")
645
-
646
- # 处理文本
647
  if has_text:
648
  display_parts.append(text)
649
-
650
- # 创建显示消息
651
  display_message = " ".join(display_parts)
652
-
653
- # 添加到历史记录
654
  history.append([_parse_text(display_message), None])
655
- task_history.append({
656
- 'text': text if has_text else '',
657
- 'images': file_paths,
658
- 'response': None
659
- })
660
-
661
- return history, task_history, '', None # 清空输入
662
 
663
- def reset_user_input():
664
- """Reset user input field"""
665
- return gr.update(value='')
666
 
667
  def reset_state():
668
- """Clear conversation history"""
669
  _gc()
670
- return [], [], None # Return empty chatbot, empty task_history, and None for file input
671
 
672
- # Create Gradio interface
673
  with gr.Blocks(css=CUSTOM_CSS, theme=gr.themes.Soft()) as demo:
674
- gr.HTML("""
675
  <div class="container">
676
  <h1 class="main-title">IntrinSight Assistant</h1>
677
  <p class="sub-title">
678
  Powered by IntrinSight-4B Model
679
- <span class="{}">{}</span>
 
 
680
  </p>
681
  </div>
682
- """.format(
683
- "status-indicator gpu-status" if device == 'cuda' else "status-indicator cpu-status",
684
- "🚀 GPU Mode" if device == 'cuda' else "💻 CPU Mode"
685
- ))
686
 
687
- # Initialize states
688
  task_history = gr.State([])
689
 
690
  with gr.Row():
691
  with gr.Column(scale=4):
692
  chatbot = gr.Chatbot(
693
- label='IntrinSight-4B Chat Interface',
694
- elem_classes='control-height',
695
- height=600,
696
  avatar_images=(None, "https://em-content.zobj.net/thumbs/240/twitter/348/robot_1f916.png")
697
  )
698
-
699
  with gr.Row():
700
- query = gr.Textbox(
701
- lines=3,
702
- label='💬 Message Input',
703
- placeholder="Enter your question here...",
704
- elem_classes="custom-input"
705
- )
706
-
707
  with gr.Row():
708
- # Multi-file upload with drag and drop support
709
  addfile_btn = gr.File(
710
- label="📸 Upload Images (Drag & Drop Supported, Multiple Selection)",
711
- file_count="multiple",
712
- file_types=["image"],
713
- elem_classes="file-upload-area"
714
  )
715
-
716
  with gr.Row():
717
  submit_btn = gr.Button('🚀 Send', variant="primary", elem_classes="custom-button")
718
  regen_btn = gr.Button('🔄 Regenerate', variant="secondary", elem_classes="custom-button")
719
  empty_bin = gr.Button('🗑️ Clear History', variant="stop", elem_classes="custom-button")
720
 
721
  with gr.Column(scale=2):
722
- # System prompt section
723
  with gr.Group(elem_classes="parameter-section"):
724
  gr.Markdown("### ⚙️ System Configuration")
725
- system_prompt = gr.Textbox(
726
- label="System Prompt",
727
- value=DEFAULT_SYSTEM_PROMPT,
728
- lines=5,
729
- placeholder="Enter system prompt here..."
730
- )
731
-
732
- # Generation parameters section
733
  with gr.Group(elem_classes="parameter-section"):
734
  gr.Markdown("### 🎛️ Generation Parameters")
735
-
736
- temperature = gr.Slider(
737
- minimum=0.1,
738
- maximum=2.0,
739
- value=0.7,
740
- step=0.1,
741
- label="Temperature (Creativity)",
742
- info="Higher values make output more random"
743
- )
744
-
745
- top_p = gr.Slider(
746
- minimum=0.1,
747
- maximum=1.0,
748
- value=1.0,
749
- step=0.05,
750
- label="Top-p (Nucleus Sampling)",
751
- info="Cumulative probability for token selection"
752
- )
753
-
754
- max_tokens = gr.Slider(
755
- minimum=256,
756
- maximum=16384,
757
- value=8192,
758
- step=256,
759
- label="Max Tokens",
760
- info="Maximum number of tokens to generate"
761
- )
762
-
763
- # Instructions section
764
- gr.Markdown("""
765
  ### 📋 Instructions
766
 
767
  **Basic Usage:**
@@ -771,65 +588,53 @@ def _launch_demo(args, model, processor, device):
771
  - **Parameters**: Adjust generation settings as needed
772
 
773
  **Performance Info:**
774
- - Current Mode: **{}**
775
- - Flash Attention: **{}**
776
  - Recommended Image Size: < 1024×1024
777
 
778
  ### ⚠️ Disclaimer
779
 
780
  This demo is subject to the Gemma license agreement.
781
  Please do not generate or disseminate harmful content.
782
- """.format(
783
- "GPU Acceleration" if device == 'cuda' else "CPU Mode",
784
- "Enabled" if (device == 'cuda' and HAS_FLASH_ATTN) else "Disabled"
785
- ))
786
 
787
- # Event bindings
788
  submit_btn.click(
789
- add_text_and_files,
790
  [chatbot, task_history, query, addfile_btn],
791
  [chatbot, task_history, query, addfile_btn]
792
  ).then(
793
- predict,
794
- [chatbot, task_history, system_prompt, temperature, top_p, max_tokens],
795
- [chatbot],
796
- show_progress=True
797
- )
798
-
799
- empty_bin.click(
800
- reset_state,
801
- outputs=[chatbot, task_history, addfile_btn],
802
- show_progress=True
803
  )
804
-
 
 
805
  regen_btn.click(
806
- regenerate,
807
- [chatbot, task_history, system_prompt, temperature, top_p, max_tokens],
808
- [chatbot],
809
- show_progress=True
810
  )
811
-
812
- # Enter key to send message
813
  query.submit(
814
- add_text_and_files,
815
  [chatbot, task_history, query, addfile_btn],
816
  [chatbot, task_history, query, addfile_btn]
817
  ).then(
818
- predict,
819
- [chatbot, task_history, system_prompt, temperature, top_p, max_tokens],
820
- [chatbot],
821
- show_progress=True
822
  )
823
 
824
  demo.queue(max_size=10).launch(
825
- share=args.share,
826
- inbrowser=args.inbrowser,
827
- server_port=args.server_port,
828
- server_name=args.server_name,
829
  show_error=True
830
  )
831
 
832
-
833
  def main():
834
  """Main entry point"""
835
  args = _get_args()
 
25
  HAS_SPACES = True
26
  except ImportError:
27
  HAS_SPACES = False
 
28
  class spaces:
29
  @staticmethod
30
+ def GPU(func=None, **kwargs):
31
+ if func:
32
  return func
33
+ return lambda f: f
34
 
35
  # Check if GPU is available
36
  HAS_GPU = torch.cuda.is_available()
37
 
38
  # Try to install flash-attn (only in GPU environment)
39
+ if HAS_GPU:
40
+ try:
41
+ import subprocess
42
+ subprocess.run('pip install flash-attn==2.7.4.post1 --no-build-isolation',
43
+ env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"},
44
+ shell=True,
45
+ capture_output=True,
46
+ timeout=30)
47
+ import flash_attn
48
+ HAS_FLASH_ATTN = True
49
+ except Exception as e:
50
+ print(f"Flash Attention installation failed: {e}")
51
+ HAS_FLASH_ATTN = False
52
+ else:
53
+ HAS_FLASH_ATTN = False
54
 
55
+ HAS_FLASH_ATTN = False
56
  # Default model checkpoint path
57
  DEFAULT_CKPT_PATH = 'qiuxi337/IntrinSight-4B'
58
 
 
60
  DEFAULT_SYSTEM_PROMPT = (
61
  "A conversation between user and assistant. The user asks a question, and the assistant solves it. The assistant "
62
  "first thinks about the reasoning process in the mind and then provides the user with the answer. "
 
63
  "The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
64
  "<think> reasoning process here </think><answer> answer here </answer>."
65
  )
 
193
  def _get_args():
194
  """Parse command line arguments"""
195
  parser = ArgumentParser()
196
+ parser.add_argument('-c', '--checkpoint-path',
197
  type=str,
198
+ default=DEFAULT_CKPT_PATH,
199
  help='Checkpoint name or path, default to %(default)r')
200
  parser.add_argument('--cpu-only', action='store_true', help='Run demo with CPU only')
201
  parser.add_argument('--share',
 
217
  """Encode image to base64 using PIL"""
218
  try:
219
  if isinstance(image_path, str):
 
220
  img = Image.open(image_path)
221
  elif isinstance(image_path, np.ndarray):
 
222
  img = Image.fromarray(image_path)
223
  elif isinstance(image_path, Image.Image):
 
224
  img = image_path
225
  else:
226
  print(f"Unsupported image type: {type(image_path)}")
227
  return None
228
+
 
229
  if img.mode not in ('RGB', 'RGBA'):
230
  img = img.convert('RGB')
231
+
 
232
  max_size = (1024, 1024)
233
  img.thumbnail(max_size, Image.Resampling.LANCZOS)
234
+
235
  buffered = io.BytesIO()
236
  img.save(buffered, format="PNG")
237
  return base64.b64encode(buffered.getvalue()).decode('utf-8')
 
243
  def _load_model_processor(args):
244
  """Intelligently load model, automatically choose CPU or GPU based on environment"""
245
  global HAS_GPU, HAS_FLASH_ATTN
246
+
 
247
  use_gpu = HAS_GPU and not args.cpu_only
248
  device = 'cuda' if use_gpu else 'cpu'
249
+
250
  print(f"{'='*50}")
251
  print(f"🚀 Loading model: {args.checkpoint_path}")
252
  print(f"📱 Device: {'GPU (CUDA)' if use_gpu else 'CPU'}")
253
  print(f"⚡ Flash Attention: {'Enabled' if (use_gpu and HAS_FLASH_ATTN) else 'Disabled'}")
254
  print(f"{'='*50}")
255
+
 
256
  model_kwargs = {
257
  'pretrained_model_name_or_path': args.checkpoint_path,
258
  'torch_dtype': torch.bfloat16 if use_gpu else torch.float32,
259
  }
260
+
 
261
  if use_gpu and HAS_FLASH_ATTN:
262
  model_kwargs['attn_implementation'] = 'flash_attention_2'
263
+
 
264
  if use_gpu:
265
  model_kwargs['device_map'] = 'auto'
266
  else:
267
  model_kwargs['device_map'] = None
268
  model_kwargs['low_cpu_mem_usage'] = True
269
+
270
  try:
271
+ model = AutoModelForImageTextToText.from_pretrained(**model_kwargs)
 
 
 
 
 
 
 
272
  model.eval()
273
+ # Note: even with device_map='auto', we might need to move a CPU-only model explicitly
 
274
  if not use_gpu:
275
  model = model.to(device)
 
276
  except Exception as e:
277
  print(f"⚠️ Failed to load model with optimal settings: {e}")
278
  print("🔄 Falling back to CPU mode...")
 
 
279
  model_kwargs = {
280
  'pretrained_model_name_or_path': args.checkpoint_path,
281
  'torch_dtype': torch.float32,
 
287
  model.eval()
288
  use_gpu = False
289
  device = 'cpu'
290
+
291
  processor = AutoProcessor.from_pretrained(args.checkpoint_path)
292
+
293
  print(f"✅ Model loaded successfully on {device}")
294
  return model, processor, device
295
 
 
341
 
342
  def _remove_image_special(text):
343
  """Remove special image tags from text"""
344
+ if text is None: return ""
 
345
  text = text.replace('<ref>', '').replace('</ref>', '')
346
  return re.sub(r'<box>.*?(</box>|$)', '', text)
347
 
 
356
 
357
  def _transform_messages(original_messages, system_prompt):
358
  """Transform messages with custom system prompt"""
359
+ transformed_messages = [{"role": "system", "content": [{"type": "text", "text":system_prompt}]}]
 
 
 
360
  for message in original_messages:
361
  new_content = []
362
  for item in message['content']:
363
  if 'image' in item:
364
+ new_content.append({'type': 'image', 'image': item['image']})
365
  elif 'text' in item:
366
+ new_content.append({'type': 'text', 'text': item['text']})
367
+ if new_content:
368
+ transformed_messages.append({'role': message['role'], 'content': new_content})
 
 
 
 
 
369
  return transformed_messages
370
 
371
 
372
  def normalize_task_history_item(item):
373
+ """Normalize items in task_history to a dictionary format"""
374
  if isinstance(item, dict):
375
+ return {'text': item.get('text', ''), 'images': item.get('images', []), 'response': item.get('response', None)}
 
 
 
 
 
376
  elif isinstance(item, (list, tuple)) and len(item) >= 2:
 
377
  query, response = item[0], item[1]
378
  if isinstance(query, (list, tuple)):
379
+ return {'text': '', 'images': list(query), 'response': response}
 
 
 
 
 
380
  else:
381
+ return {'text': str(query) if query else '', 'images': [], 'response': response}
 
 
 
 
 
382
  else:
383
+ return {'text': str(item) if item else '', 'images': [], 'response': None}
 
 
 
 
 
384
 
385
 
386
  def _launch_demo(args, model, processor, device):
387
  """Launch the Gradio demo interface"""
388
+
 
389
  def call_local_model(model, processor, messages, system_prompt, temperature, top_p, max_tokens):
390
  """Call the local model with streaming response"""
391
  messages = _transform_messages(messages, system_prompt)
392
+
393
  inputs = processor.apply_chat_template(
394
  messages,
395
  add_generation_prompt=True,
396
  tokenize=True,
397
  return_dict=True,
398
  return_tensors="pt"
399
+ )
400
+
401
+ # ====================================================================
402
+ # THE FINAL, ROBUST FIX for all environments (CUDA, ZeroGPU, CPU)
403
+ # We must move the input tensors to the correct device.
404
+ # However, to be compatible with ZeroGPU's `torch.compile`, we must use
405
+ # a string ('cuda' or 'cpu') instead of a `torch.device` object.
406
+ # The `device` variable (a string) is passed in from the parent scope.
407
+ # This prevents both the "device mismatch" error and the "ConstantVariable" error.
408
+ # ====================================================================
409
+ inputs = inputs.to(device)
410
+ # ====================================================================
411
+
412
  tokenizer = processor.tokenizer
413
  streamer = TextIteratorStreamer(tokenizer, timeout=2000.0, skip_prompt=True, skip_special_tokens=True)
414
+
415
  gen_kwargs = {
416
+ 'max_new_tokens': max_tokens, "do_sample": True, "temperature": temperature,
417
+ "top_p": top_p, "top_k": 20, 'streamer': streamer, **inputs
 
 
 
 
 
418
  }
419
+
420
  with torch.inference_mode():
421
  thread = Thread(target=model.generate, kwargs=gen_kwargs)
422
  thread.start()
 
423
  generated_text = ''
424
  for new_text in streamer:
425
  generated_text += new_text
426
+ display_text = generated_text
427
+ if "<think>" in display_text: display_text = display_text.replace("<think>", "**Reasoning Process**:\n")
428
+ if "</think>" in display_text: display_text = display_text.replace("</think>", "\n")
429
+ if "<answer>" in display_text: display_text = display_text.replace("<answer>", "**Final Answer**:\n")
430
+ if "</answer>" in display_text: display_text = display_text.replace("</answer>", "")
431
+ yield display_text, generated_text
432
+
433
+ @spaces.GPU
434
+ def predict(_chatbot, task_history, system_prompt, temperature, top_p, max_tokens):
435
+ if not _chatbot or not task_history:
436
+ yield _chatbot
437
+ return
438
+
439
+ chat_query = _chatbot[-1][0]
440
+ last_item = normalize_task_history_item(task_history[-1])
441
+
442
+ if not chat_query and not last_item['text'] and not last_item['images']:
443
+ _chatbot.pop()
444
+ task_history.pop()
445
+ yield _chatbot
446
+ return
447
+
448
+ print(f'User query: {last_item}')
449
+ history_cp = [normalize_task_history_item(item) for item in copy.deepcopy(task_history)]
450
+ full_response_raw = ''
451
+ messages = []
452
+
453
+ for i, item in enumerate(history_cp):
454
+ content = []
455
+ if item['images']:
456
+ for img_path in item['images']:
457
+ if img_path:
458
+ encoded_img = encode_image_pil(img_path)
459
+ if encoded_img: content.append({'image': encoded_img})
460
+ if item['text']: content.append({'text': str(item['text'])})
461
+
462
+ if item['response'] is None:
463
+ if content: messages.append({'role': 'user', 'content': content})
464
+ else:
465
+ if content: messages.append({'role': 'user', 'content': content})
466
+ messages.append({'role': 'assistant', 'content': [{'text': str(item['response'])}]})
467
+
468
+ try:
469
+ for response_display, response_raw in call_local_model(model, processor, messages, system_prompt, temperature, top_p, max_tokens):
470
+ _chatbot[-1] = (_parse_text(chat_query), _remove_image_special(_parse_text(response_display)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
471
  yield _chatbot
472
+ full_response_raw = response_raw
473
+
474
+ task_history[-1]['response'] = full_response_raw
475
+ print(f'Assistant: {full_response_raw}')
476
+ except Exception as e:
477
+ print(f"Error during generation: {e}")
478
+ import traceback
479
+ traceback.print_exc()
480
+ error_msg = f"Error: {str(e)}"
481
+ _chatbot[-1] = (_parse_text(chat_query), error_msg)
482
+ task_history[-1]['response'] = error_msg
483
+ yield _chatbot
484
+
485
+ @spaces.GPU
486
+ def regenerate(_chatbot, task_history, system_prompt, temperature, top_p, max_tokens):
487
+ if not task_history or not _chatbot:
488
+ yield _chatbot
489
+ return
490
+
491
+ last_item = normalize_task_history_item(task_history[-1])
492
+ if last_item['response'] is None:
493
+ yield _chatbot
494
+ return
495
+
496
+ last_item['response'] = None
497
+ task_history[-1] = last_item
498
+ _chatbot.pop(-1)
499
 
500
+ display_message_parts = []
501
+ if last_item['images']: display_message_parts.append(f"[Uploaded {len(last_item['images'])} images]")
502
+ if last_item['text']: display_message_parts.append(last_item['text'])
503
+ display_message = " ".join(display_message_parts)
504
+ _chatbot.append([_parse_text(display_message), None])
505
+
506
+ for updated_chatbot in predict(_chatbot, task_history, system_prompt, temperature, top_p, max_tokens):
507
+ yield updated_chatbot
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
508
 
509
  def add_text_and_files(history, task_history, text, files):
 
510
  history = history if history is not None else []
511
  task_history = task_history if task_history is not None else []
512
+
 
513
  has_text = text and text.strip()
514
  has_files = files and len(files) > 0
515
+
516
  if not has_text and not has_files:
517
  return history, task_history, text, files
518
+
519
+ display_parts, file_paths = [], []
 
 
 
 
520
  if has_files:
521
  for file in files:
522
+ if file and hasattr(file, 'name'):
523
+ file_paths.append(file.name)
 
 
524
  if file_paths:
525
  display_parts.append(f"[Uploaded {len(file_paths)} images]")
 
 
526
  if has_text:
527
  display_parts.append(text)
528
+
 
529
  display_message = " ".join(display_parts)
 
 
530
  history.append([_parse_text(display_message), None])
531
+ task_history.append({'text': text if has_text else '', 'images': file_paths, 'response': None})
 
 
 
 
 
 
532
 
533
+ return history, task_history, '', None
 
 
534
 
535
  def reset_state():
 
536
  _gc()
537
+ return [], [], None
538
 
 
539
  with gr.Blocks(css=CUSTOM_CSS, theme=gr.themes.Soft()) as demo:
540
+ gr.HTML(f"""
541
  <div class="container">
542
  <h1 class="main-title">IntrinSight Assistant</h1>
543
  <p class="sub-title">
544
  Powered by IntrinSight-4B Model
545
+ <span class="status-indicator {'gpu-status' if device == 'cuda' else 'cpu-status'}">
546
+ {'🚀 GPU Mode' if device == 'cuda' else '💻 CPU Mode'}
547
+ </span>
548
  </p>
549
  </div>
550
+ """)
 
 
 
551
 
 
552
  task_history = gr.State([])
553
 
554
  with gr.Row():
555
  with gr.Column(scale=4):
556
  chatbot = gr.Chatbot(
557
+ label='IntrinSight-4B Chat Interface', elem_classes='control-height', height=600,
 
 
558
  avatar_images=(None, "https://em-content.zobj.net/thumbs/240/twitter/348/robot_1f916.png")
559
  )
 
560
  with gr.Row():
561
+ query = gr.Textbox(lines=3, label='💬 Message Input', placeholder="Enter your question here...", elem_classes="custom-input")
 
 
 
 
 
 
562
  with gr.Row():
 
563
  addfile_btn = gr.File(
564
+ label="📸 Upload Images (Drag & Drop Supported, Multiple Selection)", file_count="multiple",
565
+ file_types=["image"], elem_classes="file-upload-area"
 
 
566
  )
 
567
  with gr.Row():
568
  submit_btn = gr.Button('🚀 Send', variant="primary", elem_classes="custom-button")
569
  regen_btn = gr.Button('🔄 Regenerate', variant="secondary", elem_classes="custom-button")
570
  empty_bin = gr.Button('🗑️ Clear History', variant="stop", elem_classes="custom-button")
571
 
572
  with gr.Column(scale=2):
 
573
  with gr.Group(elem_classes="parameter-section"):
574
  gr.Markdown("### ⚙️ System Configuration")
575
+ system_prompt = gr.Textbox(label="System Prompt", value=DEFAULT_SYSTEM_PROMPT, lines=5, placeholder="Enter system prompt here...")
 
 
 
 
 
 
 
576
  with gr.Group(elem_classes="parameter-section"):
577
  gr.Markdown("### 🎛️ Generation Parameters")
578
+ temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature (Creativity)", info="Higher values make output more random")
579
+ top_p = gr.Slider(minimum=0.1, maximum=1.0, value=1.0, step=0.05, label="Top-p (Nucleus Sampling)", info="Cumulative probability for token selection")
580
+ max_tokens = gr.Slider(minimum=256, maximum=16384, value=8192, step=256, label="Max Tokens", info="Maximum number of tokens to generate")
581
+ gr.Markdown(f"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
582
  ### 📋 Instructions
583
 
584
  **Basic Usage:**
 
588
  - **Parameters**: Adjust generation settings as needed
589
 
590
  **Performance Info:**
591
+ - Current Mode: **{'GPU Acceleration' if device == 'cuda' else 'CPU Mode'}**
592
+ - Flash Attention: **{'Enabled' if (device == 'cuda' and HAS_FLASH_ATTN) else 'Disabled'}**
593
  - Recommended Image Size: < 1024×1024
594
 
595
  ### ⚠️ Disclaimer
596
 
597
  This demo is subject to the Gemma license agreement.
598
  Please do not generate or disseminate harmful content.
599
+ """)
 
 
 
600
 
 
601
  submit_btn.click(
602
+ add_text_and_files,
603
  [chatbot, task_history, query, addfile_btn],
604
  [chatbot, task_history, query, addfile_btn]
605
  ).then(
606
+ predict,
607
+ [chatbot, task_history, system_prompt, temperature, top_p, max_tokens],
608
+ [chatbot],
609
+ show_progress="full"
 
 
 
 
 
 
610
  )
611
+
612
+ empty_bin.click(reset_state, outputs=[chatbot, task_history, addfile_btn], show_progress=True)
613
+
614
  regen_btn.click(
615
+ regenerate,
616
+ [chatbot, task_history, system_prompt, temperature, top_p, max_tokens],
617
+ [chatbot],
618
+ show_progress="full"
619
  )
620
+
 
621
  query.submit(
622
+ add_text_and_files,
623
  [chatbot, task_history, query, addfile_btn],
624
  [chatbot, task_history, query, addfile_btn]
625
  ).then(
626
+ predict,
627
+ [chatbot, task_history, system_prompt, temperature, top_p, max_tokens],
628
+ [chatbot],
629
+ show_progress="full"
630
  )
631
 
632
  demo.queue(max_size=10).launch(
633
+ share=args.share, inbrowser=args.inbrowser,
634
+ server_port=args.server_port, server_name=args.server_name,
 
 
635
  show_error=True
636
  )
637
 
 
638
  def main():
639
  """Main entry point"""
640
  args = _get_args()