Adding Generic Reasoning Budget (#23)
Browse files- Adding Generic Reasoning Budget (c0f14ee90b57a8218f7c6170725c2ef6e5489944)
Co-authored-by: Chris Alexiuk <llm-wizard@users.noreply.huggingface.co>
README.md
CHANGED
|
@@ -388,6 +388,117 @@ python3 -m sglang.launch_server --model-path nvidia/NVIDIA-Nemotron-3-Nano-30B-A
|
|
| 388 |
--reasoning-parser nano_v3
|
| 389 |
```
|
| 390 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 391 |
## Model Version(s)
|
| 392 |
|
| 393 |
- v1.0
|
|
|
|
| 388 |
--reasoning-parser nano_v3
|
| 389 |
```
|
| 390 |
|
| 391 |
+
#### Using Budget Control
|
| 392 |
+
|
| 393 |
+
The thinking budget allows developers to keep accuracy high and meet response‑time targets \- which is especially crucial for customer support, autonomous agent steps, and edge devices where every millisecond counts.
|
| 394 |
+
|
| 395 |
+
With budget control, you can set a limit for internal reasoning:
|
| 396 |
+
|
| 397 |
+
* `reasoning_budget`: This is a threshold that will attempt to end the reasoning trace at the next newline encountered in the reasoning trace. If no newline is encountered within 500 tokens, it will abruptly end the reasoning trace at `reasoning_budget + 500`.
|
| 398 |
+
|
| 399 |
+
> NOTE: This client will work with any OpenAI API compatible endpoint.
|
| 400 |
+
|
| 401 |
+
Client for supporting budget control:
|
| 402 |
+
|
| 403 |
+
```py
|
| 404 |
+
from typing import Any, Dict, List
|
| 405 |
+
|
| 406 |
+
import openai
|
| 407 |
+
from transformers import AutoTokenizer
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
class ThinkingBudgetClient:
|
| 411 |
+
def __init__(self, base_url: str, api_key: str, tokenizer_name_or_path: str):
|
| 412 |
+
self.base_url = base_url
|
| 413 |
+
self.api_key = api_key
|
| 414 |
+
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
|
| 415 |
+
self.client = openai.OpenAI(base_url=self.base_url, api_key=self.api_key)
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
def chat_completion(
|
| 419 |
+
self,
|
| 420 |
+
model: str,
|
| 421 |
+
messages: List[Dict[str, Any]],
|
| 422 |
+
reasoning_budget: int = 512,
|
| 423 |
+
max_tokens: int = 1024,
|
| 424 |
+
**kwargs,
|
| 425 |
+
) -> Dict[str, Any]:
|
| 426 |
+
assert (
|
| 427 |
+
max_tokens > reasoning_budget
|
| 428 |
+
), f"thinking budget must be smaller than maximum new tokens. Given {max_tokens=} and {reasoning_budget=}"
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
# 1. first call chat completion to get reasoning content
|
| 432 |
+
response = self.client.chat.completions.create(
|
| 433 |
+
model=model, messages=messages, max_tokens=reasoning_budget, **kwargs
|
| 434 |
+
)
|
| 435 |
+
content = response.choices[0].message.content
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
reasoning_content = content
|
| 439 |
+
if not "</think>" in reasoning_content:
|
| 440 |
+
# reasoning content is too long, closed with a period (.)
|
| 441 |
+
reasoning_content = f"{reasoning_content}.\n</think>\n\n"
|
| 442 |
+
reasoning_tokens_len = len(
|
| 443 |
+
self.tokenizer.encode(reasoning_content, add_special_tokens=False)
|
| 444 |
+
)
|
| 445 |
+
remaining_tokens = max_tokens - reasoning_tokens_len
|
| 446 |
+
assert (
|
| 447 |
+
remaining_tokens > 0
|
| 448 |
+
), f"remaining tokens must be positive. Given {remaining_tokens=}. Increase the max_tokens or lower the reasoning_budget."
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
# 2. append reasoning content to messages and call completion
|
| 452 |
+
messages.append({"role": "assistant", "content": reasoning_content})
|
| 453 |
+
prompt = self.tokenizer.apply_chat_template(
|
| 454 |
+
messages,
|
| 455 |
+
tokenize=False,
|
| 456 |
+
continue_final_message=True,
|
| 457 |
+
)
|
| 458 |
+
response = self.client.completions.create(
|
| 459 |
+
model=model, prompt=prompt, max_tokens=remaining_tokens, **kwargs
|
| 460 |
+
)
|
| 461 |
+
|
| 462 |
+
|
| 463 |
+
response_data = {
|
| 464 |
+
"reasoning_content": reasoning_content.strip().strip("</think>").strip(),
|
| 465 |
+
"content": response.choices[0].text,
|
| 466 |
+
"finish_reason": response.choices[0].finish_reason,
|
| 467 |
+
}
|
| 468 |
+
return response_data
|
| 469 |
+
```
|
| 470 |
+
|
| 471 |
+
Calling the server with a budget (Restricted to 32 tokens here as an example)
|
| 472 |
+
|
| 473 |
+
```py
|
| 474 |
+
tokenizer_name_or_path = "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16"
|
| 475 |
+
client = ThinkingBudgetClient(
|
| 476 |
+
base_url="http://localhost:8000/v1", # Nemotron 3 Nano deployed in thinking mode
|
| 477 |
+
api_key="EMPTY",
|
| 478 |
+
tokenizer_name_or_path=tokenizer_name_or_path,
|
| 479 |
+
)
|
| 480 |
+
|
| 481 |
+
|
| 482 |
+
result = client.chat_completion(
|
| 483 |
+
model="nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16",
|
| 484 |
+
messages=[
|
| 485 |
+
{"role": "system", "content": "You are a helpful assistant. /think"},
|
| 486 |
+
{"role": "user", "content": "What is 2+2?"},
|
| 487 |
+
],
|
| 488 |
+
reasoning_budget=32,
|
| 489 |
+
max_tokens=512,
|
| 490 |
+
temperature=1.0,
|
| 491 |
+
top_p=1.0,
|
| 492 |
+
)
|
| 493 |
+
print(result)
|
| 494 |
+
```
|
| 495 |
+
|
| 496 |
+
You should see output similar to the following:
|
| 497 |
+
|
| 498 |
+
```
|
| 499 |
+
{'reasoning_content': "Okay, the user asked, What is 2+2? Let me think. Well, 2 plus 2 equals 4. That's a basic.", 'content': '2 + 2 equals **4**.\n', 'finish_reason': 'stop'}
|
| 500 |
+
```
|
| 501 |
+
|
| 502 |
## Model Version(s)
|
| 503 |
|
| 504 |
- v1.0
|