Spaces:
Running
on
Zero
Running
on
Zero
| import spaces | |
| import gradio as gr | |
| from transformers import AutoTokenizer, TextIteratorStreamer | |
| from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig | |
| import torch | |
| from threading import Thread | |
| # Model and device configuration | |
| phi4_model_path = "Compumacy/OpenBioLLm-70B" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # === GPTQ 2-bit QUANTIZATION CONFIG === | |
| quantize_config = BaseQuantizeConfig( | |
| bits=2, # 2-bit quantization | |
| group_size=128, # grouping size | |
| desc_act=False # disable descending activations for speed | |
| ) | |
| # === LOAD GPTQ-QUANTIZED MODEL === | |
| model = AutoGPTQForCausalLM.from_quantized( | |
| phi4_model_path, | |
| quantize_config=quantize_config, | |
| device_map="auto", | |
| use_safetensors=True, | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(phi4_model_path) | |
| # === OPTIONAL: TorchCompile for optimization (PyTorch >= 2.0) === | |
| try: | |
| model = torch.compile(model) | |
| except Exception: | |
| pass | |
| # === STREAMING RESPONSE GENERATOR === | |
| def generate_response(user_message, max_tokens, temperature, top_k, top_p, repetition_penalty, history_state): | |
| if not user_message.strip(): | |
| return history_state, history_state | |
| # System prompt prefix | |
| system_message = ( | |
| "Your role as an assistant involves thoroughly exploring questions through a systematic thinking process..." | |
| ) | |
| start_tag, sep_tag, end_tag = "<|im_start|>", "<|im_sep|>", "<|im_end|>" | |
| # Build full prompt | |
| prompt = f"{start_tag}system{sep_tag}{system_message}{end_tag}" | |
| for msg in history_state: | |
| prompt += f"{start_tag}{msg['role']}{sep_tag}{msg['content']}{end_tag}" | |
| prompt += f"{start_tag}user{sep_tag}{user_message}{end_tag}{start_tag}assistant{sep_tag}" | |
| # Tokenize and move to device | |
| inputs = tokenizer(prompt, return_tensors="pt").to(device) | |
| # Set up streamer | |
| streamer = TextIteratorStreamer(tokenizer, skip_prompt=True) | |
| generation_kwargs = { | |
| "input_ids": inputs.input_ids, | |
| "attention_mask": inputs.attention_mask, | |
| "max_new_tokens": int(max_tokens), | |
| "do_sample": True, | |
| "temperature": temperature, | |
| "top_k": int(top_k), | |
| "top_p": top_p, | |
| "repetition_penalty": repetition_penalty, | |
| "streamer": streamer | |
| } | |
| # Launch generation | |
| Thread(target=model.generate, kwargs=generation_kwargs).start() | |
| assistant_response = "" | |
| new_history = history_state + [ | |
| {"role": "user", "content": user_message}, | |
| {"role": "assistant", "content": ""} | |
| ] | |
| # Stream tokens back to Gradio | |
| for token in streamer: | |
| clean = token.replace(start_tag, "").replace(sep_tag, "").replace(end_tag, "") | |
| assistant_response += clean | |
| new_history[-1]["content"] = assistant_response | |
| yield new_history, new_history | |
| yield new_history, new_history | |
| # === EXAMPLE MESSAGES === | |
| example_messages = { | |
| "Math reasoning": "If a rectangular prism has a length of 6 cm...", | |
| "Logic puzzle": "Four people (Alex, Blake, Casey, ...)", | |
| "Physics problem": "A ball is thrown upward with an initial velocity..." | |
| } | |
| # === GRADIO APP === | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # Phi-4 Chat with GPTQ Quant | |
| Try the example problems below to see how the model breaks down complex reasoning. | |
| """ ) | |
| history_state = gr.State([]) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Settings") | |
| max_tokens_slider = gr.Slider(64, 32768, step=1024, value=2048, label="Max Tokens") | |
| with gr.Accordion("Advanced Settings", open=False): | |
| temperature_slider = gr.Slider(0.1, 2.0, value=0.8, label="Temperature") | |
| top_k_slider = gr.Slider(1, 100, step=1, value=50, label="Top-k") | |
| top_p_slider = gr.Slider(0.1, 1.0, value=0.95, label="Top-p") | |
| repetition_penalty_slider = gr.Slider(1.0, 2.0, value=1.0, label="Repetition Penalty") | |
| with gr.Column(scale=4): | |
| chatbot = gr.Chatbot(label="Chat", type="messages") | |
| with gr.Row(): | |
| user_input = gr.Textbox(placeholder="Type your message...", scale=3) | |
| submit_button = gr.Button("Send", variant="primary", scale=1) | |
| clear_button = gr.Button("Clear", scale=1) | |
| gr.Markdown("**Try these examples:**") | |
| with gr.Row(): | |
| for name, text in example_messages.items(): | |
| btn = gr.Button(name) | |
| btn.click(fn=lambda t=text: gr.update(value=t), inputs=None, outputs=user_input) | |
| submit_button.click( | |
| fn=generate_response, | |
| inputs=[user_input, max_tokens_slider, temperature_slider, top_k_slider, top_p_slider, repetition_penalty_slider, history_state], | |
| outputs=[chatbot, history_state] | |
| ).then(lambda: gr.update(value=""), None, user_input) | |
| clear_button.click(lambda: ([], []), None, [chatbot, history_state]) | |
| demo.launch(ssr_mode=False) | |
| # Note: | |
| # To get CUDA extensions (nf4, double quant, etc.) back, reinstall AutoGPTQ with CUDA support: | |
| # pip install git+https://github.com/PanQiWei/AutoGPTQ.git#egg=auto-gptq[cuda] | |