| | """ |
| | Streamlit app - CatVader Social Media Classifier |
| | Migrated from Gradio for better mobile support |
| | """ |
| |
|
| | import streamlit as st |
| | import pandas as pd |
| | import tempfile |
| | import os |
| | import time |
| | import sys |
| | from datetime import datetime |
| | import matplotlib.pyplot as plt |
| |
|
| | |
| | try: |
| | import catvader |
| | CATVADER_AVAILABLE = True |
| | except ImportError as e: |
| | print(f"Warning: Could not import catvader: {e}") |
| | CATVADER_AVAILABLE = False |
| |
|
| | MAX_CATEGORIES = 10 |
| | INITIAL_CATEGORIES = 3 |
| | MAX_FILE_SIZE_MB = 100 |
| |
|
| | def count_pdf_pages(pdf_path): |
| | """Count the number of pages in a PDF file.""" |
| | try: |
| | import fitz |
| | doc = fitz.open(pdf_path) |
| | page_count = len(doc) |
| | doc.close() |
| | return page_count |
| | except Exception: |
| | return 1 |
| |
|
| |
|
| | def extract_text_from_pdfs(pdf_paths): |
| | """Extract text from all pages of all PDFs, returning list of page texts.""" |
| | import fitz |
| | all_texts = [] |
| | for pdf_path in pdf_paths: |
| | try: |
| | doc = fitz.open(pdf_path) |
| | for page in doc: |
| | text = page.get_text().strip() |
| | if text: |
| | all_texts.append(text) |
| | doc.close() |
| | except Exception as e: |
| | print(f"Error extracting text from {pdf_path}: {e}") |
| | return all_texts |
| |
|
| |
|
| | def extract_pdf_pages(pdf_paths, pdf_name_map, mode="image"): |
| | """ |
| | Extract individual pages from PDFs. |
| | Returns list of (page_data, page_label) tuples. |
| | For image mode: page_data is path to temp image file |
| | For text mode: page_data is extracted text |
| | """ |
| | import fitz |
| | pages = [] |
| |
|
| | for pdf_path in pdf_paths: |
| | orig_name = pdf_name_map.get(pdf_path, os.path.basename(pdf_path).replace('.pdf', '')) |
| | try: |
| | doc = fitz.open(pdf_path) |
| | for page_num, page in enumerate(doc, 1): |
| | page_label = f"{orig_name}_p{page_num}" |
| |
|
| | if mode == "text": |
| | |
| | text = page.get_text().strip() |
| | if text: |
| | pages.append((text, page_label, "text")) |
| | else: |
| | |
| | pix = page.get_pixmap(matrix=fitz.Matrix(2, 2)) |
| | img_path = tempfile.NamedTemporaryFile(delete=False, suffix='.png').name |
| | pix.save(img_path) |
| |
|
| | if mode == "both": |
| | text = page.get_text().strip() |
| | pages.append((img_path, page_label, "image", text)) |
| | else: |
| | pages.append((img_path, page_label, "image")) |
| | doc.close() |
| | except Exception as e: |
| | print(f"Error extracting pages from {pdf_path}: {e}") |
| |
|
| | return pages |
| |
|
| | |
| | FREE_MODELS_MAP = { |
| | "GPT-4o Mini": "gpt-4o-mini", |
| | "Gemini 2.5 Flash": "gemini-2.5-flash", |
| | "Claude 3 Haiku": "claude-3-haiku-20240307", |
| | "Llama 3.3 70B": "meta-llama/Llama-3.3-70B-Instruct:groq", |
| | "Qwen 2.5": "Qwen/Qwen2.5-72B-Instruct", |
| | "DeepSeek R1": "deepseek-ai/DeepSeek-R1:novita", |
| | "Mistral Medium": "mistral-medium-2505", |
| | "Grok 4 Fast": "grok-4-fast-non-reasoning", |
| | } |
| | FREE_MODEL_DISPLAY_NAMES = list(FREE_MODELS_MAP.keys()) |
| | FREE_MODEL_CHOICES = list(FREE_MODELS_MAP.values()) |
| |
|
| | |
| | PAID_MODEL_CHOICES = [ |
| | "gemini-2.5-flash", |
| | "gemini-2.5-pro", |
| | "gpt-4.1", |
| | "gpt-4o", |
| | "gpt-4o-mini", |
| | "claude-sonnet-4-5-20250929", |
| | "claude-opus-4-20250514", |
| | "claude-3-5-haiku-20241022", |
| | "mistral-large-latest", |
| | ] |
| |
|
| | |
| | HF_ROUTED_MODELS = [ |
| | "meta-llama/Llama-3.3-70B-Instruct:groq", |
| | "deepseek-ai/DeepSeek-R1:novita", |
| | ] |
| |
|
| |
|
| | def is_free_model(model, model_tier): |
| | """Check if using free tier (Space pays for API).""" |
| | return model_tier == "Free Models" |
| |
|
| |
|
| | def get_model_source(model): |
| | """Auto-detect model source.""" |
| | model_lower = model.lower() |
| | if "gpt" in model_lower: |
| | return "openai" |
| | elif "claude" in model_lower: |
| | return "anthropic" |
| | elif "gemini" in model_lower: |
| | return "google" |
| | elif "mistral" in model_lower and ":novita" not in model_lower: |
| | return "mistral" |
| | elif any(x in model_lower for x in [":novita", ":groq", "qwen", "llama", "deepseek"]): |
| | return "huggingface" |
| | elif "sonar" in model_lower: |
| | return "perplexity" |
| | elif "grok" in model_lower: |
| | return "xai" |
| | return "huggingface" |
| |
|
| |
|
| | def get_api_key(model, model_tier, api_key_input): |
| | """Get the appropriate API key based on model and tier.""" |
| | if is_free_model(model, model_tier): |
| | if model in HF_ROUTED_MODELS: |
| | return os.environ.get("HF_API_KEY", ""), "HuggingFace" |
| | elif "gpt" in model.lower(): |
| | return os.environ.get("OPENAI_API_KEY", ""), "OpenAI" |
| | elif "gemini" in model.lower(): |
| | return os.environ.get("GOOGLE_API_KEY", ""), "Google" |
| | elif "mistral" in model.lower(): |
| | return os.environ.get("MISTRAL_API_KEY", ""), "Mistral" |
| | elif "claude" in model.lower(): |
| | return os.environ.get("ANTHROPIC_API_KEY", ""), "Anthropic" |
| | elif "sonar" in model.lower(): |
| | return os.environ.get("PERPLEXITY_API_KEY", ""), "Perplexity" |
| | elif "grok" in model.lower(): |
| | return os.environ.get("XAI_API_KEY", ""), "xAI" |
| | else: |
| | return os.environ.get("HF_API_KEY", ""), "HuggingFace" |
| | else: |
| | if api_key_input and api_key_input.strip(): |
| | return api_key_input.strip(), "User" |
| | return "", "User" |
| |
|
| |
|
| | def calculate_total_file_size(files): |
| | """Calculate total size of uploaded files in MB.""" |
| | if files is None: |
| | return 0 |
| | if not isinstance(files, list): |
| | files = [files] |
| |
|
| | total_bytes = 0 |
| | for f in files: |
| | try: |
| | if hasattr(f, 'size'): |
| | total_bytes += f.size |
| | elif hasattr(f, 'name'): |
| | total_bytes += os.path.getsize(f.name) |
| | except (OSError, AttributeError): |
| | pass |
| | return total_bytes / (1024 * 1024) |
| |
|
| |
|
| | def generate_extract_code(input_type, description, model, model_source, max_categories, mode=None): |
| | """Generate Python code for category extraction.""" |
| | if input_type == "text": |
| | return f'''import catvader |
| | import pandas as pd |
| | |
| | # Load your data |
| | df = pd.read_csv("your_data.csv") |
| | |
| | # Extract categories from the text column |
| | result = catvader.extract( |
| | input_data=df["{description}"].tolist(), |
| | api_key="YOUR_API_KEY", |
| | input_type="text", |
| | description="{description}", |
| | user_model="{model}", |
| | model_source="{model_source}", |
| | max_categories={max_categories} |
| | ) |
| | |
| | # View extracted categories |
| | print(result["top_categories"]) |
| | print(result["counts_df"]) |
| | ''' |
| | elif input_type == "pdf": |
| | mode_line = f',\n mode="{mode}"' if mode else '' |
| | return f'''import catvader |
| | |
| | # Extract categories from PDF documents |
| | result = catvader.extract( |
| | input_data="path/to/your/pdfs/", |
| | api_key="YOUR_API_KEY", |
| | input_type="pdf", |
| | description="{description}"{mode_line}, |
| | user_model="{model}", |
| | model_source="{model_source}", |
| | max_categories={max_categories} |
| | ) |
| | |
| | # View extracted categories |
| | print(result["top_categories"]) |
| | print(result["counts_df"]) |
| | ''' |
| | else: |
| | return f'''import catvader |
| | |
| | # Extract categories from images |
| | result = catvader.extract( |
| | input_data="path/to/your/images/", |
| | api_key="YOUR_API_KEY", |
| | input_type="image", |
| | description="{description}", |
| | user_model="{model}", |
| | model_source="{model_source}", |
| | max_categories={max_categories} |
| | ) |
| | |
| | # View extracted categories |
| | print(result["top_categories"]) |
| | print(result["counts_df"]) |
| | ''' |
| |
|
| |
|
| | def generate_full_code(extraction_params, classify_params): |
| | """Generate combined extract + classify code when categories were auto-extracted.""" |
| | ext = extraction_params |
| | cls = classify_params |
| |
|
| | |
| | if ext['input_type'] == "text": |
| | input_placeholder = 'df["your_column"].tolist()' |
| | load_data = '''import pandas as pd |
| | |
| | # Load your data |
| | df = pd.read_csv("your_data.csv") |
| | ''' |
| | elif ext['input_type'] == "pdf": |
| | input_placeholder = '"path/to/your/pdfs/"' |
| | load_data = '' |
| | else: |
| | input_placeholder = '"path/to/your/images/"' |
| | load_data = '' |
| |
|
| | mode_param = f',\n mode="{ext["mode"]}"' if ext.get('mode') else '' |
| |
|
| | |
| | extract_code = f'''# Step 1: Extract categories from your data |
| | extract_result = catvader.extract( |
| | input_data={input_placeholder}, |
| | api_key="YOUR_API_KEY", |
| | description="{ext['description']}", |
| | user_model="{ext['model']}", |
| | max_categories={ext['max_categories']}{mode_param} |
| | ) |
| | |
| | categories = extract_result["top_categories"] |
| | print(f"Extracted {{len(categories)}} categories: {{categories}}") |
| | ''' |
| |
|
| | |
| | if cls['classify_mode'] == "Single Model": |
| | classify_mode_param = f',\n mode="{cls["mode"]}"' if cls.get('mode') and ext['input_type'] == "pdf" else '' |
| | classify_code = f''' |
| | # Step 2: Classify data using extracted categories |
| | result = catvader.classify( |
| | input_data={input_placeholder}, |
| | categories=categories, |
| | api_key="YOUR_API_KEY", |
| | description="{cls['description']}", |
| | user_model="{cls['model']}"{classify_mode_param} |
| | )''' |
| | else: |
| | |
| | ens_runs = cls.get('ensemble_runs') |
| | model_lines = [] |
| | if ens_runs: |
| | for m, temp in ens_runs: |
| | model_lines.append(f'("{m}", "auto", "YOUR_API_KEY", {{"creativity": {temp}}})') |
| | else: |
| | model_temps = cls.get('model_temperatures', {}) |
| | for m in cls['models_list']: |
| | temp = model_temps.get(m) if model_temps else None |
| | if temp is not None: |
| | model_lines.append(f'("{m}", "auto", "YOUR_API_KEY", {{"creativity": {temp}}})') |
| | else: |
| | model_lines.append(f'("{m}", "auto", "YOUR_API_KEY")') |
| | models_str = ",\n ".join(model_lines) |
| |
|
| | classify_mode_param = f',\n mode="{cls["mode"]}"' if cls.get('mode') and ext['input_type'] == "pdf" else '' |
| | threshold_str = "majority" if cls['consensus_threshold'] == 0.5 else "two-thirds" if cls['consensus_threshold'] == 0.67 else "unanimous" |
| | consensus_param = f',\n consensus_threshold="{threshold_str}"' if cls['classify_mode'] == "Ensemble" else '' |
| |
|
| | classify_code = f''' |
| | # Step 2: Classify data using extracted categories with {"ensemble voting" if cls['classify_mode'] == "Ensemble" else "model comparison"} |
| | models = [ |
| | {models_str} |
| | ] |
| | |
| | result = catvader.classify( |
| | input_data={input_placeholder}, |
| | categories=categories, |
| | models=models, |
| | description="{cls['description']}"{classify_mode_param}{consensus_param} |
| | )''' |
| |
|
| | return f'''import catvader |
| | {load_data} |
| | {extract_code} |
| | {classify_code} |
| | |
| | # View results |
| | print(result) |
| | result.to_csv("classified_results.csv", index=False) |
| | ''' |
| |
|
| |
|
| | def generate_classify_code(input_type, description, categories, model, model_source, mode=None, |
| | classify_mode="Single Model", models_list=None, consensus_threshold=0.5, |
| | model_temperatures=None, ensemble_runs=None): |
| | """Generate Python code for classification.""" |
| | categories_str = ",\n ".join([f'"{cat}"' for cat in categories]) |
| |
|
| | |
| | if input_type == "text": |
| | input_placeholder = 'df["your_column"].tolist()' |
| | load_data = '''import pandas as pd |
| | |
| | # Load your data |
| | df = pd.read_csv("your_data.csv") |
| | ''' |
| | elif input_type == "pdf": |
| | input_placeholder = '"path/to/your/pdfs/"' |
| | load_data = '' |
| | else: |
| | input_placeholder = '"path/to/your/images/"' |
| | load_data = '' |
| |
|
| | |
| | if classify_mode == "Single Model": |
| | |
| | mode_param = f',\n mode="{mode}"' if mode and input_type == "pdf" else '' |
| | return f'''import catvader |
| | {load_data} |
| | # Define categories |
| | categories = [ |
| | {categories_str} |
| | ] |
| | |
| | # Classify data (input type is auto-detected) |
| | result = catvader.classify( |
| | input_data={input_placeholder}, |
| | categories=categories, |
| | api_key="YOUR_API_KEY", |
| | description="{description}", |
| | user_model="{model}"{mode_param} |
| | ) |
| | |
| | # View results |
| | print(result) |
| | result.to_csv("classified_results.csv", index=False) |
| | ''' |
| | else: |
| | |
| | |
| | if ensemble_runs: |
| | |
| | model_lines = [] |
| | for m, temp in ensemble_runs: |
| | model_lines.append(f'("{m}", "auto", "YOUR_API_KEY", {{"creativity": {temp}}})') |
| | models_str = ",\n ".join(model_lines) |
| | elif models_list: |
| | model_lines = [] |
| | for m in models_list: |
| | temp = model_temperatures.get(m) if model_temperatures else None |
| | if temp is not None: |
| | model_lines.append(f'("{m}", "auto", "YOUR_API_KEY", {{"creativity": {temp}}})') |
| | else: |
| | model_lines.append(f'("{m}", "auto", "YOUR_API_KEY")') |
| | models_str = ",\n ".join(model_lines) |
| | else: |
| | models_str = '("gpt-4o", "auto", "YOUR_API_KEY"),\n ("claude-sonnet-4-5-20250929", "auto", "YOUR_API_KEY")' |
| |
|
| | mode_param = f',\n mode="{mode}"' if mode and input_type == "pdf" else '' |
| | |
| | threshold_str = "majority" if consensus_threshold == 0.5 else "two-thirds" if consensus_threshold == 0.67 else "unanimous" |
| | consensus_param = f',\n consensus_threshold="{threshold_str}"' if classify_mode == "Ensemble" else '' |
| |
|
| | return f'''import catvader |
| | {load_data} |
| | # Define categories |
| | categories = [ |
| | {categories_str} |
| | ] |
| | |
| | # Define models for {"ensemble voting" if classify_mode == "Ensemble" else "comparison"} |
| | models = [ |
| | {models_str} |
| | ] |
| | |
| | # Classify with multiple models |
| | result = catvader.classify( |
| | input_data={input_placeholder}, |
| | categories=categories, |
| | models=models, |
| | description="{description}"{mode_param}{consensus_param} |
| | ) |
| | |
| | # View results |
| | print(result) |
| | result.to_csv("classified_results.csv", index=False) |
| | ''' |
| |
|
| |
|
| | def generate_methodology_report_pdf(categories, model, column_name, num_rows, model_source, filename, success_rate, |
| | result_df=None, processing_time=None, prompt_template=None, |
| | data_quality=None, catvader_version=None, python_version=None, |
| | task_type="assign", extracted_categories_df=None, max_categories=None, |
| | input_type="text", description=None, classify_mode="Single Model", |
| | models_list=None, code=None, consensus_threshold=None): |
| | """Generate a PDF methodology report.""" |
| | from reportlab.lib.pagesizes import letter |
| | from reportlab.lib import colors |
| | from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle |
| | from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer, Table, TableStyle, PageBreak |
| |
|
| | pdf_file = tempfile.NamedTemporaryFile(mode='wb', suffix='_methodology_report.pdf', delete=False) |
| | doc = SimpleDocTemplate(pdf_file.name, pagesize=letter) |
| | styles = getSampleStyleSheet() |
| |
|
| | title_style = ParagraphStyle('Title', parent=styles['Heading1'], fontSize=18, spaceAfter=20) |
| | heading_style = ParagraphStyle('Heading', parent=styles['Heading2'], fontSize=14, spaceAfter=10, spaceBefore=15) |
| | normal_style = styles['Normal'] |
| | code_style = ParagraphStyle('Code', parent=styles['Normal'], fontName='Courier', fontSize=9, leftIndent=20, spaceAfter=3) |
| |
|
| | story = [] |
| |
|
| | if task_type == "extract_and_assign": |
| | report_title = "CatVader Extraction & Classification Report" |
| | else: |
| | report_title = "CatVader Classification Report" |
| |
|
| | story.append(Paragraph(report_title, title_style)) |
| | story.append(Paragraph(f"Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", normal_style)) |
| | story.append(Spacer(1, 15)) |
| |
|
| | story.append(Paragraph("About This Report", heading_style)) |
| |
|
| | if task_type == "extract_and_assign": |
| | about_text = """This methodology report documents the automated category extraction and classification process. \ |
| | CatVader first discovers categories from your data using LLMs, then classifies each item into those categories.""" |
| | else: |
| | about_text = """This methodology report documents the classification process for reproducibility and transparency. \ |
| | CatVader restricts the prompt to a standard template that is impartial to the researcher's inclinations, ensuring \ |
| | consistent and reproducible results.""" |
| |
|
| | story.append(Paragraph(about_text, normal_style)) |
| | story.append(Spacer(1, 15)) |
| |
|
| | if categories: |
| | story.append(Paragraph("Category Mapping", heading_style)) |
| |
|
| | if classify_mode in ("Ensemble", "Model Comparison") and result_df is not None: |
| | |
| | story.append(Paragraph("Each model produces its own binary columns. " |
| | "Consensus columns show the majority vote result.", normal_style)) |
| | story.append(Spacer(1, 8)) |
| |
|
| | |
| | |
| | all_suffixes = _find_all_model_suffixes(result_df) |
| |
|
| | category_data = [["Column Name", "Category Description"]] |
| | for i, cat in enumerate(categories, 1): |
| | |
| | for suffix in all_suffixes: |
| | category_data.append([f"category_{i}_{suffix}", f"{cat} ({suffix})"]) |
| | |
| | category_data.append([f"category_{i}_consensus", f"{cat} (consensus)"]) |
| | category_data.append([f"category_{i}_agreement", f"{cat} (agreement score)"]) |
| |
|
| | cat_table = Table(category_data, colWidths=[200, 250]) |
| | else: |
| | |
| | story.append(Paragraph("Each category column contains binary values: 1 = present, 0 = not present", normal_style)) |
| | story.append(Spacer(1, 8)) |
| |
|
| | category_data = [["Column Name", "Category Description"]] |
| | for i, cat in enumerate(categories, 1): |
| | category_data.append([f"category_{i}", cat]) |
| |
|
| | cat_table = Table(category_data, colWidths=[120, 330]) |
| |
|
| | cat_table.setStyle(TableStyle([ |
| | ('BACKGROUND', (0, 0), (-1, 0), colors.grey), |
| | ('TEXTCOLOR', (0, 0), (-1, 0), colors.whitesmoke), |
| | ('GRID', (0, 0), (-1, -1), 1, colors.black), |
| | ('PADDING', (0, 0), (-1, -1), 6), |
| | ('BACKGROUND', (0, 1), (0, -1), colors.lightgrey), |
| | ('FONTSIZE', (0, 0), (-1, -1), 9), |
| | ])) |
| | story.append(cat_table) |
| | story.append(Spacer(1, 15)) |
| |
|
| | story.append(Spacer(1, 30)) |
| | story.append(Paragraph("Citation", heading_style)) |
| | story.append(Paragraph("If you use CatVader in your research, please cite:", normal_style)) |
| | story.append(Spacer(1, 5)) |
| | story.append(Paragraph("Soria, C. (2025). CatVader: A Python package for LLM-based social media classification. DOI: 10.5281/zenodo.15532316", normal_style)) |
| |
|
| | |
| | story.append(PageBreak()) |
| | story.append(Paragraph("Classification Summary", title_style)) |
| | story.append(Spacer(1, 15)) |
| |
|
| | summary_data = [ |
| | ["Source File", filename], |
| | ["Source Column", column_name], |
| | ["Classification Mode", classify_mode], |
| | ["Model(s) Used", model], |
| | ["Model Source", model_source], |
| | ["Rows Classified", str(num_rows)], |
| | ["Number of Categories", str(len(categories)) if categories else "0"], |
| | ["Success Rate", f"{success_rate:.2f}%"], |
| | ] |
| | |
| | if classify_mode == "Ensemble" and consensus_threshold is not None: |
| | threshold_labels = {0.5: "Majority (50%+)", 0.67: "Two-Thirds (67%+)", 1.0: "Unanimous (100%)"} |
| | threshold_label = threshold_labels.get(consensus_threshold, f"Custom ({consensus_threshold:.0%})") |
| | summary_data.append(["Consensus Threshold", threshold_label]) |
| |
|
| | summary_table = Table(summary_data, colWidths=[150, 300]) |
| | summary_table.setStyle(TableStyle([ |
| | ('BACKGROUND', (0, 0), (0, -1), colors.lightgrey), |
| | ('GRID', (0, 0), (-1, -1), 1, colors.black), |
| | ('PADDING', (0, 0), (-1, -1), 6), |
| | ('FONTSIZE', (0, 0), (-1, -1), 9), |
| | ])) |
| | story.append(summary_table) |
| | story.append(Spacer(1, 15)) |
| |
|
| | |
| | if classify_mode == "Ensemble" and result_df is not None and categories: |
| | agreement_cols = [f"category_{i}_agreement" for i in range(1, len(categories) + 1)] |
| | has_agreement = all(col in result_df.columns for col in agreement_cols) |
| | if has_agreement: |
| | story.append(Paragraph("Ensemble Agreement Scores", heading_style)) |
| | story.append(Paragraph( |
| | "Agreement shows what proportion of models agreed on each category. " |
| | "Higher scores indicate stronger consensus.", normal_style)) |
| | story.append(Spacer(1, 8)) |
| |
|
| | agree_data = [["Category", "Mean Agreement", "Min Agreement"]] |
| | for i, cat in enumerate(categories, 1): |
| | col = f"category_{i}_agreement" |
| | mean_val = result_df[col].mean() |
| | min_val = result_df[col].min() |
| | agree_data.append([cat, f"{mean_val:.1%}", f"{min_val:.1%}"]) |
| |
|
| | agree_table = Table(agree_data, colWidths=[200, 125, 125]) |
| | agree_table.setStyle(TableStyle([ |
| | ('BACKGROUND', (0, 0), (-1, 0), colors.grey), |
| | ('TEXTCOLOR', (0, 0), (-1, 0), colors.whitesmoke), |
| | ('GRID', (0, 0), (-1, -1), 1, colors.black), |
| | ('PADDING', (0, 0), (-1, -1), 6), |
| | ('FONTSIZE', (0, 0), (-1, -1), 9), |
| | ])) |
| | story.append(agree_table) |
| | story.append(Spacer(1, 15)) |
| |
|
| | if processing_time is not None: |
| | story.append(Paragraph("Processing Time", heading_style)) |
| | rows_per_min = (num_rows / processing_time) * 60 if processing_time > 0 else 0 |
| | avg_time = processing_time / num_rows if num_rows > 0 else 0 |
| |
|
| | time_data = [ |
| | ["Total Processing Time", f"{processing_time:.1f} seconds"], |
| | ["Average Time per Response", f"{avg_time:.2f} seconds"], |
| | ["Processing Rate", f"{rows_per_min:.1f} rows/minute"], |
| | ] |
| | time_table = Table(time_data, colWidths=[180, 270]) |
| | time_table.setStyle(TableStyle([ |
| | ('BACKGROUND', (0, 0), (0, -1), colors.lightgrey), |
| | ('GRID', (0, 0), (-1, -1), 1, colors.black), |
| | ('PADDING', (0, 0), (-1, -1), 6), |
| | ('FONTSIZE', (0, 0), (-1, -1), 9), |
| | ])) |
| | story.append(time_table) |
| |
|
| | story.append(Spacer(1, 15)) |
| | story.append(Paragraph("Version Information", heading_style)) |
| | version_data = [ |
| | ["CatVader Version", catvader_version or "unknown"], |
| | ["Python Version", python_version or "unknown"], |
| | ["Timestamp", datetime.now().strftime('%Y-%m-%d %H:%M:%S')], |
| | ] |
| | version_table = Table(version_data, colWidths=[180, 270]) |
| | version_table.setStyle(TableStyle([ |
| | ('BACKGROUND', (0, 0), (0, -1), colors.lightgrey), |
| | ('GRID', (0, 0), (-1, -1), 1, colors.black), |
| | ('PADDING', (0, 0), (-1, -1), 6), |
| | ('FONTSIZE', (0, 0), (-1, -1), 9), |
| | ])) |
| | story.append(version_table) |
| |
|
| | |
| | if code: |
| | story.append(PageBreak()) |
| | story.append(Paragraph("Reproducibility Code", title_style)) |
| | story.append(Paragraph("Use this Python code to reproduce the classification with the CatVader package:", normal_style)) |
| | story.append(Spacer(1, 10)) |
| |
|
| | |
| | for line in code.strip().split('\n'): |
| | |
| | escaped_line = line.replace('&', '&').replace('<', '<').replace('>', '>') |
| | if escaped_line.strip(): |
| | story.append(Paragraph(escaped_line, code_style)) |
| | else: |
| | story.append(Spacer(1, 6)) |
| |
|
| | |
| | if result_df is not None and categories: |
| | from reportlab.platypus import Image |
| | import io |
| |
|
| | |
| | story.append(PageBreak()) |
| | story.append(Paragraph("Category Distribution", title_style)) |
| | try: |
| | fig1 = create_distribution_chart(result_df, categories, classify_mode, models_list) |
| | img_buffer1 = io.BytesIO() |
| | fig1.savefig(img_buffer1, format='png', dpi=150, bbox_inches='tight') |
| | img_buffer1.seek(0) |
| | plt.close(fig1) |
| |
|
| | |
| | img_temp1 = tempfile.NamedTemporaryFile(delete=False, suffix='.png') |
| | img_temp1.write(img_buffer1.read()) |
| | img_temp1.close() |
| |
|
| | img1 = Image(img_temp1.name, width=450, height=250) |
| | story.append(img1) |
| | story.append(Spacer(1, 10)) |
| | story.append(Paragraph("Note: Categories are not mutually exclusive—each item can belong to multiple categories.", normal_style)) |
| | except Exception as e: |
| | story.append(Paragraph(f"Could not generate distribution chart: {str(e)}", normal_style)) |
| |
|
| | |
| | story.append(PageBreak()) |
| | story.append(Paragraph("Classification Matrix", title_style)) |
| | try: |
| | fig2 = create_classification_heatmap(result_df, categories, classify_mode, models_list) |
| | img_buffer2 = io.BytesIO() |
| | fig2.savefig(img_buffer2, format='png', dpi=150, bbox_inches='tight') |
| | img_buffer2.seek(0) |
| | plt.close(fig2) |
| |
|
| | |
| | img_temp2 = tempfile.NamedTemporaryFile(delete=False, suffix='.png') |
| | img_temp2.write(img_buffer2.read()) |
| | img_temp2.close() |
| |
|
| | img2 = Image(img_temp2.name, width=450, height=300) |
| | story.append(img2) |
| | story.append(Spacer(1, 10)) |
| | story.append(Paragraph("Orange = category present, Black = not present. Each row represents one response.", normal_style)) |
| | except Exception as e: |
| | story.append(Paragraph(f"Could not generate classification matrix: {str(e)}", normal_style)) |
| |
|
| | doc.build(story) |
| | return pdf_file.name |
| |
|
| |
|
| | def run_auto_extract(input_type, input_data, description, max_categories_val, |
| | model_tier, model, api_key_input, mode=None, progress_callback=None): |
| | """Extract categories from data.""" |
| | if not CATVADER_AVAILABLE: |
| | return None, "catvader package not available" |
| |
|
| | actual_api_key, provider = get_api_key(model, model_tier, api_key_input) |
| | if not actual_api_key: |
| | return None, f"{provider} API key not configured" |
| |
|
| | model_source = get_model_source(model) |
| |
|
| | try: |
| | if isinstance(input_data, list): |
| | num_items = len(input_data) |
| | else: |
| | num_items = 1 |
| |
|
| | if input_type == "image": |
| | divisions = min(3, max(1, num_items // 5)) |
| | categories_per_chunk = 12 |
| | else: |
| | divisions = max(1, num_items // 15) |
| | divisions = min(divisions, 5) |
| | chunk_size = num_items // max(1, divisions) |
| | categories_per_chunk = min(10, chunk_size - 1) |
| |
|
| | extract_kwargs = { |
| | 'input_data': input_data, |
| | 'api_key': actual_api_key, |
| | 'input_type': input_type, |
| | 'description': description, |
| | 'user_model': model, |
| | 'model_source': model_source, |
| | 'divisions': divisions, |
| | 'categories_per_chunk': categories_per_chunk, |
| | 'max_categories': int(max_categories_val) |
| | } |
| | if mode: |
| | extract_kwargs['mode'] = mode |
| |
|
| | extract_result = catvader.extract(**extract_kwargs) |
| | categories = extract_result.get('top_categories', []) |
| |
|
| | if not categories: |
| | return None, "No categories were extracted" |
| |
|
| | return categories, f"Extracted {len(categories)} categories successfully!" |
| |
|
| | except Exception as e: |
| | return None, f"Error: {str(e)}" |
| |
|
| |
|
| | def run_classify_data(input_type, input_data, description, categories, |
| | model_tier, model, api_key_input, mode=None, |
| | original_filename="data", column_name="text", |
| | progress_callback=None): |
| | """Classify data with user-provided categories.""" |
| | if not CATVADER_AVAILABLE: |
| | return None, None, None, None, "catvader package not available" |
| |
|
| | if not categories: |
| | return None, None, None, None, "Please enter at least one category" |
| |
|
| | actual_api_key, provider = get_api_key(model, model_tier, api_key_input) |
| | if not actual_api_key: |
| | return None, None, None, None, f"{provider} API key not configured" |
| |
|
| | model_source = get_model_source(model) |
| |
|
| | try: |
| | start_time = time.time() |
| |
|
| | classify_kwargs = { |
| | 'input_data': input_data, |
| | 'categories': categories, |
| | 'models': [(model, model_source, actual_api_key)], |
| | 'description': description, |
| | } |
| | if mode: |
| | classify_kwargs['mode'] = mode |
| |
|
| | result = catvader.classify(**classify_kwargs) |
| |
|
| | processing_time = time.time() - start_time |
| | num_items = len(result) |
| |
|
| | |
| | with tempfile.NamedTemporaryFile(mode='w', suffix='_classified.csv', delete=False) as f: |
| | result.to_csv(f.name, index=False) |
| | csv_path = f.name |
| |
|
| | |
| | if 'processing_status' in result.columns: |
| | success_count = (result['processing_status'] == 'success').sum() |
| | success_rate = (success_count / len(result)) * 100 |
| | else: |
| | success_rate = 100.0 |
| |
|
| | |
| | try: |
| | catvader_version = catvader.__version__ |
| | except AttributeError: |
| | catvader_version = "unknown" |
| | python_version = sys.version.split()[0] |
| |
|
| | |
| | report_pdf_path = generate_methodology_report_pdf( |
| | categories=categories, |
| | model=model, |
| | column_name=column_name, |
| | num_rows=num_items, |
| | model_source=model_source, |
| | filename=original_filename, |
| | success_rate=success_rate, |
| | result_df=result, |
| | processing_time=processing_time, |
| | catvader_version=catvader_version, |
| | python_version=python_version, |
| | task_type="assign", |
| | input_type=input_type, |
| | description=description |
| | ) |
| |
|
| | |
| | code = generate_classify_code(input_type, description, categories, model, model_source, mode) |
| |
|
| | return result, csv_path, report_pdf_path, code, f"Classified {num_items} items in {processing_time:.1f}s" |
| |
|
| | except Exception as e: |
| | return None, None, None, None, f"Error: {str(e)}" |
| |
|
| |
|
| | def sanitize_model_name(model: str) -> str: |
| | """Convert model name to column-safe suffix (matches catvader logic).""" |
| | import re |
| | sanitized = re.sub(r'[^a-zA-Z0-9]', '_', model) |
| | sanitized = re.sub(r'_+', '_', sanitized) |
| | sanitized = sanitized.strip('_').lower() |
| | return sanitized[:40] |
| |
|
| |
|
| | def _find_model_column_suffix(result_df, model_name): |
| | """Find the actual column suffix used for a model in the DataFrame. |
| | |
| | catvader appends a creativity suffix (e.g. _tauto, _t50) to ensemble column |
| | names, so we can't just use sanitize_model_name(). This function looks at |
| | the real DataFrame columns to discover the full suffix. |
| | """ |
| | sanitized = sanitize_model_name(model_name) |
| | prefix = f"category_1_{sanitized}" |
| | for col in result_df.columns: |
| | if col.startswith(prefix): |
| | |
| | return col[len("category_1_"):] |
| | |
| | return sanitized |
| |
|
| |
|
| | def _find_all_model_suffixes(result_df): |
| | """Discover all distinct per-model column suffixes from the DataFrame. |
| | |
| | Looks at category_1_* columns (excluding _consensus and _agreement) |
| | to find every unique model suffix. Works even when the same model |
| | appears multiple times with different temperature suffixes. |
| | |
| | Returns: |
| | List of suffix strings, e.g. |
| | ['claude_haiku_4_5_20251001_t0', 'claude_haiku_4_5_20251001_t25', ...] |
| | """ |
| | import re |
| | suffixes = [] |
| | for col in result_df.columns: |
| | m = re.match(r'^category_1_(.+)$', col) |
| | if m: |
| | suffix = m.group(1) |
| | if suffix not in ('consensus', 'agreement'): |
| | suffixes.append(suffix) |
| | return suffixes |
| |
|
| |
|
| | def create_classification_heatmap(result_df, categories, classify_mode="Single Model", models_list=None): |
| | """Create a binary heatmap showing classification for each row. |
| | |
| | Args: |
| | result_df: DataFrame with classification results |
| | categories: List of category names |
| | classify_mode: "Single Model", "Model Comparison", or "Ensemble" |
| | models_list: List of model names (for multi-model modes) |
| | """ |
| | import numpy as np |
| |
|
| | total_rows = len(result_df) |
| | if total_rows == 0: |
| | fig, ax = plt.subplots(figsize=(10, 4)) |
| | ax.text(0.5, 0.5, 'No data to display', ha='center', va='center', fontsize=14) |
| | ax.axis('off') |
| | return fig |
| |
|
| | |
| | if classify_mode == "Ensemble": |
| | |
| | col_names = [f"category_{i}_consensus" for i in range(1, len(categories) + 1)] |
| | elif classify_mode == "Model Comparison" and models_list: |
| | |
| | suffix = _find_model_column_suffix(result_df, models_list[0]) |
| | col_names = [f"category_{i}_{suffix}" for i in range(1, len(categories) + 1)] |
| | else: |
| | |
| | col_names = [f"category_{i}" for i in range(1, len(categories) + 1)] |
| |
|
| | |
| | matrix_data = [] |
| | for col in col_names: |
| | if col in result_df.columns: |
| | matrix_data.append(result_df[col].astype(int).values) |
| | else: |
| | matrix_data.append(np.zeros(total_rows, dtype=int)) |
| |
|
| | matrix = np.array(matrix_data).T |
| |
|
| | |
| | fig_height = max(4, min(20, total_rows * 0.15)) |
| | fig_width = max(8, len(categories) * 0.8) |
| | fig, ax = plt.subplots(figsize=(fig_width, fig_height)) |
| |
|
| | |
| | from matplotlib.colors import ListedColormap |
| | cmap = ListedColormap(['#1a1a1a', '#E8A33C']) |
| |
|
| | |
| | im = ax.imshow(matrix, aspect='auto', cmap=cmap, vmin=0, vmax=1) |
| |
|
| | |
| | ax.set_xticks(range(len(categories))) |
| | ax.set_xticklabels(categories, rotation=45, ha='right', fontsize=9) |
| | ax.set_xlabel('Categories', fontsize=11) |
| | ax.set_ylabel(f'Responses (n={total_rows})', fontsize=11) |
| | ax.set_yticks([]) |
| |
|
| | title = 'Classification Matrix' |
| | if classify_mode == "Ensemble": |
| | title += ' (Ensemble Consensus)' |
| | elif classify_mode == "Model Comparison": |
| | title += f' ({models_list[0].split("/")[-1].split(":")[0][:20]})' |
| | ax.set_title(title, fontsize=14, fontweight='bold') |
| |
|
| | |
| | from matplotlib.patches import Patch |
| | legend_elements = [ |
| | Patch(facecolor='#1a1a1a', edgecolor='white', label='Not Present'), |
| | Patch(facecolor='#E8A33C', edgecolor='white', label='Present') |
| | ] |
| | ax.legend(handles=legend_elements, loc='upper right', bbox_to_anchor=(1.15, 1)) |
| |
|
| | plt.tight_layout() |
| | return fig |
| |
|
| |
|
| | def create_distribution_chart(result_df, categories, classify_mode="Single Model", models_list=None): |
| | """Create a bar chart showing category distribution. |
| | |
| | Args: |
| | result_df: DataFrame with classification results |
| | categories: List of category names |
| | classify_mode: "Single Model", "Model Comparison", or "Ensemble" |
| | models_list: List of model names (for multi-model modes) |
| | """ |
| | import numpy as np |
| |
|
| | total_rows = len(result_df) |
| | if total_rows == 0: |
| | fig, ax = plt.subplots(figsize=(10, 4)) |
| | ax.text(0.5, 0.5, 'No data to display', ha='center', va='center', fontsize=14) |
| | ax.axis('off') |
| | return fig |
| |
|
| | |
| | model_colors = ['#2563eb', '#dc2626', '#16a34a', '#ca8a04', '#9333ea', '#0891b2', '#be185d', '#65a30d'] |
| |
|
| | if classify_mode == "Single Model": |
| | |
| | fig, ax = plt.subplots(figsize=(10, max(4, len(categories) * 0.8))) |
| |
|
| | dist_data = [] |
| | for i, cat in enumerate(categories, 1): |
| | col_name = f"category_{i}" |
| | if col_name in result_df.columns: |
| | count = int(result_df[col_name].sum()) |
| | pct = (count / total_rows) * 100 |
| | dist_data.append({"Category": cat, "Percentage": round(pct, 1)}) |
| |
|
| | categories_list = [d["Category"] for d in dist_data][::-1] |
| | percentages = [d["Percentage"] for d in dist_data][::-1] |
| |
|
| | bars = ax.barh(categories_list, percentages, color='#2563eb') |
| | ax.set_xlim(0, 100) |
| | ax.set_xlabel('Percentage (%)', fontsize=11) |
| | ax.set_title('Category Distribution (%)', fontsize=14, fontweight='bold') |
| |
|
| | for bar, pct in zip(bars, percentages): |
| | ax.text(bar.get_width() + 1, bar.get_y() + bar.get_height()/2, |
| | f'{pct:.1f}%', va='center', fontsize=10) |
| |
|
| | elif classify_mode == "Ensemble": |
| | |
| | fig, ax = plt.subplots(figsize=(10, max(4, len(categories) * 0.8))) |
| |
|
| | dist_data = [] |
| | for i, cat in enumerate(categories, 1): |
| | col_name = f"category_{i}_consensus" |
| | if col_name in result_df.columns: |
| | count = int(result_df[col_name].sum()) |
| | pct = (count / total_rows) * 100 |
| | dist_data.append({"Category": cat, "Percentage": round(pct, 1)}) |
| |
|
| | categories_list = [d["Category"] for d in dist_data][::-1] |
| | percentages = [d["Percentage"] for d in dist_data][::-1] |
| |
|
| | bars = ax.barh(categories_list, percentages, color='#16a34a') |
| | ax.set_xlim(0, 100) |
| | ax.set_xlabel('Percentage (%)', fontsize=11) |
| | ax.set_title('Ensemble Consensus Distribution (%)', fontsize=14, fontweight='bold') |
| |
|
| | for bar, pct in zip(bars, percentages): |
| | ax.text(bar.get_width() + 1, bar.get_y() + bar.get_height()/2, |
| | f'{pct:.1f}%', va='center', fontsize=10) |
| |
|
| | else: |
| | |
| | if not models_list: |
| | models_list = [] |
| |
|
| | |
| | model_suffixes = [_find_model_column_suffix(result_df, m) for m in models_list] |
| | n_models = len(model_suffixes) |
| | n_categories = len(categories) |
| |
|
| | fig, ax = plt.subplots(figsize=(12, max(5, n_categories * 1.2))) |
| |
|
| | |
| | bar_height = 0.8 / n_models |
| | y_positions = np.arange(n_categories) |
| |
|
| | for model_idx, (model_name, suffix) in enumerate(zip(models_list, model_suffixes)): |
| | model_pcts = [] |
| | for i in range(1, n_categories + 1): |
| | col_name = f"category_{i}_{suffix}" |
| | if col_name in result_df.columns: |
| | count = int(result_df[col_name].sum()) |
| | pct = (count / total_rows) * 100 |
| | else: |
| | pct = 0 |
| | model_pcts.append(pct) |
| |
|
| | |
| | model_pcts = model_pcts[::-1] |
| | offset = (model_idx - n_models / 2 + 0.5) * bar_height |
| | color = model_colors[model_idx % len(model_colors)] |
| |
|
| | |
| | display_name = model_name.split('/')[-1].split(':')[0][:20] |
| | bars = ax.barh(y_positions + offset, model_pcts, bar_height * 0.9, |
| | label=display_name, color=color, alpha=0.85) |
| |
|
| | ax.set_yticks(y_positions) |
| | ax.set_yticklabels(categories[::-1]) |
| | ax.set_xlim(0, 100) |
| | ax.set_xlabel('Percentage (%)', fontsize=11) |
| | ax.set_title('Category Distribution by Model (%)', fontsize=14, fontweight='bold') |
| | ax.legend(loc='lower right', fontsize=9) |
| |
|
| | plt.tight_layout() |
| | return fig |
| |
|
| |
|
| | |
| | st.set_page_config( |
| | page_title="CatVader - Social Media Classifier", |
| | page_icon="🐱", |
| | layout="wide" |
| | ) |
| |
|
| | |
| | st.markdown(""" |
| | <style> |
| | /* Import Garamond font and apply globally */ |
| | @import url('https://fonts.googleapis.com/css2?family=EB+Garamond:wght@400;500;600;700&display=swap'); |
| | |
| | *:not([class*="icon"]):not([data-testid="stIconMaterial"]):not(svg):not(path) { |
| | font-family: 'EB Garamond', Garamond, Georgia, serif !important; |
| | font-size: 17px !important; |
| | } |
| | |
| | /* Preserve Streamlit icon fonts */ |
| | [data-testid="stIconMaterial"], .material-icons, .material-symbols-rounded { |
| | font-family: 'Material Symbols Rounded', 'Material Icons' !important; |
| | font-size: 24px !important; |
| | } |
| | |
| | /* Main container styling */ |
| | .main .block-container { |
| | padding-top: 2rem; |
| | padding-bottom: 2rem; |
| | } |
| | |
| | /* Headers with gradient accent */ |
| | h1 { |
| | background: linear-gradient(90deg, #E8A33C 0%, #D4872C 100%); |
| | -webkit-background-clip: text; |
| | -webkit-text-fill-color: transparent; |
| | background-clip: text; |
| | font-weight: 700; |
| | } |
| | |
| | /* Card-like sections */ |
| | .stExpander { |
| | border: 1px solid #E8D5B5; |
| | border-radius: 12px; |
| | box-shadow: 0 2px 8px rgba(232, 163, 60, 0.08); |
| | } |
| | |
| | /* File uploader styling */ |
| | .stFileUploader { |
| | border-radius: 12px; |
| | } |
| | |
| | .stFileUploader > div > div { |
| | border: 2px dashed #E8A33C; |
| | border-radius: 12px; |
| | background: linear-gradient(135deg, #FEFCF9 0%, #F5EFE6 100%); |
| | } |
| | |
| | /* Button styling */ |
| | .stButton > button { |
| | border-radius: 8px; |
| | font-weight: 600; |
| | transition: all 0.2s ease; |
| | border: 2px solid #E8A33C; |
| | background: #FEFCF9; |
| | color: #D4872C; |
| | } |
| | |
| | /* Tall button for example dataset (matches file uploader height) */ |
| | .tall-button .stButton > button { |
| | min-height: 107px; |
| | border-radius: 12px; |
| | } |
| | |
| | .stButton > button:hover { |
| | transform: translateY(-1px); |
| | box-shadow: 0 4px 12px rgba(232, 163, 60, 0.3); |
| | background: #F5EFE6; |
| | } |
| | |
| | /* Primary button */ |
| | .stButton > button[kind="primary"] { |
| | background: linear-gradient(135deg, #E8A33C 0%, #D4872C 100%); |
| | border: none; |
| | color: white; |
| | } |
| | |
| | /* Success/info messages */ |
| | .stSuccess { |
| | background-color: #E8F5E9; |
| | border-left: 4px solid #4CAF50; |
| | border-radius: 0 8px 8px 0; |
| | } |
| | |
| | .stInfo { |
| | background-color: #FFF8E8; |
| | border-left: 4px solid #E8A33C; |
| | border-radius: 0 8px 8px 0; |
| | } |
| | |
| | /* Radio buttons */ |
| | .stRadio > div { |
| | gap: 0.5rem; |
| | display: flex; |
| | width: 100%; |
| | } |
| | |
| | .stRadio > div > label { |
| | background: #F5EFE6; |
| | padding: 0.5rem 1rem; |
| | border-radius: 20px; |
| | border: 1px solid transparent; |
| | transition: all 0.2s ease; |
| | flex: 1; |
| | text-align: center; |
| | justify-content: center; |
| | } |
| | |
| | .stRadio > div > label:hover { |
| | border-color: #E8A33C; |
| | } |
| | |
| | /* Text inputs */ |
| | .stTextInput > div > div > input { |
| | border-radius: 8px; |
| | border: 1px solid #E8D5B5; |
| | } |
| | |
| | .stTextInput > div > div > input:focus { |
| | border-color: #E8A33C; |
| | box-shadow: 0 0 0 2px rgba(232, 163, 60, 0.2); |
| | } |
| | |
| | /* Select boxes */ |
| | .stSelectbox > div > div { |
| | border-radius: 8px; |
| | } |
| | |
| | /* Dataframe styling */ |
| | .stDataFrame { |
| | border-radius: 12px; |
| | overflow: hidden; |
| | box-shadow: 0 2px 8px rgba(0, 0, 0, 0.05); |
| | } |
| | |
| | /* Progress bar */ |
| | .stProgress > div > div { |
| | background: linear-gradient(90deg, #E8A33C 0%, #D4872C 100%); |
| | border-radius: 10px; |
| | } |
| | |
| | /* Slider */ |
| | .stSlider > div > div > div { |
| | background: #E8A33C; |
| | } |
| | |
| | /* Divider */ |
| | hr { |
| | border: none; |
| | height: 1px; |
| | background: linear-gradient(90deg, transparent, #E8D5B5, transparent); |
| | margin: 1.5rem 0; |
| | } |
| | |
| | /* Code blocks */ |
| | .stCodeBlock { |
| | border-radius: 12px; |
| | border: 1px solid #E8D5B5; |
| | } |
| | |
| | /* Metric cards */ |
| | .stMetric { |
| | background: linear-gradient(135deg, #FEFCF9 0%, #F5EFE6 100%); |
| | padding: 1rem; |
| | border-radius: 12px; |
| | border: 1px solid #E8D5B5; |
| | } |
| | |
| | /* Download buttons */ |
| | .stDownloadButton > button { |
| | background: #F5EFE6; |
| | border: 1px solid #E8A33C; |
| | color: #D4872C; |
| | } |
| | |
| | .stDownloadButton > button:hover { |
| | background: #E8A33C; |
| | color: white; |
| | } |
| | |
| | /* Multiselect */ |
| | .stMultiSelect > div > div { |
| | border-radius: 8px; |
| | } |
| | |
| | /* Status indicator */ |
| | .stStatus { |
| | border-radius: 12px; |
| | } |
| | |
| | /* Column gaps */ |
| | [data-testid="column"] { |
| | padding: 0 0.5rem; |
| | } |
| | |
| | /* Logo and title alignment */ |
| | [data-testid="column"]:first-child img { |
| | border-radius: 8px; |
| | } |
| | </style> |
| | """, unsafe_allow_html=True) |
| |
|
| | |
| | if 'categories' not in st.session_state: |
| | st.session_state.categories = [''] * MAX_CATEGORIES |
| | if 'category_count' not in st.session_state: |
| | st.session_state.category_count = INITIAL_CATEGORIES |
| | if 'task_mode' not in st.session_state: |
| | st.session_state.task_mode = None |
| | if 'extracted_categories' not in st.session_state: |
| | st.session_state.extracted_categories = None |
| | if 'results' not in st.session_state: |
| | st.session_state.results = None |
| | if 'active_tab' not in st.session_state: |
| | st.session_state.active_tab = "survey" |
| | if 'survey_data' not in st.session_state: |
| | st.session_state.survey_data = None |
| | if 'pdf_data' not in st.session_state: |
| | st.session_state.pdf_data = None |
| | if 'image_data' not in st.session_state: |
| | st.session_state.image_data = None |
| | if 'extraction_params' not in st.session_state: |
| | st.session_state.extraction_params = None |
| | if 'bluesky_df' not in st.session_state: |
| | st.session_state.bluesky_df = None |
| |
|
| | |
| | st.markdown(""" |
| | <div style="display: flex; align-items: center; gap: 20px; margin-bottom: 10px;"> |
| | <img src="https://huggingface.co/spaces/CatVader/social-media-classifier/resolve/main/logo.png" width="100" style="border-radius: 8px;"> |
| | <div> |
| | <div style="font-size: 2.2rem; font-weight: 700; color: #333; font-family: 'EB Garamond', Garamond, Georgia, serif; line-height: 1.1;">CatVader</div> |
| | <div style="font-size: 1.1rem; font-weight: 500; color: #E8A33C; font-family: 'EB Garamond', Garamond, Georgia, serif; margin-bottom: 4px;">NLP for Survey Research</div> |
| | <div style="font-size: 1rem; font-weight: 400; color: #666; font-family: 'EB Garamond', Garamond, Georgia, serif;">Research-grade classification of social media posts, PDFs, and images using AI models.</div> |
| | <div style="font-size: 0.85rem; font-weight: 400; color: #888; font-family: 'EB Garamond', Garamond, Georgia, serif; margin-top: 4px;">Developed at UC Berkeley</div> |
| | </div> |
| | </div> |
| | """, unsafe_allow_html=True) |
| |
|
| | |
| | with st.expander("About This App"): |
| | st.markdown(""" |
| | **Privacy Notice:** Your data is sent to third-party LLM APIs for classification. Do not upload sensitive, confidential, or personally identifiable information (PII). |
| | |
| | --- |
| | |
| | **CatVader** is an open-source Python package for classifying and exploring social media data using Large Language Models. |
| | |
| | ### What It Does |
| | - **Extract Categories**: Discover themes and categories in your data automatically |
| | - **Assign Categories**: Classify data into your predefined categories |
| | - **Extract & Assign**: Let CatVader discover categories, then classify all your data |
| | |
| | ### Supported Providers |
| | OpenAI (GPT-4o, GPT-4o Mini), Anthropic (Claude), Google (Gemini), Mistral, HuggingFace, xAI (Grok), and Perplexity. Use the free tier or bring your own API key. |
| | |
| | ### Beta Test - We Want Your Feedback! |
| | This app is currently in **beta** and **free to use** while CatVader is under active development, made possible by **Bashir Ahmed's generous fellowship support**. |
| | |
| | - Found a bug? Have a feature request? Please open an issue on [GitHub](https://github.com/chrissoria/cat-vader) |
| | - Reach out directly: [chrissoria@berkeley.edu](mailto:chrissoria@berkeley.edu) |
| | |
| | ### Acknowledgments |
| | - **Bashir Ahmed** for his generous fellowship support that makes this free beta possible |
| | - **Claude Fischer** for his thoughtful feedback and collaboration on research that helped inspire this project |
| | - **Kevin Collins** from Survey360 for his input |
| | - **Fendi Tsim** for sharing it widely |
| | |
| | ### Links |
| | - **Website**: [christophersoria.com](https://christophersoria.com) |
| | - **PyPI**: [pip install cat-vader](https://pypi.org/project/cat-vader/) |
| | - **GitHub**: [github.com/chrissoria/cat-vader](https://github.com/chrissoria/cat-vader) |
| | |
| | ### Citation |
| | If you use CatVader in your research, please cite: |
| | ``` |
| | Soria, C. (2025). CatVader: A Python package for LLM-based social media classification. DOI: 10.5281/zenodo.15532316 |
| | ``` |
| | """) |
| |
|
| | |
| | col_input, col_output = st.columns([1, 1]) |
| |
|
| | with col_input: |
| | |
| | input_type_choice = st.radio( |
| | "Input Type", |
| | options=["Social Media Posts", "PDF Documents", "Images"], |
| | horizontal=True, |
| | key="input_type_radio" |
| | ) |
| |
|
| | |
| | input_data = None |
| | input_type_selected = "text" |
| | description = "" |
| | original_filename = "data" |
| | pdf_mode = "Image (visual documents)" |
| |
|
| | if input_type_choice == "Social Media Posts": |
| | input_type_selected = "text" |
| |
|
| | data_source = st.radio( |
| | "Data Source", |
| | options=["Upload CSV/Excel", "Fetch from Bluesky"], |
| | horizontal=True, |
| | key="data_source_radio" |
| | ) |
| |
|
| | if data_source == "Upload CSV/Excel": |
| | st.session_state.bluesky_df = None |
| | upload_col, example_col = st.columns([3, 1]) |
| | with upload_col: |
| | uploaded_file = st.file_uploader( |
| | "Upload Data (CSV or Excel)", |
| | type=['csv', 'xlsx', 'xls'], |
| | key="survey_file" |
| | ) |
| | with example_col: |
| | st.markdown("<div style='height: 27px;'></div>", unsafe_allow_html=True) |
| | st.markdown('<div class="tall-button">', unsafe_allow_html=True) |
| | if st.button("Try Example Dataset", key="example_btn", use_container_width=True): |
| | st.session_state.example_loaded = True |
| | st.markdown('</div>', unsafe_allow_html=True) |
| |
|
| | columns = [] |
| | df = None |
| | if uploaded_file is not None: |
| | try: |
| | if uploaded_file.name.endswith('.csv'): |
| | df = pd.read_csv(uploaded_file) |
| | else: |
| | df = pd.read_excel(uploaded_file) |
| | columns = df.columns.tolist() |
| | st.success(f"Loaded {len(df):,} rows") |
| | except Exception as e: |
| | st.error(f"Error loading file: {e}") |
| | elif hasattr(st.session_state, 'example_loaded') and st.session_state.example_loaded: |
| | try: |
| | df = pd.read_csv("example_data.csv") |
| | columns = df.columns.tolist() |
| | st.success(f"Loaded example dataset ({len(df)} rows)") |
| | except: |
| | pass |
| |
|
| | selected_column = st.selectbox( |
| | "Column to Process", |
| | options=columns if columns else ["Upload a file first"], |
| | disabled=not columns, |
| | key="survey_column" |
| | ) |
| |
|
| | description = selected_column if columns else "" |
| | original_filename = uploaded_file.name if uploaded_file else "example_data.csv" |
| |
|
| | if df is not None and columns and selected_column in columns: |
| | input_data = df[selected_column].tolist() |
| |
|
| | else: |
| | bsky_handle = st.text_input( |
| | "Bluesky Handle", |
| | placeholder="e.g. aoc.bsky.social or @aoc.bsky.social", |
| | key="bluesky_handle_input" |
| | ) |
| | bsky_num_posts = st.slider( |
| | "Number of Posts to Fetch", |
| | min_value=10, max_value=250, value=50, step=10, |
| | key="bluesky_num_posts" |
| | ) |
| | if st.button("Fetch Posts", key="fetch_bluesky_btn"): |
| | handle_clean = bsky_handle.strip().lstrip("@") |
| | if not handle_clean: |
| | st.error("Please enter a Bluesky handle.") |
| | else: |
| | with st.spinner(f"Fetching {bsky_num_posts} posts from {handle_clean}..."): |
| | try: |
| | from catvader._social_media import fetch_bluesky |
| | df_bsky = fetch_bluesky(limit=bsky_num_posts, handle=handle_clean) |
| | df_bsky = df_bsky[df_bsky["media_type"] != "REPOST_FACADE"].reset_index(drop=True) |
| | st.session_state.bluesky_df = df_bsky |
| | except Exception as e: |
| | st.error(f"Error fetching posts: {e}") |
| |
|
| | if st.session_state.bluesky_df is not None: |
| | bsky_df = st.session_state.bluesky_df |
| | st.success(f"Fetched {len(bsky_df)} posts") |
| | st.dataframe( |
| | bsky_df[["timestamp", "text", "likes", "replies"]].head(5), |
| | use_container_width=True |
| | ) |
| | handle_clean = bsky_handle.strip().lstrip("@") if bsky_handle else "bluesky" |
| | input_data = bsky_df["text"].tolist() |
| | description = f"Bluesky posts from @{handle_clean}" |
| | original_filename = f"bluesky_{handle_clean.replace('.', '_')}" |
| |
|
| | elif input_type_choice == "PDF Documents": |
| | input_type_selected = "pdf" |
| |
|
| | pdf_files = st.file_uploader( |
| | "Upload PDF Document(s)", |
| | type=['pdf'], |
| | accept_multiple_files=True, |
| | key="pdf_files" |
| | ) |
| |
|
| | pdf_description = st.text_input( |
| | "Document Description", |
| | placeholder="e.g., 'research papers', 'interview transcripts'", |
| | help="Helps the LLM understand context", |
| | key="pdf_desc" |
| | ) |
| |
|
| | pdf_mode = st.radio( |
| | "Processing Mode", |
| | options=["Image (visual documents)", "Text (text-heavy)", "Both (comprehensive)"], |
| | key="pdf_mode" |
| | ) |
| |
|
| | if pdf_files: |
| | input_data = [] |
| | pdf_name_map = {} |
| | for f in pdf_files: |
| | with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp: |
| | tmp.write(f.read()) |
| | input_data.append(tmp.name) |
| | pdf_name_map[tmp.name] = f.name.replace('.pdf', '') |
| | st.session_state.pdf_name_map = pdf_name_map |
| | description = pdf_description or "document" |
| | original_filename = "pdf_files" |
| | st.success(f"Uploaded {len(pdf_files)} PDF file(s)") |
| |
|
| | else: |
| | input_type_selected = "image" |
| |
|
| | image_files = st.file_uploader( |
| | "Upload Images", |
| | type=['png', 'jpg', 'jpeg', 'gif', 'webp'], |
| | accept_multiple_files=True, |
| | key="image_files" |
| | ) |
| |
|
| | image_description = st.text_input( |
| | "Image Description", |
| | placeholder="e.g., 'product photos', 'social media posts'", |
| | help="Helps the LLM understand context", |
| | key="image_desc" |
| | ) |
| |
|
| | if image_files: |
| | input_data = [] |
| | for f in image_files: |
| | suffix = '.' + f.name.split('.')[-1] |
| | with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp: |
| | tmp.write(f.read()) |
| | input_data.append(tmp.name) |
| | description = image_description or "images" |
| | original_filename = "image_files" |
| | st.success(f"Uploaded {len(image_files)} image file(s)") |
| |
|
| | st.markdown("---") |
| |
|
| | |
| | st.markdown("### What would you like to do?") |
| | col_btn1, col_btn2 = st.columns(2) |
| | with col_btn1: |
| | manual_mode = st.button("Enter Categories Manually", use_container_width=True) |
| | with col_btn2: |
| | auto_mode = st.button("Auto-extract Categories", use_container_width=True) |
| |
|
| | if manual_mode: |
| | st.session_state.task_mode = "manual" |
| | if auto_mode: |
| | st.session_state.task_mode = "auto_extract" |
| |
|
| | |
| | if st.session_state.task_mode == "auto_extract": |
| | st.markdown("### Auto-extract Categories") |
| | st.markdown("We'll analyze your data to discover the main categories.") |
| |
|
| | max_categories = st.slider( |
| | "Number of Categories to Extract", |
| | min_value=3, |
| | max_value=25, |
| | value=12, |
| | help="How many categories should be identified in your data" |
| | ) |
| |
|
| | specificity = st.selectbox( |
| | "How specific should categories be?", |
| | options=["Broad", "Moderate", "Narrow"], |
| | index=0, |
| | help="Broad = general themes, Moderate = balanced detail, Narrow = highly specific categories" |
| | ) |
| |
|
| | focus = st.text_input( |
| | "What should categories be focused around? (optional)", |
| | placeholder="e.g., 'decisions to move', 'emotional responses', 'financial factors'", |
| | help="Guide the model to prioritize extracting categories related to this focus" |
| | ) |
| |
|
| | |
| | st.markdown("### Model Selection") |
| | model_tier = st.radio( |
| | "Model Tier", |
| | options=["Free Models", "Bring Your Own Key"], |
| | key="extract_model_tier" |
| | ) |
| |
|
| | if model_tier == "Free Models": |
| | model_display = st.selectbox("Model", options=FREE_MODEL_DISPLAY_NAMES, key="extract_model") |
| | model = FREE_MODELS_MAP[model_display] |
| | api_key = "" |
| | else: |
| | model = st.selectbox("Model", options=PAID_MODEL_CHOICES, key="extract_model_paid") |
| | api_key = st.text_input("API Key", type="password", key="extract_api_key") |
| |
|
| | if st.button("Extract Categories", type="primary"): |
| | if input_data is None: |
| | st.error("Please upload data first") |
| | else: |
| | mode = None |
| | if input_type_selected == "pdf": |
| | mode_mapping = { |
| | "Image (visual documents)": "image", |
| | "Text (text-heavy)": "text", |
| | "Both (comprehensive)": "both" |
| | } |
| | mode = mode_mapping.get(pdf_mode, "image") |
| |
|
| | actual_api_key, provider = get_api_key(model, model_tier, api_key) |
| | if not actual_api_key: |
| | st.error(f"{provider} API key not configured") |
| | else: |
| | model_source = get_model_source(model) |
| |
|
| | |
| | num_items = len(input_data) if isinstance(input_data, list) else 1 |
| | if input_type_selected == "pdf": |
| | |
| | total_pages = sum(count_pdf_pages(p) for p in (input_data if isinstance(input_data, list) else [input_data])) |
| | est_seconds = total_pages * 5 |
| | elif input_type_selected == "image": |
| | |
| | est_seconds = num_items * 4 |
| | else: |
| | |
| | est_seconds = max(10, num_items * 0.5) |
| |
|
| | |
| | progress_bar = st.progress(0) |
| | status_text = st.empty() |
| | start_time = time.time() |
| |
|
| | |
| | def extract_progress_callback(current_step, total_steps, step_label): |
| | progress = current_step / total_steps if total_steps > 0 else 0 |
| | progress_bar.progress(min(progress, 1.0)) |
| |
|
| | elapsed = time.time() - start_time |
| | if current_step > 0: |
| | avg_time = elapsed / current_step |
| | eta_seconds = avg_time * (total_steps - current_step) |
| | eta_str = f" | ETA: {eta_seconds:.0f}s" if eta_seconds < 60 else f" | ETA: {eta_seconds/60:.1f}m" |
| | else: |
| | eta_str = "" |
| |
|
| | status_text.text(f"Extracting categories: {step_label} ({progress*100:.0f}%){eta_str}") |
| |
|
| | extract_kwargs = { |
| | 'input_data': input_data, |
| | 'api_key': actual_api_key, |
| | 'input_type': input_type_selected, |
| | 'description': description, |
| | 'user_model': model, |
| | 'model_source': model_source, |
| | 'max_categories': int(max_categories), |
| | 'specificity': specificity.lower(), |
| | 'progress_callback': extract_progress_callback, |
| | } |
| | if mode: |
| | extract_kwargs['mode'] = mode |
| | if focus and focus.strip(): |
| | extract_kwargs['focus'] = focus.strip() |
| |
|
| | try: |
| | extract_result = catvader.extract(**extract_kwargs) |
| | categories = extract_result.get('top_categories', []) |
| |
|
| | processing_time = time.time() - start_time |
| | progress_bar.progress(1.0) |
| | status_text.text(f"Completed in {processing_time:.1f}s") |
| |
|
| | if categories: |
| | st.success(f"Extracted {len(categories)} categories in {processing_time:.1f}s") |
| | st.session_state.extracted_categories = categories |
| | |
| | st.session_state.extraction_params = { |
| | 'model': model, |
| | 'model_source': model_source, |
| | 'max_categories': int(max_categories), |
| | 'input_type': input_type_selected, |
| | 'description': description, |
| | 'mode': mode, |
| | } |
| | st.session_state.task_mode = "manual" |
| | st.rerun() |
| | else: |
| | st.error("No categories were extracted from the data") |
| | except Exception as e: |
| | st.error(f"Error: {str(e)}") |
| |
|
| | |
| | if st.session_state.task_mode == "manual": |
| | st.markdown("### Categories") |
| | st.markdown("Enter your classification categories below.") |
| |
|
| | |
| | if st.session_state.extracted_categories: |
| | for i, cat in enumerate(st.session_state.extracted_categories[:MAX_CATEGORIES]): |
| | st.session_state.categories[i] = cat |
| | st.session_state.category_count = min(len(st.session_state.extracted_categories), MAX_CATEGORIES) |
| | st.session_state.extracted_categories = None |
| |
|
| | placeholder_examples = [ |
| | "e.g., Positive sentiment", |
| | "e.g., Negative sentiment", |
| | "e.g., Product feedback", |
| | "e.g., Service complaint", |
| | "e.g., Feature request", |
| | "e.g., Custom category" |
| | ] |
| |
|
| | categories_entered = [] |
| | for i in range(st.session_state.category_count): |
| | placeholder = placeholder_examples[i] if i < len(placeholder_examples) else "e.g., Custom category" |
| | cat_value = st.text_input( |
| | f"Category {i+1}", |
| | value=st.session_state.categories[i], |
| | placeholder=placeholder, |
| | key=f"cat_{i}" |
| | ) |
| | st.session_state.categories[i] = cat_value |
| | if cat_value.strip(): |
| | categories_entered.append(cat_value.strip()) |
| |
|
| | if st.session_state.category_count < MAX_CATEGORIES: |
| | if st.button("+ Add More"): |
| | st.session_state.category_count += 1 |
| | st.rerun() |
| |
|
| | st.markdown("### Model Selection") |
| |
|
| | |
| | classify_mode = st.radio( |
| | "Classification Mode", |
| | options=["Single Model", "Model Comparison", "Ensemble"], |
| | horizontal=True, |
| | key="classify_mode", |
| | help="Single: one model. Comparison: see results from multiple models side-by-side. Ensemble: multiple models vote for consensus." |
| | ) |
| |
|
| | model_tier = st.radio( |
| | "Model Tier", |
| | options=["Free Models", "Bring Your Own Key"], |
| | key="classify_model_tier" |
| | ) |
| |
|
| | |
| | is_multi_model = classify_mode in ["Model Comparison", "Ensemble"] |
| | min_models = 3 if classify_mode == "Ensemble" else 2 |
| |
|
| | |
| | |
| | model_temperatures = {} |
| | |
| | ensemble_runs = [] |
| |
|
| | if classify_mode == "Ensemble": |
| | |
| | if "ensemble_num_runs" not in st.session_state: |
| | st.session_state.ensemble_num_runs = 3 |
| |
|
| | if model_tier == "Free Models": |
| | model_options = FREE_MODEL_DISPLAY_NAMES |
| | is_free = True |
| | else: |
| | model_options = PAID_MODEL_CHOICES |
| | is_free = False |
| |
|
| | st.markdown(f"**Model Runs** (select {min_models}+ runs)") |
| | for i in range(st.session_state.ensemble_num_runs): |
| | cols = st.columns([3, 1, 0.5]) |
| | with cols[0]: |
| | default_idx = 0 if i < len(model_options) else i % len(model_options) |
| | selected = st.selectbox( |
| | f"Run {i+1}", options=model_options, |
| | index=default_idx, key=f"ensemble_model_{i}", |
| | label_visibility="collapsed" |
| | ) |
| | with cols[1]: |
| | temp = st.number_input( |
| | "Temp", min_value=0.0, max_value=2.0, value=round(i * 0.25, 2), |
| | step=0.25, key=f"ensemble_temp_{i}", label_visibility="collapsed" |
| | ) |
| | with cols[2]: |
| | if st.session_state.ensemble_num_runs > 3: |
| | if st.button("✕", key=f"ensemble_remove_{i}"): |
| | st.session_state.ensemble_num_runs -= 1 |
| | st.rerun() |
| |
|
| | model_name = FREE_MODELS_MAP[selected] if is_free else selected |
| | ensemble_runs.append((model_name, temp)) |
| |
|
| | if st.button("Add Run", key="add_ensemble_run"): |
| | st.session_state.ensemble_num_runs += 1 |
| | st.rerun() |
| |
|
| | models_list = [r[0] for r in ensemble_runs] |
| | model_temperatures = {f"{r[0]}__run{i}": r[1] for i, r in enumerate(ensemble_runs)} |
| | api_key = "" if model_tier == "Free Models" else st.text_input("API Key", type="password", key="classify_api_key") |
| |
|
| | elif is_multi_model: |
| | |
| | if model_tier == "Free Models": |
| | default_models = FREE_MODEL_DISPLAY_NAMES[:min_models] if len(FREE_MODEL_DISPLAY_NAMES) >= min_models else FREE_MODEL_DISPLAY_NAMES |
| | model_displays = st.multiselect( |
| | f"Models (select {min_models}+)", |
| | options=FREE_MODEL_DISPLAY_NAMES, |
| | default=default_models, |
| | key="classify_models_multi" |
| | ) |
| | models_list = [FREE_MODELS_MAP[d] for d in model_displays] |
| | api_key = "" |
| | else: |
| | default_models = PAID_MODEL_CHOICES[:min_models] if len(PAID_MODEL_CHOICES) >= min_models else PAID_MODEL_CHOICES |
| | models_list = st.multiselect( |
| | f"Models (select {min_models}+)", |
| | options=PAID_MODEL_CHOICES, |
| | default=default_models, |
| | key="classify_models_multi_paid" |
| | ) |
| | api_key = st.text_input("API Key", type="password", key="classify_api_key") |
| |
|
| | if models_list: |
| | st.markdown("**Model Temperature**") |
| | temp_cols = st.columns(len(models_list)) |
| | for idx, (col, m) in enumerate(zip(temp_cols, models_list)): |
| | short_name = m.split('/')[-1].split(':')[0][:20] |
| | model_temperatures[m] = col.number_input( |
| | short_name, |
| | min_value=0.0, |
| | max_value=2.0, |
| | value=0.0, |
| | step=0.25, |
| | key=f"temp_{idx}", |
| | help=f"Temperature for {m} (0 = deterministic, higher = more creative)" |
| | ) |
| | else: |
| | |
| | if model_tier == "Free Models": |
| | model_display = st.selectbox("Model", options=FREE_MODEL_DISPLAY_NAMES, key="classify_model") |
| | model = FREE_MODELS_MAP[model_display] |
| | models_list = [model] |
| | api_key = "" |
| | else: |
| | model = st.selectbox("Model", options=PAID_MODEL_CHOICES, key="classify_model_paid") |
| | models_list = [model] |
| | api_key = st.text_input("API Key", type="password", key="classify_api_key") |
| |
|
| | |
| | consensus_threshold = 0.5 |
| | if classify_mode == "Ensemble": |
| | consensus_options = { |
| | "Majority (50%+)": 0.5, |
| | "Two-Thirds (67%+)": 0.67, |
| | "Unanimous (100%)": 1.0, |
| | } |
| | consensus_choice = st.radio( |
| | "Consensus Rule", |
| | options=list(consensus_options.keys()), |
| | horizontal=True, |
| | key="consensus_choice", |
| | help="How many models must agree for a category to be marked present" |
| | ) |
| | consensus_threshold = consensus_options[consensus_choice] |
| |
|
| | if st.button("Categorize Data", type="primary", use_container_width=True): |
| | if input_data is None: |
| | st.error("Please upload data first") |
| | elif not categories_entered: |
| | st.error("Please enter at least one category") |
| | elif classify_mode == "Model Comparison" and len(models_list) < 2: |
| | st.error("Please select at least 2 models for comparison mode") |
| | elif classify_mode == "Ensemble" and len(models_list) < 3: |
| | st.error("Please select at least 3 models for ensemble mode (needed for majority voting)") |
| | else: |
| | |
| | mode = None |
| | if input_type_selected == "pdf": |
| | mode_mapping = { |
| | "Image (visual documents)": "image", |
| | "Text (text-heavy)": "text", |
| | "Both (comprehensive)": "both" |
| | } |
| | mode = mode_mapping.get(pdf_mode, "image") |
| |
|
| | |
| | |
| | models_tuples = [] |
| | api_key_error = None |
| | if ensemble_runs: |
| | |
| | for m, temp in ensemble_runs: |
| | actual_key, provider = get_api_key(m, model_tier, api_key) |
| | if not actual_key: |
| | api_key_error = f"{provider} API key not configured for {m}" |
| | break |
| | m_source = get_model_source(m) |
| | models_tuples.append((m, m_source, actual_key, {"creativity": temp})) |
| | else: |
| | for m in models_list: |
| | actual_key, provider = get_api_key(m, model_tier, api_key) |
| | if not actual_key: |
| | api_key_error = f"{provider} API key not configured for {m}" |
| | break |
| | m_source = get_model_source(m) |
| | temp = model_temperatures.get(m) |
| | if temp is not None and is_multi_model: |
| | models_tuples.append((m, m_source, actual_key, {"creativity": temp})) |
| | else: |
| | models_tuples.append((m, m_source, actual_key)) |
| |
|
| | if api_key_error: |
| | st.error(api_key_error) |
| | else: |
| | items_list = input_data if isinstance(input_data, list) else [input_data] |
| |
|
| | |
| | progress_bar = st.progress(0) |
| | status_text = st.empty() |
| | start_time = time.time() |
| |
|
| | |
| | if input_type_selected == "pdf": |
| | |
| | def pdf_progress_callback(current_idx, total_pages, page_label): |
| | progress = current_idx / total_pages if total_pages > 0 else 0 |
| | progress_bar.progress(min(progress, 1.0)) |
| |
|
| | elapsed = time.time() - start_time |
| | if current_idx > 0: |
| | avg_time = elapsed / current_idx |
| | eta_seconds = avg_time * (total_pages - current_idx) |
| | eta_str = f" | ETA: {eta_seconds:.0f}s" if eta_seconds < 60 else f" | ETA: {eta_seconds/60:.1f}m" |
| | else: |
| | eta_str = "" |
| |
|
| | status_text.text(f"Processing page {current_idx+1} of {total_pages} ({page_label}) ({progress*100:.0f}%){eta_str}") |
| |
|
| | try: |
| | |
| | classify_kwargs = { |
| | "input_data": items_list, |
| | "categories": categories_entered, |
| | "models": models_tuples, |
| | "description": description, |
| | "mode": mode, |
| | "progress_callback": pdf_progress_callback, |
| | } |
| | |
| | if classify_mode == "Ensemble": |
| | classify_kwargs["consensus_threshold"] = consensus_threshold |
| |
|
| | result_df = catvader.classify(**classify_kwargs) |
| |
|
| | processing_time = time.time() - start_time |
| | total_items = len(result_df) |
| | progress_bar.progress(1.0) |
| | status_text.text(f"Completed {total_items} pages in {processing_time:.1f}s") |
| |
|
| | |
| | if 'pdf_input' in result_df.columns: |
| | pdf_name_map = st.session_state.get('pdf_name_map', {}) |
| | def replace_temp_path(val): |
| | if pd.isna(val): |
| | return val |
| | val_str = str(val) |
| | for temp_path, orig_name in pdf_name_map.items(): |
| | |
| | temp_name = os.path.basename(temp_path).replace('.pdf', '') |
| | if temp_name in val_str: |
| | return val_str.replace(temp_name, orig_name) |
| | return val_str |
| | result_df['pdf_input'] = result_df['pdf_input'].apply(replace_temp_path) |
| |
|
| | all_results = [result_df] |
| |
|
| | except Exception as e: |
| | st.error(f"Error: {str(e)}") |
| | all_results = [] |
| |
|
| | else: |
| | |
| | total_items = len(items_list) |
| |
|
| | |
| | def item_progress_callback(current_idx, total, item_label): |
| | progress = current_idx / total if total > 0 else 0 |
| | progress_bar.progress(min(progress, 1.0)) |
| |
|
| | elapsed = time.time() - start_time |
| | if current_idx > 0: |
| | avg_time = elapsed / current_idx |
| | eta_seconds = avg_time * (total - current_idx) |
| | eta_str = f" | ETA: {eta_seconds:.0f}s" if eta_seconds < 60 else f" | ETA: {eta_seconds/60:.1f}m" |
| | else: |
| | eta_str = "" |
| |
|
| | status_text.text(f"Processing item {current_idx+1} of {total} ({progress*100:.0f}%){eta_str}") |
| |
|
| | try: |
| | |
| | classify_kwargs = { |
| | "input_data": items_list, |
| | "categories": categories_entered, |
| | "models": models_tuples, |
| | "description": description, |
| | "progress_callback": item_progress_callback, |
| | } |
| | |
| | if classify_mode == "Ensemble": |
| | classify_kwargs["consensus_threshold"] = consensus_threshold |
| |
|
| | result_df = catvader.classify(**classify_kwargs) |
| | all_results = [result_df] |
| |
|
| | processing_time = time.time() - start_time |
| | progress_bar.progress(1.0) |
| | status_text.text(f"Completed {total_items} items in {processing_time:.1f}s") |
| |
|
| | except Exception as e: |
| | st.error(f"Error: {str(e)}") |
| | all_results = [] |
| | processing_time = time.time() - start_time |
| |
|
| | if all_results: |
| | |
| | result_df = pd.concat(all_results, ignore_index=True) |
| |
|
| | |
| | if st.session_state.get("bluesky_df") is not None: |
| | bsky_eng = st.session_state.bluesky_df.reset_index(drop=True) |
| | if len(bsky_eng) == len(result_df): |
| | for col in ["post_id", "timestamp", "likes", "replies", "reposts", |
| | "media_type", "image_url", "post_length", |
| | "contains_url", "contains_image", "is_repost"]: |
| | if col in bsky_eng.columns: |
| | result_df[col] = bsky_eng[col].values |
| |
|
| | |
| | with tempfile.NamedTemporaryFile(mode='w', suffix='_classified.csv', delete=False) as f: |
| | result_df.to_csv(f.name, index=False) |
| | csv_path = f.name |
| |
|
| | |
| | if 'processing_status' in result_df.columns: |
| | success_count = (result_df['processing_status'] == 'success').sum() |
| | success_rate = (success_count / len(result_df)) * 100 |
| | else: |
| | success_rate = 100.0 |
| |
|
| | |
| | try: |
| | catvader_version = catvader.__version__ |
| | except AttributeError: |
| | catvader_version = "unknown" |
| | python_version = sys.version.split()[0] |
| |
|
| | |
| | if len(models_list) == 1: |
| | report_model = models_list[0] |
| | report_model_source = models_tuples[0][1] |
| | else: |
| | report_model = ", ".join(models_list) |
| | report_model_source = f"{classify_mode} ({len(models_list)} models)" |
| |
|
| | |
| | |
| | if st.session_state.extraction_params: |
| | classify_params = { |
| | 'model': report_model, |
| | 'description': description, |
| | 'mode': mode, |
| | 'classify_mode': classify_mode, |
| | 'models_list': models_list, |
| | 'consensus_threshold': consensus_threshold, |
| | 'model_temperatures': model_temperatures, |
| | 'ensemble_runs': ensemble_runs if ensemble_runs else None, |
| | } |
| | code = generate_full_code(st.session_state.extraction_params, classify_params) |
| | else: |
| | code = generate_classify_code( |
| | input_type_selected, description, categories_entered, |
| | report_model, report_model_source, mode, |
| | classify_mode=classify_mode, models_list=models_list, |
| | consensus_threshold=consensus_threshold, |
| | model_temperatures=model_temperatures, |
| | ensemble_runs=ensemble_runs if ensemble_runs else None, |
| | ) |
| |
|
| | |
| | pdf_path = generate_methodology_report_pdf( |
| | categories=categories_entered, |
| | model=report_model, |
| | column_name=description, |
| | num_rows=len(result_df), |
| | model_source=report_model_source, |
| | filename=original_filename, |
| | success_rate=success_rate, |
| | result_df=result_df, |
| | processing_time=processing_time, |
| | catvader_version=catvader_version, |
| | python_version=python_version, |
| | task_type="assign", |
| | input_type=input_type_selected, |
| | description=description, |
| | classify_mode=classify_mode, |
| | models_list=models_list, |
| | code=code, |
| | consensus_threshold=consensus_threshold if classify_mode == "Ensemble" else None, |
| | ) |
| |
|
| | st.session_state.results = { |
| | 'df': result_df, |
| | 'csv_path': csv_path, |
| | 'pdf_path': pdf_path, |
| | 'code': code, |
| | 'status': f"Classified {len(result_df)} items in {processing_time:.1f}s", |
| | 'categories': categories_entered, |
| | 'classify_mode': classify_mode, |
| | 'models_list': models_list, |
| | 'model_temperatures': model_temperatures, |
| | 'ensemble_runs': ensemble_runs if ensemble_runs else None, |
| | } |
| | st.success(f"Classified {len(result_df)} items in {processing_time:.1f}s") |
| | st.rerun() |
| | else: |
| | st.error("No items were successfully classified") |
| |
|
| | with col_output: |
| | st.markdown("### Results") |
| |
|
| | if st.session_state.results: |
| | results = st.session_state.results |
| |
|
| | |
| | viz_type = st.selectbox( |
| | "Visualization", |
| | options=["Category Distribution", "Classification Matrix"], |
| | key="viz_type", |
| | help="Distribution shows category percentages. Matrix shows each response's classifications." |
| | ) |
| |
|
| | if viz_type == "Category Distribution": |
| | fig = create_distribution_chart( |
| | results['df'], |
| | results['categories'], |
| | classify_mode=results.get('classify_mode', 'Single Model'), |
| | models_list=results.get('models_list', []) |
| | ) |
| | st.pyplot(fig) |
| | st.caption("Note: Categories are not mutually exclusive—each item can belong to multiple categories.") |
| | else: |
| | fig = create_classification_heatmap( |
| | results['df'], |
| | results['categories'], |
| | classify_mode=results.get('classify_mode', 'Single Model'), |
| | models_list=results.get('models_list', []) |
| | ) |
| | st.pyplot(fig) |
| | st.caption("Green = category present, Black = not present. Each row is one response.") |
| |
|
| | |
| | display_df = results['df'].copy() |
| | cols_to_hide = ['model_response', 'json', 'raw_response', 'raw_json'] |
| | display_df = display_df.drop(columns=[c for c in cols_to_hide if c in display_df.columns]) |
| | st.dataframe(display_df, use_container_width=True) |
| |
|
| | |
| | col_dl1, col_dl2, col_dl3 = st.columns(3) |
| | with col_dl1: |
| | with open(results['csv_path'], 'rb') as f: |
| | st.download_button( |
| | "Download CSV", |
| | data=f, |
| | file_name="classified_results.csv", |
| | mime="text/csv" |
| | ) |
| | with col_dl2: |
| | with open(results['pdf_path'], 'rb') as f: |
| | st.download_button( |
| | "Download Report", |
| | data=f, |
| | file_name="methodology_report.pdf", |
| | mime="application/pdf" |
| | ) |
| | with col_dl3: |
| | |
| | import io |
| | from matplotlib.backends.backend_pdf import PdfPages |
| |
|
| | plot_buffer = io.BytesIO() |
| | with PdfPages(plot_buffer) as pdf: |
| | |
| | fig1 = create_distribution_chart( |
| | results['df'], |
| | results['categories'], |
| | classify_mode=results.get('classify_mode', 'Single Model'), |
| | models_list=results.get('models_list', []) |
| | ) |
| | pdf.savefig(fig1, bbox_inches='tight') |
| | plt.close(fig1) |
| |
|
| | |
| | fig2 = create_classification_heatmap( |
| | results['df'], |
| | results['categories'], |
| | classify_mode=results.get('classify_mode', 'Single Model'), |
| | models_list=results.get('models_list', []) |
| | ) |
| | pdf.savefig(fig2, bbox_inches='tight') |
| | plt.close(fig2) |
| |
|
| | plot_buffer.seek(0) |
| | st.download_button( |
| | "Download Plots", |
| | data=plot_buffer, |
| | file_name="classification_plots.pdf", |
| | mime="application/pdf" |
| | ) |
| |
|
| | |
| | with st.expander("See the Code"): |
| | st.code(results['code'], language='python') |
| | else: |
| | st.info("Upload data, select categories, and click 'Categorize Data' to see results here.") |
| |
|
| | |
| | col_reset, col_code = st.columns(2) |
| | with col_reset: |
| | if st.button("Reset", type="secondary", use_container_width=True): |
| | st.session_state.categories = [''] * MAX_CATEGORIES |
| | st.session_state.category_count = INITIAL_CATEGORIES |
| | st.session_state.task_mode = None |
| | st.session_state.extracted_categories = None |
| | st.session_state.extraction_params = None |
| | st.session_state.results = None |
| | if hasattr(st.session_state, 'example_loaded'): |
| | del st.session_state.example_loaded |
| | st.rerun() |
| |
|
| | with col_code: |
| | if st.button("See in Code", use_container_width=True): |
| | st.session_state.show_code_modal = True |
| |
|
| | |
| | if st.session_state.get('show_code_modal'): |
| | st.markdown("---") |
| | st.markdown("### Reproducibility Code") |
| | st.markdown("Use this code to reproduce the classification with the CatVader Python package:") |
| |
|
| | |
| | if st.session_state.results: |
| | code_to_show = st.session_state.results['code'] |
| | else: |
| | |
| | current_categories = [c for c in st.session_state.categories[:st.session_state.category_count] if c.strip()] |
| |
|
| | |
| | input_type_map = {"Social Media Posts": "text", "PDF Documents": "pdf", "Images": "image"} |
| | current_input_type = input_type_map.get(st.session_state.get('input_type_radio', 'Social Media Posts'), 'text') |
| | current_description = st.session_state.get('survey_column', '') or st.session_state.get('pdf_desc', '') or st.session_state.get('image_desc', '') or 'your_data' |
| |
|
| | |
| | current_classify_mode = st.session_state.get('classify_mode', 'Single Model') |
| | current_model_tier = st.session_state.get('classify_model_tier', 'Free Models') |
| |
|
| | if current_classify_mode in ["Model Comparison", "Ensemble"]: |
| | |
| | if current_model_tier == 'Free Models': |
| | model_displays = st.session_state.get('classify_models_multi', []) |
| | current_models_list = [FREE_MODELS_MAP.get(d, d) for d in model_displays] |
| | else: |
| | current_models_list = st.session_state.get('classify_models_multi_paid', []) |
| | current_model = ", ".join(current_models_list) if current_models_list else "gpt-4o-mini" |
| | current_model_source = f"{current_classify_mode} ({len(current_models_list)} models)" |
| | else: |
| | |
| | if current_model_tier == 'Free Models': |
| | model_display = st.session_state.get('classify_model', 'GPT-4o Mini') |
| | current_model = FREE_MODELS_MAP.get(model_display, 'gpt-4o-mini') |
| | else: |
| | current_model = st.session_state.get('classify_model_paid', 'gpt-4o-mini') |
| | current_models_list = [current_model] |
| | current_model_source = get_model_source(current_model) |
| |
|
| | |
| | consensus_options = {"Majority (50%+)": 0.5, "Two-Thirds (67%+)": 0.67, "Unanimous (100%)": 1.0} |
| | current_consensus = consensus_options.get(st.session_state.get('consensus_choice', 'Majority (50%+)'), 0.5) |
| |
|
| | |
| | current_mode = None |
| | if current_input_type == "pdf": |
| | mode_mapping = { |
| | "Image (visual documents)": "image", |
| | "Text (text-heavy)": "text", |
| | "Both (comprehensive)": "both" |
| | } |
| | current_mode = mode_mapping.get(st.session_state.get('pdf_mode', 'Image (visual documents)'), 'image') |
| |
|
| | if current_categories: |
| | |
| | if st.session_state.extraction_params: |
| | current_temperatures = results.get('model_temperatures', {}) |
| | classify_params = { |
| | 'model': current_model, |
| | 'description': current_description, |
| | 'mode': current_mode, |
| | 'classify_mode': current_classify_mode, |
| | 'models_list': current_models_list, |
| | 'consensus_threshold': current_consensus, |
| | 'model_temperatures': current_temperatures, |
| | 'ensemble_runs': results.get('ensemble_runs'), |
| | } |
| | code_to_show = generate_full_code(st.session_state.extraction_params, classify_params) |
| | else: |
| | current_temperatures = results.get('model_temperatures', {}) |
| | code_to_show = generate_classify_code( |
| | current_input_type, current_description, current_categories, |
| | current_model, current_model_source, current_mode, |
| | classify_mode=current_classify_mode, models_list=current_models_list, |
| | consensus_threshold=current_consensus, |
| | model_temperatures=current_temperatures, |
| | ensemble_runs=results.get('ensemble_runs'), |
| | ) |
| | else: |
| | code_to_show = '''import catvader |
| | |
| | # Define your categories |
| | categories = [ |
| | "Category 1", |
| | "Category 2", |
| | # Add more categories... |
| | ] |
| | |
| | # Classify your data |
| | result = catvader.classify( |
| | input_data=df["your_column"].tolist(), |
| | categories=categories, |
| | api_key="YOUR_API_KEY", |
| | description="your_description", |
| | user_model="gpt-4o-mini" |
| | ) |
| | |
| | # View results |
| | print(result) |
| | result.to_csv("classified_results.csv", index=False) |
| | ''' |
| |
|
| | st.code(code_to_show, language='python') |
| | if st.button("Close"): |
| | st.session_state.show_code_modal = False |
| | st.rerun() |
| |
|