"""LangGraph Agent with Direct Groq API and Custom Rate Limiting""" import os import time import threading from collections import deque from typing import Dict, Any, List from dotenv import load_dotenv from langgraph.graph import START, StateGraph, MessagesState from langgraph.prebuilt import tools_condition from langgraph.prebuilt import ToolNode from langchain_community.tools.tavily_search import TavilySearchResults from langchain_community.document_loaders import WikipediaLoader from langchain_community.document_loaders import ArxivLoader from langchain_core.messages import SystemMessage, HumanMessage, AIMessage from langchain_core.tools import tool from groq import Groq, RateLimitError import logging load_dotenv() # Setup logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class GroqRateLimiter: """Thread-safe rate limiter for direct Groq API calls""" def __init__(self, rpm: int = 20, tpm: int = 6000): self.rpm = rpm # Requests per minute self.tpm = tpm # Tokens per minute self.request_times = deque() self.token_usage = deque() # (timestamp, token_count) tuples self.lock = threading.Lock() def _clean_old_records(self, current_time: float): """Remove records older than 1 minute""" minute_ago = current_time - 60 while self.request_times and self.request_times[0] <= minute_ago: self.request_times.popleft() while self.token_usage and self.token_usage[0][0] <= minute_ago: self.token_usage.popleft() def can_make_request(self, estimated_tokens: int = 1000) -> tuple[bool, float]: """Check if request can be made, return (can_proceed, wait_time)""" with self.lock: current_time = time.time() self._clean_old_records(current_time) wait_time = 0 # Check RPM limit if len(self.request_times) >= self.rpm: oldest_request = self.request_times[0] wait_time = max(wait_time, 60 - (current_time - oldest_request)) # Check TPM limit current_tokens = sum(tokens for _, tokens in self.token_usage) if current_tokens + estimated_tokens > self.tpm: if self.token_usage: oldest_token_time = self.token_usage[0][0] wait_time = max(wait_time, 60 - (current_time - oldest_token_time)) return wait_time <= 0, wait_time def record_request(self, token_count: int): """Record a successful request""" with self.lock: current_time = time.time() self.request_times.append(current_time) self.token_usage.append((current_time, token_count)) class GroqWrapper: """Wrapper for direct Groq API with rate limiting and error handling""" def __init__(self, model: str = "qwen/qwen3-32b", rpm: int = 30, tpm: int = 6000): self.client = Groq(api_key=os.getenv("GROQ_API_KEY")) self.model = model self.rate_limiter = GroqRateLimiter(rpm=rpm, tpm=tpm) def estimate_tokens(self, messages: List[Dict]) -> int: """Rough token estimation (4 chars ≈ 1 token)""" total_chars = sum(len(str(msg.get('content', ''))) for msg in messages) return max(total_chars // 4, 100) def invoke(self, messages: List[Dict], **kwargs) -> Dict: """Invoke Groq API with rate limiting and retry logic""" # Convert LangChain messages to Groq format if needed groq_messages = [] for msg in messages: if hasattr(msg, 'content') and hasattr(msg, 'type'): # LangChain message object role = "user" if msg.type == "human" else "assistant" if msg.type == "ai" else "system" groq_messages.append({"role": role, "content": str(msg.content)}) else: # Already in dict format groq_messages.append(msg) estimated_tokens = self.estimate_tokens(groq_messages) max_retries = 3 for attempt in range(max_retries): try: # Check rate limits can_proceed, wait_time = self.rate_limiter.can_make_request(estimated_tokens) if not can_proceed: logger.info(f"Rate limit: waiting {wait_time:.2f} seconds") time.sleep(wait_time) # Make the API call response = self.client.chat.completions.create( model=self.model, messages=groq_messages, **kwargs ) # Record successful request actual_tokens = response.usage.total_tokens if hasattr(response, 'usage') else estimated_tokens self.rate_limiter.record_request(actual_tokens) # Convert back to LangChain format content = response.choices[0].message.content return AIMessage(content=content) except RateLimitError as e: if attempt == max_retries - 1: raise e # Use retry-after header if available retry_after = getattr(e.response, 'headers', {}).get('retry-after') if retry_after: delay = float(retry_after) else: delay = 2 ** attempt # Exponential backoff logger.warning(f"Rate limited. Retrying in {delay} seconds (attempt {attempt + 1})") time.sleep(delay) except Exception as e: logger.error(f"Groq API error: {e}") if attempt == max_retries - 1: raise e time.sleep(2 ** attempt) raise Exception("Max retries exceeded") def bind_tools(self, tools): """Mock bind_tools method for compatibility""" self.tools = tools return self # Your existing tools @tool def multiply(a: int, b: int) -> int: """Multiply two numbers.""" return a * b @tool def add(a: int, b: int) -> int: """Add two numbers.""" return a + b @tool def subtract(a: int, b: int) -> int: """Subtract two numbers.""" return a - b @tool def divide(a: float, b: float) -> float: """Divide two numbers.""" if b == 0: raise ValueError("Cannot divide by zero.") return a / b @tool def modulus(a: int, b: int) -> int: """Get the modulus of two numbers.""" return a % b @tool def wiki_search(query: str) -> str: """Search Wikipedia for a query and return maximum 2 results.""" try: search_docs = WikipediaLoader(query=query, load_max_docs=2).load() formatted_search_docs = "\n\n---\n\n".join( [ f'\n{doc.page_content}\n' for doc in search_docs ]) return {"wiki_results": formatted_search_docs} except Exception as e: return {"wiki_results": f"Error: {str(e)}"} @tool def web_search(query: str) -> str: """Search Tavily for a query and return maximum 3 results.""" try: search_docs = TavilySearchResults(max_results=3).invoke(query=query) formatted_search_docs = "\n\n---\n\n".join( [ f'\n{doc.get("content", "")}\n' for doc in search_docs ]) return {"web_results": formatted_search_docs} except Exception as e: return {"web_results": f"Error: {str(e)}"} @tool def arxiv_search(query: str) -> str: """Search Arxiv for a query and return maximum 3 results.""" try: search_docs = ArxivLoader(query=query, load_max_docs=3).load() formatted_search_docs = "\n\n---\n\n".join( [ f'\n{doc.page_content[:1000]}\n' for doc in search_docs ]) return {"arxiv_results": formatted_search_docs} except Exception as e: return {"arxiv_results": f"Error: {str(e)}"} def load_system_prompt(): """Load system prompt with error handling""" with open("system_prompt.txt", "r", encoding="utf-8") as f: return f.read() system_prompt = load_system_prompt() sys_msg = SystemMessage(content=system_prompt) tools = [ multiply, add, subtract, divide, modulus, wiki_search, web_search, arxiv_search, ] def build_graph(provider: str = "direct_groq", model: str = "qwen/qwen3-32b"): """Build the graph with direct Groq API and custom rate limiting""" if provider == "direct_groq": # Use custom Groq wrapper with rate limiting llm = GroqWrapper(model=model, rpm=30, tpm=6000) # Adjust based on your plan elif provider == "langchain_groq": # Use LangChain's ChatGroq with native rate limiting from langchain_core.rate_limiters import InMemoryRateLimiter rate_limiter = InMemoryRateLimiter( requests_per_second=0.5, # 30 RPM check_every_n_seconds=0.1, max_bucket_size=5, ) from langchain_groq import ChatGroq llm = ChatGroq( model=model, temperature=0, groq_api_key=os.getenv("GROQ_API_KEY"), rate_limiter=rate_limiter ) else: raise ValueError("Choose 'direct_groq' or 'langchain_groq'") # Bind tools to LLM llm_with_tools = llm.bind_tools(tools) def assistant(state: MessagesState): """Assistant node""" try: response = llm_with_tools.invoke(state["messages"]) return {"messages": [response]} except Exception as e: logger.error(f"Assistant failed: {e}") error_msg = AIMessage(content=f"I encountered an error: {str(e)}") return {"messages": [error_msg]} # Build the graph builder = StateGraph(MessagesState) builder.add_node("assistant", assistant) builder.add_node("tools", ToolNode(tools)) builder.add_edge(START, "assistant") builder.add_conditional_edges("assistant", tools_condition) builder.add_edge("tools", "assistant") return builder.compile() if __name__ == "__main__": question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?" try: # Test with direct Groq API graph = build_graph(provider="direct_groq") messages = [HumanMessage(content=question)] result = graph.invoke({"messages": messages}) for m in result["messages"]: m.pretty_print() except Exception as e: logger.error(f"Test failed: {e}")