Rajan Sharma commited on
Commit
327f806
·
verified ·
1 Parent(s): 15441ed

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +97 -26
app.py CHANGED
@@ -1,11 +1,19 @@
1
- import os, requests, gradio as gr
 
 
 
2
  from safety import safety_filter, refusal_reply
3
 
4
- HF_API_URL = os.getenv("HF_API_URL", "https://api-inference.huggingface.co/models/tiiuae/Falcon3-7B-Instruct")
5
- HF_TOKEN = os.getenv("HF_TOKEN") # store in Secrets (not Variables)
6
- MAX_NEW = int(os.getenv("MAX_NEW", "256"))
7
- TEMP = float(os.getenv("TEMP", "0.7"))
8
- TOP_P = float(os.getenv("TOP_P", "0.9"))
 
 
 
 
 
9
 
10
  HEADERS = {"Authorization": f"Bearer {HF_TOKEN}"} if HF_TOKEN else {}
11
 
@@ -15,6 +23,12 @@ SYSTEM_MSG = (
15
  "decline with a brief rationale and offer safer alternatives."
16
  )
17
 
 
 
 
 
 
 
18
  def build_prompt(history, user_input, max_turns=5):
19
  turns = history[-max_turns:] if history else []
20
  parts = [f"System: {SYSTEM_MSG}"]
@@ -25,7 +39,28 @@ def build_prompt(history, user_input, max_turns=5):
25
  parts.append("Assistant:")
26
  return "\n".join(parts)
27
 
28
- def call_hf_api(prompt):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  payload = {
30
  "inputs": prompt,
31
  "parameters": {
@@ -35,48 +70,84 @@ def call_hf_api(prompt):
35
  "stop": ["\nUser:", "\nSystem:"]
36
  }
37
  }
38
- r = requests.post(HF_API_URL, headers=HEADERS, json=payload, timeout=120)
 
 
 
39
  r.raise_for_status()
40
- data = r.json()
41
- if isinstance(data, list) and data and "generated_text" in data[0]:
42
- return data[0]["generated_text"]
43
- if isinstance(data, dict) and "generated_text" in data:
44
- return data["generated_text"]
45
- if isinstance(data, dict) and "error" in data:
46
- return f"⚠️ API error: {data['error']}"
47
- return str(data)
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
- def vedagi_chat(user_input, history):
 
 
 
 
50
  safe_in, blocked_in, reason_in = safety_filter(user_input, mode="input")
51
  if blocked_in:
52
- return history + [[user_input, refusal_reply(reason_in)]]
53
 
 
54
  prompt = build_prompt(history, safe_in)
55
  try:
56
- out = call_hf_api(prompt)
57
  except Exception as e:
58
- out = f"⚠️ API request failed: {e}"
59
 
60
- # trim echoes
61
  for tag in ("Assistant:", "System:", "User:"):
62
- if out.startswith(tag):
63
  out = out[len(tag):].strip()
64
 
 
65
  safe_out, blocked_out, reason_out = safety_filter(out, mode="output")
66
  if blocked_out:
67
  safe_out = refusal_reply(reason_out)
68
 
69
- return history + [[user_input, safe_out]]
 
 
70
 
 
 
 
71
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
72
- gr.Markdown("# 🌸 Ved AGI — v1 (HF Inference API)\nFalcon via serverless API + safety wrapper.")
 
 
73
  chat = gr.Chatbot(height=430)
74
  box = gr.Textbox(placeholder="Ask Ved AGI…", autofocus=True, label="Message")
75
  clear = gr.Button("Clear")
76
 
77
- box.submit(vedagi_chat, [box, chat], [chat])
78
- clear.click(lambda: None, None, [chat])
 
 
 
79
 
80
  if __name__ == "__main__":
 
 
 
81
  demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))
82
 
 
 
1
+ import os
2
+ import time
3
+ import requests
4
+ import gradio as gr
5
  from safety import safety_filter, refusal_reply
6
 
7
+ # =========================
8
+ # Config via env variables
9
+ # =========================
10
+ FALCON_URL = os.getenv("HF_API_URL_FALCON", "").strip() # e.g., https://api-inference.huggingface.co/models/tiiuae/falcon-7b-instruct
11
+ PRIMARY_URL = os.getenv("HF_API_URL_PRIMARY", "").strip() # e.g., https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.3
12
+ HF_TOKEN = os.getenv("HF_TOKEN", "").strip()
13
+
14
+ MAX_NEW = int(os.getenv("MAX_NEW", "256"))
15
+ TEMP = float(os.getenv("TEMP", "0.7"))
16
+ TOP_P = float(os.getenv("TOP_P", "0.9"))
17
 
18
  HEADERS = {"Authorization": f"Bearer {HF_TOKEN}"} if HF_TOKEN else {}
19
 
 
23
  "decline with a brief rationale and offer safer alternatives."
24
  )
25
 
26
+ # Ordered preference: try Falcon first, then primary (Mistral, Llama, etc.)
27
+ API_ORDER = [u for u in [FALCON_URL, PRIMARY_URL] if u]
28
+
29
+ # ==============
30
+ # Prompt builder
31
+ # ==============
32
  def build_prompt(history, user_input, max_turns=5):
33
  turns = history[-max_turns:] if history else []
34
  parts = [f"System: {SYSTEM_MSG}"]
 
39
  parts.append("Assistant:")
40
  return "\n".join(parts)
41
 
42
+ # ======================
43
+ # Inference API routines
44
+ # ======================
45
+ def parse_api_response(data):
46
+ """
47
+ The HF Inference API can return:
48
+ - [{'generated_text': '...'}]
49
+ - {'generated_text': '...'}
50
+ - {'error': '...'}
51
+ - other shapes depending on backend
52
+ """
53
+ if isinstance(data, list) and data and isinstance(data[0], dict):
54
+ if "generated_text" in data[0]:
55
+ return data[0]["generated_text"]
56
+ if isinstance(data, dict):
57
+ if "generated_text" in data:
58
+ return data["generated_text"]
59
+ if "error" in data:
60
+ return f"⚠️ API error: {data['error']}"
61
+ return str(data)
62
+
63
+ def call_api_once(url, prompt, timeout=120):
64
  payload = {
65
  "inputs": prompt,
66
  "parameters": {
 
70
  "stop": ["\nUser:", "\nSystem:"]
71
  }
72
  }
73
+ r = requests.post(url, headers=HEADERS, json=payload, timeout=timeout)
74
+ if r.status_code == 503:
75
+ # model is loading; surface a friendly message for logs/UI
76
+ return None, f"🕙 Model cold start on {url}. Try again in a few seconds."
77
  r.raise_for_status()
78
+ return parse_api_response(r.json()), None
79
+
80
+ def query_with_fallback(prompt):
81
+ """
82
+ Try endpoints in order. Returns (text, backend_label).
83
+ Raises last error if all fail.
84
+ """
85
+ last_err = None
86
+ for url in API_ORDER:
87
+ label = "Falcon" if url == FALCON_URL else "Primary"
88
+ try:
89
+ data, cold = call_api_once(url, prompt)
90
+ if cold:
91
+ # brief wait & retry same URL once
92
+ time.sleep(2)
93
+ data, cold = call_api_once(url, prompt)
94
+ if data:
95
+ return data, label
96
+ except Exception as e:
97
+ last_err = e
98
+ raise RuntimeError(f"All API endpoints failed. Last error: {last_err}")
99
 
100
+ # =====================
101
+ # Chat + Safety wrapper
102
+ # =====================
103
+ def vedagi_chat(user_input, history, status):
104
+ # Pre-filter (RealSafe-style input)
105
  safe_in, blocked_in, reason_in = safety_filter(user_input, mode="input")
106
  if blocked_in:
107
+ return history + [[user_input, refusal_reply(reason_in)]], status
108
 
109
+ # Build prompt and query API (with fallback)
110
  prompt = build_prompt(history, safe_in)
111
  try:
112
+ out, backend = query_with_fallback(prompt)
113
  except Exception as e:
114
+ out, backend = f"⚠️ API request failed: {e}", "Offline"
115
 
116
+ # Tidy echoes
117
  for tag in ("Assistant:", "System:", "User:"):
118
+ if isinstance(out, str) and out.startswith(tag):
119
  out = out[len(tag):].strip()
120
 
121
+ # Post-filter (RealSafe-style output)
122
  safe_out, blocked_out, reason_out = safety_filter(out, mode="output")
123
  if blocked_out:
124
  safe_out = refusal_reply(reason_out)
125
 
126
+ # Update banner with active backend
127
+ status = f"**Backend:** {backend} • **MAX_NEW:** {MAX_NEW} • **TEMP:** {TEMP} • **TOP_P:** {TOP_P}"
128
+ return history + [[user_input, safe_out]], status
129
 
130
+ # =====
131
+ # UI
132
+ # =====
133
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
134
+ gr.Markdown("# 🌸 Ved AGI — v1 (HF Inference API)")
135
+ status = gr.Markdown("**Backend:** (probing…) • **MAX_NEW:** "
136
+ f"{MAX_NEW} • **TEMP:** {TEMP} • **TOP_P:** {TOP_P}")
137
  chat = gr.Chatbot(height=430)
138
  box = gr.Textbox(placeholder="Ask Ved AGI…", autofocus=True, label="Message")
139
  clear = gr.Button("Clear")
140
 
141
+ def _respond(msg, hist, stat):
142
+ return vedagi_chat(msg, hist, stat)
143
+
144
+ box.submit(_respond, [box, chat, status], [chat, status])
145
+ clear.click(lambda: ([], "Ready."), None, [chat, status])
146
 
147
  if __name__ == "__main__":
148
+ # Ensure we have at least one URL
149
+ if not API_ORDER:
150
+ raise RuntimeError("No API endpoints configured. Set HF_API_URL_FALCON and/or HF_API_URL_PRIMARY in Variables.")
151
  demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))
152
 
153
+