IB-Math-Ontology-7B / handler.py
ongilLabs's picture
Increase max_new_tokens default to 1024 to prevent truncation
6619c0f verified
"""
Custom handler for Hugging Face Inference Endpoints
Model: ongilLabs/IB-Math-Instruct-7B
"""
from typing import Dict, List, Any
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
class EndpointHandler:
def __init__(self, path: str = ""):
"""Initialize the model and tokenizer."""
self.tokenizer = AutoTokenizer.from_pretrained(
path,
trust_remote_code=True
)
self.model = AutoModelForCausalLM.from_pretrained(
path,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True
)
self.model.eval()
# Default system prompt - IB Math AA Tutor style
self.default_system = """You are an expert IB Mathematics AA tutor. Your role is to explain mathematical concepts and solve problems using pure mathematical reasoning, NOT programming code.
CRITICAL RULES:
1. NEVER write Python, SymPy, or any programming code
2. Use ONLY mathematical notation and LaTeX ($...$ for inline, $$...$$ for display)
3. Show step-by-step solutions with clear mathematical reasoning
4. Use IB command terms appropriately (Find, Show, Hence, Prove, etc.)
5. Include common pitfall warnings when relevant
6. End with IB exam tips about marking schemes (M marks, A marks)
7. Write in a teacher-like, encouraging tone
8. Use <think> tags to show your reasoning process before the solution
Your responses should be like a professional IB teacher explaining to students, using mathematical notation and clear explanations."""
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Handle inference request.
Args:
data: Dictionary with 'inputs' (str or list) and optional 'parameters'
Returns:
Dictionary with 'generated_text'
"""
inputs = data.get("inputs", "")
parameters = data.get("parameters", {})
# Extract parameters with defaults
max_new_tokens = parameters.get("max_new_tokens", 1024) # Increased to prevent truncation
temperature = parameters.get("temperature", 0.7)
top_p = parameters.get("top_p", 0.9)
system_prompt = parameters.get("system_prompt", self.default_system)
# Handle both string and message list inputs
if isinstance(inputs, str):
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": inputs}
]
elif isinstance(inputs, list):
# Assume it's already a list of messages
messages = inputs
# Prepend system if not present
if messages and messages[0].get("role") != "system":
messages = [{"role": "system", "content": system_prompt}] + messages
else:
return {"error": "Invalid input format. Expected string or list of messages."}
# Apply chat template
prompt = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
input_ids = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
with torch.no_grad():
outputs = self.model.generate(
**input_ids,
max_new_tokens=max_new_tokens,
temperature=temperature if temperature > 0 else None,
top_p=top_p,
do_sample=temperature > 0,
pad_token_id=self.tokenizer.eos_token_id,
)
# Decode only new tokens (exclude prompt)
response = self.tokenizer.decode(
outputs[0][input_ids["input_ids"].shape[1]:],
skip_special_tokens=True
)
return {"generated_text": response}