Spaces:
Running
Running
| import gradio as gr | |
| from openai import OpenAI | |
| import snowflake.connector | |
| import os | |
| import json | |
| from decimal import Decimal | |
| from datetime import date, datetime | |
| from urllib.parse import urlencode | |
| from utils.functions import ( | |
| intraday_stock_prices, | |
| daily_stock_prices, | |
| get_income_statement, | |
| ticker_search, | |
| company_profile, | |
| current_market_cap, | |
| historical_market_cap, | |
| analyst_recommendations, | |
| stock_peers, | |
| earnings_historical_and_upcoming | |
| ) | |
| # Initialize OpenAI client | |
| client = OpenAI(api_key=os.environ["OPENAI_API_KEY"]) | |
| def fetch_trades_of_the_day(): | |
| """ | |
| Fetches RSI data from Snowflake, calculates future return percentages based on a lead value, | |
| and returns the data as a formatted JSON string. | |
| Parameters: | |
| - horizon (int): Number of days to use in the lead function to calculate future return. Default is 21. | |
| Returns: | |
| - json_data (str): Formatted JSON string containing the RSI data. | |
| """ | |
| def custom_json_serializer(obj): | |
| """ Custom JSON serializer for handling date objects and Decimal types """ | |
| if isinstance(obj, (datetime, date)): | |
| return obj.isoformat() # Convert date/datetime to ISO format | |
| elif isinstance(obj, Decimal): | |
| return float(obj) # Convert Decimal to float | |
| raise TypeError(f"Type {type(obj)} not serializable") | |
| try: | |
| # Establish connection to Snowflake | |
| conn = snowflake.connector.connect( | |
| user=os.environ['SNOWFLAKE_USER'], | |
| password=os.environ['SNOWFLAKE_PW'], | |
| account=os.environ['SNOWFLAKE_ACCOUNT'], | |
| warehouse=os.environ['SNOWFLAKE_WH'], | |
| database=os.environ['SNOWFLAKE_DB'], | |
| schema=os.environ['SNOWFLAKE_SCHEMA'] | |
| ) | |
| # Define the query | |
| # query = os.environ['QUERY'] | |
| query = "select BEST_TRADE_STRING from RESEARCHDATA.RSI_TRADE_OF_THE_DAY rs order by rk desc;" | |
| # Execute the query and fetch data | |
| cur = conn.cursor() | |
| rows = cur.execute(query).fetchall() | |
| columns = [desc[0] for desc in cur.description] # Get column names | |
| # Close the cursor and connection | |
| cur.close() | |
| conn.close() | |
| # Convert the rows into a list of dictionaries (for JSON serialization) | |
| result = [dict(zip(columns, row)) for row in rows] | |
| # Convert the result to a formatted JSON string, with the custom serializer | |
| json_data = json.dumps(result, indent=4, default=custom_json_serializer) | |
| return json_data | |
| except Exception as e: | |
| print(f"Failed to connect to Snowflake: {e}") | |
| return None | |
| # Function to interact with the OpenAI assistant | |
| def interact_with_assistant(user_input): | |
| thread = client.beta.threads.create() | |
| client.beta.threads.messages.create( | |
| thread_id=thread.id, | |
| role="user", | |
| content=user_input, | |
| ) | |
| run = client.beta.threads.runs.create( | |
| thread_id=thread.id, | |
| assistant_id= os.environ['ASSISTANT_ID'], | |
| ) | |
| while run.status != 'completed': | |
| run = client.beta.threads.runs.retrieve(thread_id=thread.id, run_id=run.id) | |
| if run.status == 'requires_action': | |
| tool_outputs = [] | |
| for tool_call in run.required_action.submit_tool_outputs.tool_calls: | |
| if tool_call.function.name == "fetch_trades_of_the_day": | |
| output = fetch_trades_of_the_day() | |
| tool_outputs.append({"tool_call_id": tool_call.id, "output": output}) | |
| client.beta.threads.runs.submit_tool_outputs( | |
| thread_id=thread.id, | |
| run_id=run.id, | |
| tool_outputs=tool_outputs | |
| ) | |
| messages = client.beta.threads.messages.list(thread_id=thread.id) | |
| return messages.data[0].content[0].text.value | |
| def fetch_best_trades(): | |
| try: | |
| return interact_with_assistant("What are the best trades for today?") | |
| except Exception as e: | |
| return f"An error occurred: {str(e)}" | |
| css = """ | |
| body { | |
| font-family: Arial, sans-serif; | |
| background-color: #f0f2f5; | |
| } | |
| .container { | |
| margin: 0 auto; | |
| padding: 20px; | |
| background-color: white; | |
| border-radius: 10px; | |
| box-shadow: 0 2px 10px rgba(0, 0, 0, 0.1); | |
| } | |
| .output-box { | |
| margin-bottom: 20px; | |
| } | |
| """ | |
| with gr.Blocks(css=css) as iface: | |
| with gr.Column(elem_classes="container"): | |
| gr.Markdown("# π Stock Market Assistant") | |
| gr.Markdown("Get insights on the best trades for today based on RSI data.") | |
| output = gr.Textbox( | |
| label="Trade Recommendations", | |
| lines=20, # Doubled from 10 to 20 | |
| interactive=False, | |
| elem_classes="output-box" | |
| ) | |
| fetch_button = gr.Button("π Fetch me the best trades for today", variant="primary") | |
| fetch_button.click(fn=fetch_best_trades, outputs=output) | |
| iface.launch() |