Spaces:
Paused
Paused
| import streamlit as st | |
| import asyncio | |
| import re | |
| import os | |
| from llama_cpp import Llama | |
| import requests | |
| from bs4 import BeautifulSoup | |
| # Set page configuration | |
| st.set_page_config(page_title="Security Assistant", page_icon="π", layout="wide") | |
| # Custom CSS for styling | |
| st.markdown( | |
| """ | |
| <style> | |
| .user-message { background-color: #DCF8C6; padding: 10px; border-radius: 10px; margin: 5px 0; } | |
| .assistant-message { background-color: #E9ECEF; padding: 10px; border-radius: 10px; margin: 5px 0; } | |
| .tool-output { background-color: #F8F9FA; padding: 10px; border-radius: 10px; border: 1px solid #DEE2E6; } | |
| </style> | |
| """, | |
| unsafe_allow_html=True | |
| ) | |
| # Cache the model loading | |
| def load_model(): | |
| # Model path consistent across environments | |
| model_path = os.path.join("models", "pentest_ai.Q4_0.gguf") | |
| if not os.path.exists(model_path): | |
| st.error(f"Model file not found at {model_path}. Please ensure itβs placed correctly.") | |
| return None | |
| try: | |
| model = Llama(model_path=model_path, n_ctx=2048, n_threads=4, verbose=False) | |
| return model | |
| except Exception as e: | |
| st.error(f"Failed to load model: {e}") | |
| return None | |
| # Execute tools asynchronously | |
| async def run_tool(command: str) -> str: | |
| try: | |
| process = await asyncio.create_subprocess_shell( | |
| command, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE | |
| ) | |
| stdout, stderr = await process.communicate() | |
| return stdout.decode() if stdout else stderr.decode() | |
| except Exception as e: | |
| return f"Error executing tool: {str(e)}" | |
| # Fetch vulnerability info via web scraping (no API keys) | |
| def get_vulnerability_info(query: str) -> str: | |
| try: | |
| url = f"https://cve.mitre.org/cgi-bin/cvekey.cgi?keyword={query}" | |
| response = requests.get(url, timeout=10) | |
| soup = BeautifulSoup(response.text, "html.parser") | |
| results = soup.find_all("tr")[1:6] # Top 5 results | |
| vulns = [f"{row.find_all('td')[0].text}: {row.find_all('td')[1].text}" for row in results] | |
| return "\n".join(vulns) if vulns else "No vulnerabilities found." | |
| except Exception as e: | |
| return f"Error fetching vulnerability data: {str(e)}" | |
| # Session state management | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [] | |
| # Add message to chat history | |
| def add_message(content: str, is_user: bool): | |
| st.session_state.messages.append({"content": content, "is_user": is_user}) | |
| # Render chat history | |
| def render_chat(): | |
| for msg in st.session_state.messages: | |
| bubble_class = "user-message" if msg["is_user"] else "assistant-message" | |
| st.markdown(f'<div class="{bubble_class}">{msg["content"]}</div>', unsafe_allow_html=True) | |
| # Main application | |
| def main(): | |
| st.title("π Open-Source Security Assistant") | |
| st.markdown("Powered by pentest_ai.Q4_0.gguf. Runs locally or on Hugging Face Spaces.") | |
| # Sidebar for settings | |
| with st.sidebar: | |
| max_tokens = st.slider("Max Tokens", 128, 1024, 256) | |
| if st.button("Clear Chat"): | |
| st.session_state.messages = [] | |
| # Load model | |
| model = load_model() | |
| if not model: | |
| st.warning("Model loading failed. Check logs or ensure the model file is available.") | |
| return | |
| render_chat() | |
| # Chat input form | |
| with st.form("chat_form", clear_on_submit=True): | |
| user_input = st.text_area("Ask a security question...", height=100) | |
| submit = st.form_submit_button("Send") | |
| if submit and user_input: | |
| add_message(user_input, True) | |
| with st.spinner("Processing..."): | |
| # Prepare prompt | |
| system_prompt = """ | |
| You are a cybersecurity assistant with expertise in penetration testing. | |
| Provide concise, actionable insights. Use [TOOL: tool_name ARGS: "args"] for tool suggestions. | |
| """ | |
| full_prompt = f"{system_prompt}\nUser: {user_input}\nAssistant:" | |
| # Generate response | |
| response = model.create_completion( | |
| full_prompt, max_tokens=max_tokens, temperature=0.7, stop=["User:"] | |
| ) | |
| generated_text = response["choices"][0]["text"].strip() | |
| # Parse for tool execution | |
| tool_pattern = r"\[TOOL: (\w+) ARGS: \"(.*?)\"\]" | |
| match = re.search(tool_pattern, generated_text) | |
| if match: | |
| tool_name, args = match.groups() | |
| tool_output = asyncio.run(run_tool(f"{tool_name} {args}")) | |
| generated_text += f"\n\n<div class='tool-output'>Tool Output:\n{tool_output}</div>" | |
| # Handle vulnerability lookups | |
| if "vulnerability" in user_input.lower(): | |
| query = user_input.split()[-1] # Simplified query extraction | |
| vulns = get_vulnerability_info(query) | |
| generated_text += f"\n\nVulnerability Data:\n{vulns}" | |
| add_message(generated_text, False) | |
| if __name__ == "__main__": | |
| main() |