CyberPredator / app.py
Elprofessore's picture
Create app.py
7d04911 verified
import streamlit as st
import asyncio
import re
import os
from llama_cpp import Llama
import requests
from bs4 import BeautifulSoup
# Set page configuration
st.set_page_config(page_title="Security Assistant", page_icon="πŸ”’", layout="wide")
# Custom CSS for styling
st.markdown(
"""
<style>
.user-message { background-color: #DCF8C6; padding: 10px; border-radius: 10px; margin: 5px 0; }
.assistant-message { background-color: #E9ECEF; padding: 10px; border-radius: 10px; margin: 5px 0; }
.tool-output { background-color: #F8F9FA; padding: 10px; border-radius: 10px; border: 1px solid #DEE2E6; }
</style>
""",
unsafe_allow_html=True
)
# Cache the model loading
@st.cache_resource
def load_model():
# Model path consistent across environments
model_path = os.path.join("models", "pentest_ai.Q4_0.gguf")
if not os.path.exists(model_path):
st.error(f"Model file not found at {model_path}. Please ensure it’s placed correctly.")
return None
try:
model = Llama(model_path=model_path, n_ctx=2048, n_threads=4, verbose=False)
return model
except Exception as e:
st.error(f"Failed to load model: {e}")
return None
# Execute tools asynchronously
async def run_tool(command: str) -> str:
try:
process = await asyncio.create_subprocess_shell(
command, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
)
stdout, stderr = await process.communicate()
return stdout.decode() if stdout else stderr.decode()
except Exception as e:
return f"Error executing tool: {str(e)}"
# Fetch vulnerability info via web scraping (no API keys)
def get_vulnerability_info(query: str) -> str:
try:
url = f"https://cve.mitre.org/cgi-bin/cvekey.cgi?keyword={query}"
response = requests.get(url, timeout=10)
soup = BeautifulSoup(response.text, "html.parser")
results = soup.find_all("tr")[1:6] # Top 5 results
vulns = [f"{row.find_all('td')[0].text}: {row.find_all('td')[1].text}" for row in results]
return "\n".join(vulns) if vulns else "No vulnerabilities found."
except Exception as e:
return f"Error fetching vulnerability data: {str(e)}"
# Session state management
if "messages" not in st.session_state:
st.session_state.messages = []
# Add message to chat history
def add_message(content: str, is_user: bool):
st.session_state.messages.append({"content": content, "is_user": is_user})
# Render chat history
def render_chat():
for msg in st.session_state.messages:
bubble_class = "user-message" if msg["is_user"] else "assistant-message"
st.markdown(f'<div class="{bubble_class}">{msg["content"]}</div>', unsafe_allow_html=True)
# Main application
def main():
st.title("πŸ”’ Open-Source Security Assistant")
st.markdown("Powered by pentest_ai.Q4_0.gguf. Runs locally or on Hugging Face Spaces.")
# Sidebar for settings
with st.sidebar:
max_tokens = st.slider("Max Tokens", 128, 1024, 256)
if st.button("Clear Chat"):
st.session_state.messages = []
# Load model
model = load_model()
if not model:
st.warning("Model loading failed. Check logs or ensure the model file is available.")
return
render_chat()
# Chat input form
with st.form("chat_form", clear_on_submit=True):
user_input = st.text_area("Ask a security question...", height=100)
submit = st.form_submit_button("Send")
if submit and user_input:
add_message(user_input, True)
with st.spinner("Processing..."):
# Prepare prompt
system_prompt = """
You are a cybersecurity assistant with expertise in penetration testing.
Provide concise, actionable insights. Use [TOOL: tool_name ARGS: "args"] for tool suggestions.
"""
full_prompt = f"{system_prompt}\nUser: {user_input}\nAssistant:"
# Generate response
response = model.create_completion(
full_prompt, max_tokens=max_tokens, temperature=0.7, stop=["User:"]
)
generated_text = response["choices"][0]["text"].strip()
# Parse for tool execution
tool_pattern = r"\[TOOL: (\w+) ARGS: \"(.*?)\"\]"
match = re.search(tool_pattern, generated_text)
if match:
tool_name, args = match.groups()
tool_output = asyncio.run(run_tool(f"{tool_name} {args}"))
generated_text += f"\n\n<div class='tool-output'>Tool Output:\n{tool_output}</div>"
# Handle vulnerability lookups
if "vulnerability" in user_input.lower():
query = user_input.split()[-1] # Simplified query extraction
vulns = get_vulnerability_info(query)
generated_text += f"\n\nVulnerability Data:\n{vulns}"
add_message(generated_text, False)
if __name__ == "__main__":
main()