bwilkie commited on
Commit
c99f4b7
·
verified ·
1 Parent(s): e268f7b

Update agent_simple.py

Browse files
Files changed (1) hide show
  1. agent_simple.py +235 -115
agent_simple.py CHANGED
@@ -1,133 +1,238 @@
1
- """LangGraph Agent"""
2
  import os
 
 
 
 
3
  from dotenv import load_dotenv
4
  from langgraph.graph import START, StateGraph, MessagesState
5
  from langgraph.prebuilt import tools_condition
6
  from langgraph.prebuilt import ToolNode
7
- from langchain_google_genai import ChatGoogleGenerativeAI
8
- from langchain_groq import ChatGroq
9
- from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
10
  from langchain_community.tools.tavily_search import TavilySearchResults
11
  from langchain_community.document_loaders import WikipediaLoader
12
  from langchain_community.document_loaders import ArxivLoader
13
- from langchain_core.messages import SystemMessage, HumanMessage
14
  from langchain_core.tools import tool
 
 
15
 
16
  load_dotenv()
17
 
18
- from langchain_core.rate_limiters import InMemoryRateLimiter
 
 
19
 
20
- # Create a rate limiter
21
- rate_limiter = InMemoryRateLimiter(
22
- requests_per_second=0.1, # Once every 10 seconds
23
- check_every_n_seconds=0.1,
24
- max_bucket_size=10,
25
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
 
28
  @tool
29
  def multiply(a: int, b: int) -> int:
30
- """Multiply two numbers.
31
- Args:
32
- a: first int
33
- b: second int
34
- """
35
  return a * b
36
 
37
  @tool
38
  def add(a: int, b: int) -> int:
39
- """Add two numbers.
40
-
41
- Args:
42
- a: first int
43
- b: second int
44
- """
45
  return a + b
46
 
47
  @tool
48
  def subtract(a: int, b: int) -> int:
49
- """Subtract two numbers.
50
-
51
- Args:
52
- a: first int
53
- b: second int
54
- """
55
  return a - b
56
 
57
  @tool
58
- def divide(a: int, b: int) -> int:
59
- """Divide two numbers.
60
-
61
- Args:
62
- a: first int
63
- b: second int
64
- """
65
  if b == 0:
66
  raise ValueError("Cannot divide by zero.")
67
  return a / b
68
 
69
  @tool
70
  def modulus(a: int, b: int) -> int:
71
- """Get the modulus of two numbers.
72
-
73
- Args:
74
- a: first int
75
- b: second int
76
- """
77
  return a % b
78
 
79
  @tool
80
  def wiki_search(query: str) -> str:
81
- """Search Wikipedia for a query and return maximum 2 results.
82
-
83
- Args:
84
- query: The search query."""
85
- search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
86
- formatted_search_docs = "\n\n---\n\n".join(
87
- [
88
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
89
- for doc in search_docs
90
- ])
91
- return {"wiki_results": formatted_search_docs}
92
 
93
  @tool
94
  def web_search(query: str) -> str:
95
- """Search Tavily for a query and return maximum 3 results.
96
-
97
- Args:
98
- query: The search query."""
99
- search_docs = TavilySearchResults(max_results=3).invoke(query=query)
100
- formatted_search_docs = "\n\n---\n\n".join(
101
- [
102
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
103
- for doc in search_docs
104
- ])
105
- return {"web_results": formatted_search_docs}
106
 
107
  @tool
108
- def arvix_search(query: str) -> str:
109
- """Search Arxiv for a query and return maximum 3 result.
110
-
111
- Args:
112
- query: The search query."""
113
- search_docs = ArxivLoader(query=query, load_max_docs=3).load()
114
- formatted_search_docs = "\n\n---\n\n".join(
115
- [
116
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
117
- for doc in search_docs
118
- ])
119
- return {"arvix_results": formatted_search_docs}
120
-
121
 
 
 
 
 
 
 
 
 
122
 
123
- # load the system prompt from the file
124
- with open("system_prompt.txt", "r", encoding="utf-8") as f:
125
- system_prompt = f.read()
126
-
127
- # System message
128
  sys_msg = SystemMessage(content=system_prompt)
129
 
130
-
131
  tools = [
132
  multiply,
133
  add,
@@ -136,55 +241,70 @@ tools = [
136
  modulus,
137
  wiki_search,
138
  web_search,
139
- arvix_search,
140
  ]
141
 
142
- # Build graph function
143
- def build_graph(provider: str = "groq"):
144
- """Build the graph"""
145
- # Load environment variables from .env file
146
- if provider == "google":
147
- llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0, rate_limiter=rate_limiter)
148
- elif provider == "groq":
149
- llm = ChatGroq(model="qwen-qwen3-32b", temperature=0, rate_limiter=rate_limiter)
150
- elif provider == "huggingface":
151
- llm = ChatHuggingFace(
152
- llm=HuggingFaceEndpoint(
153
- url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
154
- temperature=0,
155
- rate_limiter=rate_limiter,
156
- ),
 
 
 
 
 
 
 
 
157
  )
158
  else:
159
- raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
 
160
  # Bind tools to LLM
161
  llm_with_tools = llm.bind_tools(tools)
162
 
163
- # Node
164
  def assistant(state: MessagesState):
165
  """Assistant node"""
166
- return {"messages": [llm_with_tools.invoke(state["messages"])]}
 
 
 
 
 
 
167
 
 
168
  builder = StateGraph(MessagesState)
169
  builder.add_node("assistant", assistant)
170
  builder.add_node("tools", ToolNode(tools))
171
  builder.add_edge(START, "assistant")
172
- builder.add_conditional_edges(
173
- "assistant",
174
- tools_condition,
175
- )
176
  builder.add_edge("tools", "assistant")
177
 
178
- # Compile graph
179
  return builder.compile()
180
 
181
- # test
182
  if __name__ == "__main__":
183
  question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
184
- # Build the graph
185
- graph = build_graph(provider="groq")
186
- # Run the graph
187
- messages = [HumanMessage(content=question)]
188
- messages = graph.invoke({"messages": messages})
189
- for m in messages["messages"]:
190
- m.pretty_print()
 
 
 
 
 
 
1
+ """LangGraph Agent with Direct Groq API and Custom Rate Limiting"""
2
  import os
3
+ import time
4
+ import threading
5
+ from collections import deque
6
+ from typing import Dict, Any, List
7
  from dotenv import load_dotenv
8
  from langgraph.graph import START, StateGraph, MessagesState
9
  from langgraph.prebuilt import tools_condition
10
  from langgraph.prebuilt import ToolNode
 
 
 
11
  from langchain_community.tools.tavily_search import TavilySearchResults
12
  from langchain_community.document_loaders import WikipediaLoader
13
  from langchain_community.document_loaders import ArxivLoader
14
+ from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
15
  from langchain_core.tools import tool
16
+ from groq import Groq, RateLimitError
17
+ import logging
18
 
19
  load_dotenv()
20
 
21
+ # Setup logging
22
+ logging.basicConfig(level=logging.INFO)
23
+ logger = logging.getLogger(__name__)
24
 
25
+ class GroqRateLimiter:
26
+ """Thread-safe rate limiter for direct Groq API calls"""
27
+
28
+ def __init__(self, rpm: int = 20, tpm: int = 6000):
29
+ self.rpm = rpm # Requests per minute
30
+ self.tpm = tpm # Tokens per minute
31
+ self.request_times = deque()
32
+ self.token_usage = deque() # (timestamp, token_count) tuples
33
+ self.lock = threading.Lock()
34
+
35
+ def _clean_old_records(self, current_time: float):
36
+ """Remove records older than 1 minute"""
37
+ minute_ago = current_time - 60
38
+
39
+ while self.request_times and self.request_times[0] <= minute_ago:
40
+ self.request_times.popleft()
41
+
42
+ while self.token_usage and self.token_usage[0][0] <= minute_ago:
43
+ self.token_usage.popleft()
44
+
45
+ def can_make_request(self, estimated_tokens: int = 1000) -> tuple[bool, float]:
46
+ """Check if request can be made, return (can_proceed, wait_time)"""
47
+ with self.lock:
48
+ current_time = time.time()
49
+ self._clean_old_records(current_time)
50
+
51
+ wait_time = 0
52
+
53
+ # Check RPM limit
54
+ if len(self.request_times) >= self.rpm:
55
+ oldest_request = self.request_times[0]
56
+ wait_time = max(wait_time, 60 - (current_time - oldest_request))
57
+
58
+ # Check TPM limit
59
+ current_tokens = sum(tokens for _, tokens in self.token_usage)
60
+ if current_tokens + estimated_tokens > self.tpm:
61
+ if self.token_usage:
62
+ oldest_token_time = self.token_usage[0][0]
63
+ wait_time = max(wait_time, 60 - (current_time - oldest_token_time))
64
+
65
+ return wait_time <= 0, wait_time
66
+
67
+ def record_request(self, token_count: int):
68
+ """Record a successful request"""
69
+ with self.lock:
70
+ current_time = time.time()
71
+ self.request_times.append(current_time)
72
+ self.token_usage.append((current_time, token_count))
73
 
74
+ class GroqWrapper:
75
+ """Wrapper for direct Groq API with rate limiting and error handling"""
76
+
77
+ def __init__(self, model: str = "llama-3.1-70b-versatile",
78
+ rpm: int = 30, tpm: int = 6000):
79
+ self.client = Groq(api_key=os.getenv("GROQ_API_KEY"))
80
+ self.model = model
81
+ self.rate_limiter = GroqRateLimiter(rpm=rpm, tpm=tpm)
82
+
83
+ def estimate_tokens(self, messages: List[Dict]) -> int:
84
+ """Rough token estimation (4 chars ≈ 1 token)"""
85
+ total_chars = sum(len(str(msg.get('content', ''))) for msg in messages)
86
+ return max(total_chars // 4, 100)
87
+
88
+ def invoke(self, messages: List[Dict], **kwargs) -> Dict:
89
+ """Invoke Groq API with rate limiting and retry logic"""
90
+ # Convert LangChain messages to Groq format if needed
91
+ groq_messages = []
92
+ for msg in messages:
93
+ if hasattr(msg, 'content') and hasattr(msg, 'type'):
94
+ # LangChain message object
95
+ role = "user" if msg.type == "human" else "assistant" if msg.type == "ai" else "system"
96
+ groq_messages.append({"role": role, "content": str(msg.content)})
97
+ else:
98
+ # Already in dict format
99
+ groq_messages.append(msg)
100
+
101
+ estimated_tokens = self.estimate_tokens(groq_messages)
102
+
103
+ max_retries = 3
104
+ for attempt in range(max_retries):
105
+ try:
106
+ # Check rate limits
107
+ can_proceed, wait_time = self.rate_limiter.can_make_request(estimated_tokens)
108
+ if not can_proceed:
109
+ logger.info(f"Rate limit: waiting {wait_time:.2f} seconds")
110
+ time.sleep(wait_time)
111
+
112
+ # Make the API call
113
+ response = self.client.chat.completions.create(
114
+ model=self.model,
115
+ messages=groq_messages,
116
+ **kwargs
117
+ )
118
+
119
+ # Record successful request
120
+ actual_tokens = response.usage.total_tokens if hasattr(response, 'usage') else estimated_tokens
121
+ self.rate_limiter.record_request(actual_tokens)
122
+
123
+ # Convert back to LangChain format
124
+ content = response.choices[0].message.content
125
+ return AIMessage(content=content)
126
+
127
+ except RateLimitError as e:
128
+ if attempt == max_retries - 1:
129
+ raise e
130
+
131
+ # Use retry-after header if available
132
+ retry_after = getattr(e.response, 'headers', {}).get('retry-after')
133
+ if retry_after:
134
+ delay = float(retry_after)
135
+ else:
136
+ delay = 2 ** attempt # Exponential backoff
137
+
138
+ logger.warning(f"Rate limited. Retrying in {delay} seconds (attempt {attempt + 1})")
139
+ time.sleep(delay)
140
+
141
+ except Exception as e:
142
+ logger.error(f"Groq API error: {e}")
143
+ if attempt == max_retries - 1:
144
+ raise e
145
+ time.sleep(2 ** attempt)
146
+
147
+ raise Exception("Max retries exceeded")
148
+
149
+ def bind_tools(self, tools):
150
+ """Mock bind_tools method for compatibility"""
151
+ self.tools = tools
152
+ return self
153
 
154
+ # Your existing tools
155
  @tool
156
  def multiply(a: int, b: int) -> int:
157
+ """Multiply two numbers."""
 
 
 
 
158
  return a * b
159
 
160
  @tool
161
  def add(a: int, b: int) -> int:
162
+ """Add two numbers."""
 
 
 
 
 
163
  return a + b
164
 
165
  @tool
166
  def subtract(a: int, b: int) -> int:
167
+ """Subtract two numbers."""
 
 
 
 
 
168
  return a - b
169
 
170
  @tool
171
+ def divide(a: float, b: float) -> float:
172
+ """Divide two numbers."""
 
 
 
 
 
173
  if b == 0:
174
  raise ValueError("Cannot divide by zero.")
175
  return a / b
176
 
177
  @tool
178
  def modulus(a: int, b: int) -> int:
179
+ """Get the modulus of two numbers."""
 
 
 
 
 
180
  return a % b
181
 
182
  @tool
183
  def wiki_search(query: str) -> str:
184
+ """Search Wikipedia for a query and return maximum 2 results."""
185
+ try:
186
+ search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
187
+ formatted_search_docs = "\n\n---\n\n".join(
188
+ [
189
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
190
+ for doc in search_docs
191
+ ])
192
+ return {"wiki_results": formatted_search_docs}
193
+ except Exception as e:
194
+ return {"wiki_results": f"Error: {str(e)}"}
195
 
196
  @tool
197
  def web_search(query: str) -> str:
198
+ """Search Tavily for a query and return maximum 3 results."""
199
+ try:
200
+ search_docs = TavilySearchResults(max_results=3).invoke(query=query)
201
+ formatted_search_docs = "\n\n---\n\n".join(
202
+ [
203
+ f'<Document source="{doc.get("url", "")}">\n{doc.get("content", "")}\n</Document>'
204
+ for doc in search_docs
205
+ ])
206
+ return {"web_results": formatted_search_docs}
207
+ except Exception as e:
208
+ return {"web_results": f"Error: {str(e)}"}
209
 
210
  @tool
211
+ def arxiv_search(query: str) -> str:
212
+ """Search Arxiv for a query and return maximum 3 results."""
213
+ try:
214
+ search_docs = ArxivLoader(query=query, load_max_docs=3).load()
215
+ formatted_search_docs = "\n\n---\n\n".join(
216
+ [
217
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
218
+ for doc in search_docs
219
+ ])
220
+ return {"arxiv_results": formatted_search_docs}
221
+ except Exception as e:
222
+ return {"arxiv_results": f"Error: {str(e)}"}
 
223
 
224
+ def load_system_prompt():
225
+ """Load system prompt with error handling"""
226
+ try:
227
+ with open("system_prompt.txt", "r", encoding="utf-8") as f:
228
+ return f.read()
229
+ except FileNotFoundError:
230
+ logger.warning("system_prompt.txt not found, using default prompt")
231
+ return "You are a helpful AI assistant."
232
 
233
+ system_prompt = load_system_prompt()
 
 
 
 
234
  sys_msg = SystemMessage(content=system_prompt)
235
 
 
236
  tools = [
237
  multiply,
238
  add,
 
241
  modulus,
242
  wiki_search,
243
  web_search,
244
+ arxiv_search,
245
  ]
246
 
247
+ def build_graph(provider: str = "direct_groq", model: str = "llama-3.1-70b-versatile"):
248
+ """Build the graph with direct Groq API and custom rate limiting"""
249
+
250
+ if provider == "direct_groq":
251
+ # Use custom Groq wrapper with rate limiting
252
+ llm = GroqWrapper(model=model, rpm=30, tpm=6000) # Adjust based on your plan
253
+
254
+ elif provider == "langchain_groq":
255
+ # Use LangChain's ChatGroq with native rate limiting
256
+ from langchain_core.rate_limiters import InMemoryRateLimiter
257
+
258
+ rate_limiter = InMemoryRateLimiter(
259
+ requests_per_second=0.5, # 30 RPM
260
+ check_every_n_seconds=0.1,
261
+ max_bucket_size=5,
262
+ )
263
+
264
+ from langchain_groq import ChatGroq
265
+ llm = ChatGroq(
266
+ model=model,
267
+ temperature=0,
268
+ groq_api_key=os.getenv("GROQ_API_KEY"),
269
+ rate_limiter=rate_limiter
270
  )
271
  else:
272
+ raise ValueError("Choose 'direct_groq' or 'langchain_groq'")
273
+
274
  # Bind tools to LLM
275
  llm_with_tools = llm.bind_tools(tools)
276
 
 
277
  def assistant(state: MessagesState):
278
  """Assistant node"""
279
+ try:
280
+ response = llm_with_tools.invoke(state["messages"])
281
+ return {"messages": [response]}
282
+ except Exception as e:
283
+ logger.error(f"Assistant failed: {e}")
284
+ error_msg = AIMessage(content=f"I encountered an error: {str(e)}")
285
+ return {"messages": [error_msg]}
286
 
287
+ # Build the graph
288
  builder = StateGraph(MessagesState)
289
  builder.add_node("assistant", assistant)
290
  builder.add_node("tools", ToolNode(tools))
291
  builder.add_edge(START, "assistant")
292
+ builder.add_conditional_edges("assistant", tools_condition)
 
 
 
293
  builder.add_edge("tools", "assistant")
294
 
 
295
  return builder.compile()
296
 
 
297
  if __name__ == "__main__":
298
  question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
299
+
300
+ try:
301
+ # Test with direct Groq API
302
+ graph = build_graph(provider="direct_groq")
303
+ messages = [HumanMessage(content=question)]
304
+ result = graph.invoke({"messages": messages})
305
+
306
+ for m in result["messages"]:
307
+ m.pretty_print()
308
+
309
+ except Exception as e:
310
+ logger.error(f"Test failed: {e}")