|
|
""" |
|
|
Custom handler for Hugging Face Inference Endpoints |
|
|
Model: ongilLabs/IB-Math-Instruct-7B |
|
|
""" |
|
|
|
|
|
from typing import Dict, List, Any |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
import torch |
|
|
|
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, path: str = ""): |
|
|
"""Initialize the model and tokenizer.""" |
|
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
|
path, |
|
|
trust_remote_code=True |
|
|
) |
|
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
|
path, |
|
|
torch_dtype=torch.bfloat16, |
|
|
device_map="auto", |
|
|
trust_remote_code=True |
|
|
) |
|
|
self.model.eval() |
|
|
|
|
|
|
|
|
self.default_system = """You are an expert IB Mathematics AA tutor. Your role is to explain mathematical concepts and solve problems using pure mathematical reasoning, NOT programming code. |
|
|
|
|
|
CRITICAL RULES: |
|
|
1. NEVER write Python, SymPy, or any programming code |
|
|
2. Use ONLY mathematical notation and LaTeX ($...$ for inline, $$...$$ for display) |
|
|
3. Show step-by-step solutions with clear mathematical reasoning |
|
|
4. Use IB command terms appropriately (Find, Show, Hence, Prove, etc.) |
|
|
5. Include common pitfall warnings when relevant |
|
|
6. End with IB exam tips about marking schemes (M marks, A marks) |
|
|
7. Write in a teacher-like, encouraging tone |
|
|
8. Use <think> tags to show your reasoning process before the solution |
|
|
|
|
|
Your responses should be like a professional IB teacher explaining to students, using mathematical notation and clear explanations.""" |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
|
""" |
|
|
Handle inference request. |
|
|
|
|
|
Args: |
|
|
data: Dictionary with 'inputs' (str or list) and optional 'parameters' |
|
|
|
|
|
Returns: |
|
|
Dictionary with 'generated_text' |
|
|
""" |
|
|
inputs = data.get("inputs", "") |
|
|
parameters = data.get("parameters", {}) |
|
|
|
|
|
|
|
|
max_new_tokens = parameters.get("max_new_tokens", 1024) |
|
|
temperature = parameters.get("temperature", 0.7) |
|
|
top_p = parameters.get("top_p", 0.9) |
|
|
system_prompt = parameters.get("system_prompt", self.default_system) |
|
|
|
|
|
|
|
|
if isinstance(inputs, str): |
|
|
messages = [ |
|
|
{"role": "system", "content": system_prompt}, |
|
|
{"role": "user", "content": inputs} |
|
|
] |
|
|
elif isinstance(inputs, list): |
|
|
|
|
|
messages = inputs |
|
|
|
|
|
if messages and messages[0].get("role") != "system": |
|
|
messages = [{"role": "system", "content": system_prompt}] + messages |
|
|
else: |
|
|
return {"error": "Invalid input format. Expected string or list of messages."} |
|
|
|
|
|
|
|
|
prompt = self.tokenizer.apply_chat_template( |
|
|
messages, |
|
|
tokenize=False, |
|
|
add_generation_prompt=True |
|
|
) |
|
|
|
|
|
input_ids = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.model.generate( |
|
|
**input_ids, |
|
|
max_new_tokens=max_new_tokens, |
|
|
temperature=temperature if temperature > 0 else None, |
|
|
top_p=top_p, |
|
|
do_sample=temperature > 0, |
|
|
pad_token_id=self.tokenizer.eos_token_id, |
|
|
) |
|
|
|
|
|
|
|
|
response = self.tokenizer.decode( |
|
|
outputs[0][input_ids["input_ids"].shape[1]:], |
|
|
skip_special_tokens=True |
|
|
) |
|
|
|
|
|
return {"generated_text": response} |
|
|
|
|
|
|