"""LangGraph Agent with Direct Groq API and Custom Rate Limiting"""
import os
import time
import threading
from collections import deque
from typing import Dict, Any, List
from dotenv import load_dotenv
from langgraph.graph import START, StateGraph, MessagesState
from langgraph.prebuilt import tools_condition
from langgraph.prebuilt import ToolNode
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_community.document_loaders import WikipediaLoader
from langchain_community.document_loaders import ArxivLoader
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
from langchain_core.tools import tool
from groq import Groq, RateLimitError
import logging
load_dotenv()
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class GroqRateLimiter:
"""Thread-safe rate limiter for direct Groq API calls"""
def __init__(self, rpm: int = 20, tpm: int = 6000):
self.rpm = rpm # Requests per minute
self.tpm = tpm # Tokens per minute
self.request_times = deque()
self.token_usage = deque() # (timestamp, token_count) tuples
self.lock = threading.Lock()
def _clean_old_records(self, current_time: float):
"""Remove records older than 1 minute"""
minute_ago = current_time - 60
while self.request_times and self.request_times[0] <= minute_ago:
self.request_times.popleft()
while self.token_usage and self.token_usage[0][0] <= minute_ago:
self.token_usage.popleft()
def can_make_request(self, estimated_tokens: int = 1000) -> tuple[bool, float]:
"""Check if request can be made, return (can_proceed, wait_time)"""
with self.lock:
current_time = time.time()
self._clean_old_records(current_time)
wait_time = 0
# Check RPM limit
if len(self.request_times) >= self.rpm:
oldest_request = self.request_times[0]
wait_time = max(wait_time, 60 - (current_time - oldest_request))
# Check TPM limit
current_tokens = sum(tokens for _, tokens in self.token_usage)
if current_tokens + estimated_tokens > self.tpm:
if self.token_usage:
oldest_token_time = self.token_usage[0][0]
wait_time = max(wait_time, 60 - (current_time - oldest_token_time))
return wait_time <= 0, wait_time
def record_request(self, token_count: int):
"""Record a successful request"""
with self.lock:
current_time = time.time()
self.request_times.append(current_time)
self.token_usage.append((current_time, token_count))
class GroqWrapper:
"""Wrapper for direct Groq API with rate limiting and error handling"""
def __init__(self, model: str = "qwen/qwen3-32b",
rpm: int = 30, tpm: int = 6000):
self.client = Groq(api_key=os.getenv("GROQ_API_KEY"))
self.model = model
self.rate_limiter = GroqRateLimiter(rpm=rpm, tpm=tpm)
def estimate_tokens(self, messages: List[Dict]) -> int:
"""Rough token estimation (4 chars ≈ 1 token)"""
total_chars = sum(len(str(msg.get('content', ''))) for msg in messages)
return max(total_chars // 4, 100)
def invoke(self, messages: List[Dict], **kwargs) -> Dict:
"""Invoke Groq API with rate limiting and retry logic"""
# Convert LangChain messages to Groq format if needed
groq_messages = []
for msg in messages:
if hasattr(msg, 'content') and hasattr(msg, 'type'):
# LangChain message object
role = "user" if msg.type == "human" else "assistant" if msg.type == "ai" else "system"
groq_messages.append({"role": role, "content": str(msg.content)})
else:
# Already in dict format
groq_messages.append(msg)
estimated_tokens = self.estimate_tokens(groq_messages)
max_retries = 3
for attempt in range(max_retries):
try:
# Check rate limits
can_proceed, wait_time = self.rate_limiter.can_make_request(estimated_tokens)
if not can_proceed:
logger.info(f"Rate limit: waiting {wait_time:.2f} seconds")
time.sleep(wait_time)
# Make the API call
response = self.client.chat.completions.create(
model=self.model,
messages=groq_messages,
**kwargs
)
# Record successful request
actual_tokens = response.usage.total_tokens if hasattr(response, 'usage') else estimated_tokens
self.rate_limiter.record_request(actual_tokens)
# Convert back to LangChain format
content = response.choices[0].message.content
return AIMessage(content=content)
except RateLimitError as e:
if attempt == max_retries - 1:
raise e
# Use retry-after header if available
retry_after = getattr(e.response, 'headers', {}).get('retry-after')
if retry_after:
delay = float(retry_after)
else:
delay = 2 ** attempt # Exponential backoff
logger.warning(f"Rate limited. Retrying in {delay} seconds (attempt {attempt + 1})")
time.sleep(delay)
except Exception as e:
logger.error(f"Groq API error: {e}")
if attempt == max_retries - 1:
raise e
time.sleep(2 ** attempt)
raise Exception("Max retries exceeded")
def bind_tools(self, tools):
"""Mock bind_tools method for compatibility"""
self.tools = tools
return self
# Your existing tools
@tool
def multiply(a: int, b: int) -> int:
"""Multiply two numbers."""
return a * b
@tool
def add(a: int, b: int) -> int:
"""Add two numbers."""
return a + b
@tool
def subtract(a: int, b: int) -> int:
"""Subtract two numbers."""
return a - b
@tool
def divide(a: float, b: float) -> float:
"""Divide two numbers."""
if b == 0:
raise ValueError("Cannot divide by zero.")
return a / b
@tool
def modulus(a: int, b: int) -> int:
"""Get the modulus of two numbers."""
return a % b
@tool
def wiki_search(query: str) -> str:
"""Search Wikipedia for a query and return maximum 2 results."""
try:
search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
formatted_search_docs = "\n\n---\n\n".join(
[
f'\n{doc.page_content}\n'
for doc in search_docs
])
return {"wiki_results": formatted_search_docs}
except Exception as e:
return {"wiki_results": f"Error: {str(e)}"}
@tool
def web_search(query: str) -> str:
"""Search Tavily for a query and return maximum 3 results."""
try:
search_docs = TavilySearchResults(max_results=3).invoke(query=query)
formatted_search_docs = "\n\n---\n\n".join(
[
f'\n{doc.get("content", "")}\n'
for doc in search_docs
])
return {"web_results": formatted_search_docs}
except Exception as e:
return {"web_results": f"Error: {str(e)}"}
@tool
def arxiv_search(query: str) -> str:
"""Search Arxiv for a query and return maximum 3 results."""
try:
search_docs = ArxivLoader(query=query, load_max_docs=3).load()
formatted_search_docs = "\n\n---\n\n".join(
[
f'\n{doc.page_content[:1000]}\n'
for doc in search_docs
])
return {"arxiv_results": formatted_search_docs}
except Exception as e:
return {"arxiv_results": f"Error: {str(e)}"}
def load_system_prompt():
"""Load system prompt with error handling"""
with open("system_prompt.txt", "r", encoding="utf-8") as f:
return f.read()
system_prompt = load_system_prompt()
sys_msg = SystemMessage(content=system_prompt)
tools = [
multiply,
add,
subtract,
divide,
modulus,
wiki_search,
web_search,
arxiv_search,
]
def build_graph(provider: str = "direct_groq", model: str = "qwen/qwen3-32b"):
"""Build the graph with direct Groq API and custom rate limiting"""
if provider == "direct_groq":
# Use custom Groq wrapper with rate limiting
llm = GroqWrapper(model=model, rpm=30, tpm=6000) # Adjust based on your plan
elif provider == "langchain_groq":
# Use LangChain's ChatGroq with native rate limiting
from langchain_core.rate_limiters import InMemoryRateLimiter
rate_limiter = InMemoryRateLimiter(
requests_per_second=0.5, # 30 RPM
check_every_n_seconds=0.1,
max_bucket_size=5,
)
from langchain_groq import ChatGroq
llm = ChatGroq(
model=model,
temperature=0,
groq_api_key=os.getenv("GROQ_API_KEY"),
rate_limiter=rate_limiter
)
else:
raise ValueError("Choose 'direct_groq' or 'langchain_groq'")
# Bind tools to LLM
llm_with_tools = llm.bind_tools(tools)
def assistant(state: MessagesState):
"""Assistant node"""
try:
response = llm_with_tools.invoke(state["messages"])
return {"messages": [response]}
except Exception as e:
logger.error(f"Assistant failed: {e}")
error_msg = AIMessage(content=f"I encountered an error: {str(e)}")
return {"messages": [error_msg]}
# Build the graph
builder = StateGraph(MessagesState)
builder.add_node("assistant", assistant)
builder.add_node("tools", ToolNode(tools))
builder.add_edge(START, "assistant")
builder.add_conditional_edges("assistant", tools_condition)
builder.add_edge("tools", "assistant")
return builder.compile()
if __name__ == "__main__":
question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
try:
# Test with direct Groq API
graph = build_graph(provider="direct_groq")
messages = [HumanMessage(content=question)]
result = graph.invoke({"messages": messages})
for m in result["messages"]:
m.pretty_print()
except Exception as e:
logger.error(f"Test failed: {e}")