|
|
"""LangGraph Agent with Direct Groq API and Custom Rate Limiting""" |
|
|
import os |
|
|
import time |
|
|
import threading |
|
|
from collections import deque |
|
|
from typing import Dict, Any, List |
|
|
from dotenv import load_dotenv |
|
|
from langgraph.graph import START, StateGraph, MessagesState |
|
|
from langgraph.prebuilt import tools_condition |
|
|
from langgraph.prebuilt import ToolNode |
|
|
from langchain_community.tools.tavily_search import TavilySearchResults |
|
|
from langchain_community.document_loaders import WikipediaLoader |
|
|
from langchain_community.document_loaders import ArxivLoader |
|
|
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage |
|
|
from langchain_core.tools import tool |
|
|
from groq import Groq, RateLimitError |
|
|
import logging |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class GroqRateLimiter: |
|
|
"""Thread-safe rate limiter for direct Groq API calls""" |
|
|
|
|
|
def __init__(self, rpm: int = 20, tpm: int = 6000): |
|
|
self.rpm = rpm |
|
|
self.tpm = tpm |
|
|
self.request_times = deque() |
|
|
self.token_usage = deque() |
|
|
self.lock = threading.Lock() |
|
|
|
|
|
def _clean_old_records(self, current_time: float): |
|
|
"""Remove records older than 1 minute""" |
|
|
minute_ago = current_time - 60 |
|
|
|
|
|
while self.request_times and self.request_times[0] <= minute_ago: |
|
|
self.request_times.popleft() |
|
|
|
|
|
while self.token_usage and self.token_usage[0][0] <= minute_ago: |
|
|
self.token_usage.popleft() |
|
|
|
|
|
def can_make_request(self, estimated_tokens: int = 1000) -> tuple[bool, float]: |
|
|
"""Check if request can be made, return (can_proceed, wait_time)""" |
|
|
with self.lock: |
|
|
current_time = time.time() |
|
|
self._clean_old_records(current_time) |
|
|
|
|
|
wait_time = 0 |
|
|
|
|
|
|
|
|
if len(self.request_times) >= self.rpm: |
|
|
oldest_request = self.request_times[0] |
|
|
wait_time = max(wait_time, 60 - (current_time - oldest_request)) |
|
|
|
|
|
|
|
|
current_tokens = sum(tokens for _, tokens in self.token_usage) |
|
|
if current_tokens + estimated_tokens > self.tpm: |
|
|
if self.token_usage: |
|
|
oldest_token_time = self.token_usage[0][0] |
|
|
wait_time = max(wait_time, 60 - (current_time - oldest_token_time)) |
|
|
|
|
|
return wait_time <= 0, wait_time |
|
|
|
|
|
def record_request(self, token_count: int): |
|
|
"""Record a successful request""" |
|
|
with self.lock: |
|
|
current_time = time.time() |
|
|
self.request_times.append(current_time) |
|
|
self.token_usage.append((current_time, token_count)) |
|
|
|
|
|
class GroqWrapper: |
|
|
"""Wrapper for direct Groq API with rate limiting and error handling""" |
|
|
|
|
|
def __init__(self, model: str = "qwen/qwen3-32b", |
|
|
rpm: int = 30, tpm: int = 6000): |
|
|
self.client = Groq(api_key=os.getenv("GROQ_API_KEY")) |
|
|
self.model = model |
|
|
self.rate_limiter = GroqRateLimiter(rpm=rpm, tpm=tpm) |
|
|
|
|
|
def estimate_tokens(self, messages: List[Dict]) -> int: |
|
|
"""Rough token estimation (4 chars ≈ 1 token)""" |
|
|
total_chars = sum(len(str(msg.get('content', ''))) for msg in messages) |
|
|
return max(total_chars // 4, 100) |
|
|
|
|
|
def invoke(self, messages: List[Dict], **kwargs) -> Dict: |
|
|
"""Invoke Groq API with rate limiting and retry logic""" |
|
|
|
|
|
groq_messages = [] |
|
|
for msg in messages: |
|
|
if hasattr(msg, 'content') and hasattr(msg, 'type'): |
|
|
|
|
|
role = "user" if msg.type == "human" else "assistant" if msg.type == "ai" else "system" |
|
|
groq_messages.append({"role": role, "content": str(msg.content)}) |
|
|
else: |
|
|
|
|
|
groq_messages.append(msg) |
|
|
|
|
|
estimated_tokens = self.estimate_tokens(groq_messages) |
|
|
|
|
|
max_retries = 3 |
|
|
for attempt in range(max_retries): |
|
|
try: |
|
|
|
|
|
can_proceed, wait_time = self.rate_limiter.can_make_request(estimated_tokens) |
|
|
if not can_proceed: |
|
|
logger.info(f"Rate limit: waiting {wait_time:.2f} seconds") |
|
|
time.sleep(wait_time) |
|
|
|
|
|
|
|
|
response = self.client.chat.completions.create( |
|
|
model=self.model, |
|
|
messages=groq_messages, |
|
|
**kwargs |
|
|
) |
|
|
|
|
|
|
|
|
actual_tokens = response.usage.total_tokens if hasattr(response, 'usage') else estimated_tokens |
|
|
self.rate_limiter.record_request(actual_tokens) |
|
|
|
|
|
|
|
|
content = response.choices[0].message.content |
|
|
return AIMessage(content=content) |
|
|
|
|
|
except RateLimitError as e: |
|
|
if attempt == max_retries - 1: |
|
|
raise e |
|
|
|
|
|
|
|
|
retry_after = getattr(e.response, 'headers', {}).get('retry-after') |
|
|
if retry_after: |
|
|
delay = float(retry_after) |
|
|
else: |
|
|
delay = 2 ** attempt |
|
|
|
|
|
logger.warning(f"Rate limited. Retrying in {delay} seconds (attempt {attempt + 1})") |
|
|
time.sleep(delay) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Groq API error: {e}") |
|
|
if attempt == max_retries - 1: |
|
|
raise e |
|
|
time.sleep(2 ** attempt) |
|
|
|
|
|
raise Exception("Max retries exceeded") |
|
|
|
|
|
def bind_tools(self, tools): |
|
|
"""Mock bind_tools method for compatibility""" |
|
|
self.tools = tools |
|
|
return self |
|
|
|
|
|
|
|
|
@tool |
|
|
def multiply(a: int, b: int) -> int: |
|
|
"""Multiply two numbers.""" |
|
|
return a * b |
|
|
|
|
|
@tool |
|
|
def add(a: int, b: int) -> int: |
|
|
"""Add two numbers.""" |
|
|
return a + b |
|
|
|
|
|
@tool |
|
|
def subtract(a: int, b: int) -> int: |
|
|
"""Subtract two numbers.""" |
|
|
return a - b |
|
|
|
|
|
@tool |
|
|
def divide(a: float, b: float) -> float: |
|
|
"""Divide two numbers.""" |
|
|
if b == 0: |
|
|
raise ValueError("Cannot divide by zero.") |
|
|
return a / b |
|
|
|
|
|
@tool |
|
|
def modulus(a: int, b: int) -> int: |
|
|
"""Get the modulus of two numbers.""" |
|
|
return a % b |
|
|
|
|
|
@tool |
|
|
def wiki_search(query: str) -> str: |
|
|
"""Search Wikipedia for a query and return maximum 2 results.""" |
|
|
try: |
|
|
search_docs = WikipediaLoader(query=query, load_max_docs=2).load() |
|
|
formatted_search_docs = "\n\n---\n\n".join( |
|
|
[ |
|
|
f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>' |
|
|
for doc in search_docs |
|
|
]) |
|
|
return {"wiki_results": formatted_search_docs} |
|
|
except Exception as e: |
|
|
return {"wiki_results": f"Error: {str(e)}"} |
|
|
|
|
|
@tool |
|
|
def web_search(query: str) -> str: |
|
|
"""Search Tavily for a query and return maximum 3 results.""" |
|
|
try: |
|
|
search_docs = TavilySearchResults(max_results=3).invoke(query=query) |
|
|
formatted_search_docs = "\n\n---\n\n".join( |
|
|
[ |
|
|
f'<Document source="{doc.get("url", "")}">\n{doc.get("content", "")}\n</Document>' |
|
|
for doc in search_docs |
|
|
]) |
|
|
return {"web_results": formatted_search_docs} |
|
|
except Exception as e: |
|
|
return {"web_results": f"Error: {str(e)}"} |
|
|
|
|
|
@tool |
|
|
def arxiv_search(query: str) -> str: |
|
|
"""Search Arxiv for a query and return maximum 3 results.""" |
|
|
try: |
|
|
search_docs = ArxivLoader(query=query, load_max_docs=3).load() |
|
|
formatted_search_docs = "\n\n---\n\n".join( |
|
|
[ |
|
|
f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>' |
|
|
for doc in search_docs |
|
|
]) |
|
|
return {"arxiv_results": formatted_search_docs} |
|
|
except Exception as e: |
|
|
return {"arxiv_results": f"Error: {str(e)}"} |
|
|
|
|
|
def load_system_prompt(): |
|
|
"""Load system prompt with error handling""" |
|
|
|
|
|
with open("system_prompt.txt", "r", encoding="utf-8") as f: |
|
|
return f.read() |
|
|
|
|
|
|
|
|
system_prompt = load_system_prompt() |
|
|
sys_msg = SystemMessage(content=system_prompt) |
|
|
|
|
|
tools = [ |
|
|
multiply, |
|
|
add, |
|
|
subtract, |
|
|
divide, |
|
|
modulus, |
|
|
wiki_search, |
|
|
web_search, |
|
|
arxiv_search, |
|
|
] |
|
|
|
|
|
def build_graph(provider: str = "direct_groq", model: str = "qwen/qwen3-32b"): |
|
|
"""Build the graph with direct Groq API and custom rate limiting""" |
|
|
|
|
|
if provider == "direct_groq": |
|
|
|
|
|
llm = GroqWrapper(model=model, rpm=30, tpm=6000) |
|
|
|
|
|
elif provider == "langchain_groq": |
|
|
|
|
|
from langchain_core.rate_limiters import InMemoryRateLimiter |
|
|
|
|
|
rate_limiter = InMemoryRateLimiter( |
|
|
requests_per_second=0.5, |
|
|
check_every_n_seconds=0.1, |
|
|
max_bucket_size=5, |
|
|
) |
|
|
|
|
|
from langchain_groq import ChatGroq |
|
|
llm = ChatGroq( |
|
|
model=model, |
|
|
temperature=0, |
|
|
groq_api_key=os.getenv("GROQ_API_KEY"), |
|
|
rate_limiter=rate_limiter |
|
|
) |
|
|
else: |
|
|
raise ValueError("Choose 'direct_groq' or 'langchain_groq'") |
|
|
|
|
|
|
|
|
llm_with_tools = llm.bind_tools(tools) |
|
|
|
|
|
def assistant(state: MessagesState): |
|
|
"""Assistant node""" |
|
|
try: |
|
|
response = llm_with_tools.invoke(state["messages"]) |
|
|
return {"messages": [response]} |
|
|
except Exception as e: |
|
|
logger.error(f"Assistant failed: {e}") |
|
|
error_msg = AIMessage(content=f"I encountered an error: {str(e)}") |
|
|
return {"messages": [error_msg]} |
|
|
|
|
|
|
|
|
builder = StateGraph(MessagesState) |
|
|
builder.add_node("assistant", assistant) |
|
|
builder.add_node("tools", ToolNode(tools)) |
|
|
builder.add_edge(START, "assistant") |
|
|
builder.add_conditional_edges("assistant", tools_condition) |
|
|
builder.add_edge("tools", "assistant") |
|
|
|
|
|
return builder.compile() |
|
|
|
|
|
if __name__ == "__main__": |
|
|
question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?" |
|
|
|
|
|
try: |
|
|
|
|
|
graph = build_graph(provider="direct_groq") |
|
|
messages = [HumanMessage(content=question)] |
|
|
result = graph.invoke({"messages": messages}) |
|
|
|
|
|
for m in result["messages"]: |
|
|
m.pretty_print() |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Test failed: {e}") |