""" Streamlit app - CatLLM Survey Response Classifier Migrated from Gradio for better mobile support """ import streamlit as st import pandas as pd import tempfile import os import time import sys from datetime import datetime import matplotlib.pyplot as plt # Import catllm try: import catllm CATLLM_AVAILABLE = True except ImportError as e: print(f"Warning: Could not import catllm: {e}") CATLLM_AVAILABLE = False MAX_CATEGORIES = 10 INITIAL_CATEGORIES = 3 MAX_FILE_SIZE_MB = 100 def count_pdf_pages(pdf_path): """Count the number of pages in a PDF file.""" try: import fitz # PyMuPDF doc = fitz.open(pdf_path) page_count = len(doc) doc.close() return page_count except Exception: return 1 # Default to 1 if can't read def extract_text_from_pdfs(pdf_paths): """Extract text from all pages of all PDFs, returning list of page texts.""" import fitz # PyMuPDF all_texts = [] for pdf_path in pdf_paths: try: doc = fitz.open(pdf_path) for page in doc: text = page.get_text().strip() if text: # Only add non-empty pages all_texts.append(text) doc.close() except Exception as e: print(f"Error extracting text from {pdf_path}: {e}") return all_texts def extract_pdf_pages(pdf_paths, pdf_name_map, mode="image"): """ Extract individual pages from PDFs. Returns list of (page_data, page_label) tuples. For image mode: page_data is path to temp image file For text mode: page_data is extracted text """ import fitz # PyMuPDF pages = [] for pdf_path in pdf_paths: orig_name = pdf_name_map.get(pdf_path, os.path.basename(pdf_path).replace('.pdf', '')) try: doc = fitz.open(pdf_path) for page_num, page in enumerate(doc, 1): page_label = f"{orig_name}_p{page_num}" if mode == "text": # Extract text text = page.get_text().strip() if text: pages.append((text, page_label, "text")) else: # Render as image (for image or both mode) pix = page.get_pixmap(matrix=fitz.Matrix(2, 2)) # 2x zoom for better quality img_path = tempfile.NamedTemporaryFile(delete=False, suffix='.png').name pix.save(img_path) if mode == "both": text = page.get_text().strip() pages.append((img_path, page_label, "image", text)) else: pages.append((img_path, page_label, "image")) doc.close() except Exception as e: print(f"Error extracting pages from {pdf_path}: {e}") return pages # Free models - display name -> actual API model name FREE_MODELS_MAP = { "GPT-4o Mini": "gpt-4o-mini", "Gemini 2.5 Flash": "gemini-2.5-flash", "Claude 3 Haiku": "claude-3-haiku-20240307", "Llama 3.3 70B": "meta-llama/Llama-3.3-70B-Instruct:groq", "Qwen 2.5": "Qwen/Qwen2.5-72B-Instruct", "DeepSeek R1": "deepseek-ai/DeepSeek-R1:novita", "Mistral Medium": "mistral-medium-2505", "Grok 4 Fast": "grok-4-fast-non-reasoning", } FREE_MODEL_DISPLAY_NAMES = list(FREE_MODELS_MAP.keys()) FREE_MODEL_CHOICES = list(FREE_MODELS_MAP.values()) # Keep for backward compat # Paid models (user provides their own API key) PAID_MODEL_CHOICES = [ "gemini-2.5-flash", "gemini-2.5-pro", "gpt-4.1", "gpt-4o", "gpt-4o-mini", "claude-sonnet-4-5-20250929", "claude-opus-4-20250514", "claude-3-5-haiku-20241022", "mistral-large-latest", ] # Models routed through HuggingFace HF_ROUTED_MODELS = [ "meta-llama/Llama-3.3-70B-Instruct:groq", "deepseek-ai/DeepSeek-R1:novita", ] def is_free_model(model, model_tier): """Check if using free tier (Space pays for API).""" return model_tier == "Free Models" def get_model_source(model): """Auto-detect model source.""" model_lower = model.lower() if "gpt" in model_lower: return "openai" elif "claude" in model_lower: return "anthropic" elif "gemini" in model_lower: return "google" elif "mistral" in model_lower and ":novita" not in model_lower: return "mistral" elif any(x in model_lower for x in [":novita", ":groq", "qwen", "llama", "deepseek"]): return "huggingface" elif "sonar" in model_lower: return "perplexity" elif "grok" in model_lower: return "xai" return "huggingface" def get_api_key(model, model_tier, api_key_input): """Get the appropriate API key based on model and tier.""" if is_free_model(model, model_tier): if model in HF_ROUTED_MODELS: return os.environ.get("HF_API_KEY", ""), "HuggingFace" elif "gpt" in model.lower(): return os.environ.get("OPENAI_API_KEY", ""), "OpenAI" elif "gemini" in model.lower(): return os.environ.get("GOOGLE_API_KEY", ""), "Google" elif "mistral" in model.lower(): return os.environ.get("MISTRAL_API_KEY", ""), "Mistral" elif "claude" in model.lower(): return os.environ.get("ANTHROPIC_API_KEY", ""), "Anthropic" elif "sonar" in model.lower(): return os.environ.get("PERPLEXITY_API_KEY", ""), "Perplexity" elif "grok" in model.lower(): return os.environ.get("XAI_API_KEY", ""), "xAI" else: return os.environ.get("HF_API_KEY", ""), "HuggingFace" else: if api_key_input and api_key_input.strip(): return api_key_input.strip(), "User" return "", "User" def calculate_total_file_size(files): """Calculate total size of uploaded files in MB.""" if files is None: return 0 if not isinstance(files, list): files = [files] total_bytes = 0 for f in files: try: if hasattr(f, 'size'): total_bytes += f.size elif hasattr(f, 'name'): total_bytes += os.path.getsize(f.name) except (OSError, AttributeError): pass return total_bytes / (1024 * 1024) def generate_extract_code(input_type, description, model, model_source, max_categories, mode=None): """Generate Python code for category extraction.""" if input_type == "text": return f'''import catllm import pandas as pd # Load your data df = pd.read_csv("your_data.csv") # Extract categories from the text column result = catllm.extract( input_data=df["{description}"].tolist(), api_key="YOUR_API_KEY", input_type="text", description="{description}", user_model="{model}", model_source="{model_source}", max_categories={max_categories} ) # View extracted categories print(result["top_categories"]) print(result["counts_df"]) ''' elif input_type == "pdf": mode_line = f',\n mode="{mode}"' if mode else '' return f'''import catllm # Extract categories from PDF documents result = catllm.extract( input_data="path/to/your/pdfs/", api_key="YOUR_API_KEY", input_type="pdf", description="{description}"{mode_line}, user_model="{model}", model_source="{model_source}", max_categories={max_categories} ) # View extracted categories print(result["top_categories"]) print(result["counts_df"]) ''' else: # image return f'''import catllm # Extract categories from images result = catllm.extract( input_data="path/to/your/images/", api_key="YOUR_API_KEY", input_type="image", description="{description}", user_model="{model}", model_source="{model_source}", max_categories={max_categories} ) # View extracted categories print(result["top_categories"]) print(result["counts_df"]) ''' def generate_full_code(extraction_params, classify_params): """Generate combined extract + classify code when categories were auto-extracted.""" ext = extraction_params cls = classify_params # Determine input data placeholder if ext['input_type'] == "text": input_placeholder = 'df["your_column"].tolist()' load_data = '''import pandas as pd # Load your data df = pd.read_csv("your_data.csv") ''' elif ext['input_type'] == "pdf": input_placeholder = '"path/to/your/pdfs/"' load_data = '' else: input_placeholder = '"path/to/your/images/"' load_data = '' mode_param = f',\n mode="{ext["mode"]}"' if ext.get('mode') else '' # Build extract code extract_code = f'''# Step 1: Extract categories from your data extract_result = catllm.extract( input_data={input_placeholder}, api_key="YOUR_API_KEY", description="{ext['description']}", user_model="{ext['model']}", max_categories={ext['max_categories']}{mode_param} ) categories = extract_result["top_categories"] print(f"Extracted {{len(categories)}} categories: {{categories}}") ''' # Build classify code based on mode if cls['classify_mode'] == "Single Model": classify_mode_param = f',\n mode="{cls["mode"]}"' if cls.get('mode') and ext['input_type'] == "pdf" else '' classify_code = f''' # Step 2: Classify data using extracted categories result = catllm.classify( input_data={input_placeholder}, categories=categories, api_key="YOUR_API_KEY", description="{cls['description']}", user_model="{cls['model']}"{classify_mode_param} )''' else: # Multi-model mode — include per-model temperatures when set ens_runs = cls.get('ensemble_runs') model_lines = [] if ens_runs: for m, temp in ens_runs: model_lines.append(f'("{m}", "auto", "YOUR_API_KEY", {{"creativity": {temp}}})') else: model_temps = cls.get('model_temperatures', {}) for m in cls['models_list']: temp = model_temps.get(m) if model_temps else None if temp is not None: model_lines.append(f'("{m}", "auto", "YOUR_API_KEY", {{"creativity": {temp}}})') else: model_lines.append(f'("{m}", "auto", "YOUR_API_KEY")') models_str = ",\n ".join(model_lines) classify_mode_param = f',\n mode="{cls["mode"]}"' if cls.get('mode') and ext['input_type'] == "pdf" else '' threshold_str = "majority" if cls['consensus_threshold'] == 0.5 else "two-thirds" if cls['consensus_threshold'] == 0.67 else "unanimous" consensus_param = f',\n consensus_threshold="{threshold_str}"' if cls['classify_mode'] == "Ensemble" else '' classify_code = f''' # Step 2: Classify data using extracted categories with {"ensemble voting" if cls['classify_mode'] == "Ensemble" else "model comparison"} models = [ {models_str} ] result = catllm.classify( input_data={input_placeholder}, categories=categories, models=models, description="{cls['description']}"{classify_mode_param}{consensus_param} )''' return f'''import catllm {load_data} {extract_code} {classify_code} # View results print(result) result.to_csv("classified_results.csv", index=False) ''' def generate_classify_code(input_type, description, categories, model, model_source, mode=None, classify_mode="Single Model", models_list=None, consensus_threshold=0.5, model_temperatures=None, ensemble_runs=None): """Generate Python code for classification.""" categories_str = ",\n ".join([f'"{cat}"' for cat in categories]) # Determine input data placeholder based on type if input_type == "text": input_placeholder = 'df["your_column"].tolist()' load_data = '''import pandas as pd # Load your data df = pd.read_csv("your_data.csv") ''' elif input_type == "pdf": input_placeholder = '"path/to/your/pdfs/"' load_data = '' else: # image input_placeholder = '"path/to/your/images/"' load_data = '' # Generate code based on classification mode if classify_mode == "Single Model": # Single model mode mode_param = f',\n mode="{mode}"' if mode and input_type == "pdf" else '' return f'''import catllm {load_data} # Define categories categories = [ {categories_str} ] # Classify data (input type is auto-detected) result = catllm.classify( input_data={input_placeholder}, categories=categories, api_key="YOUR_API_KEY", description="{description}", user_model="{model}"{mode_param} ) # View results print(result) result.to_csv("classified_results.csv", index=False) ''' else: # Multi-model mode (Comparison or Ensemble) # Build model tuples with per-model temperature when set if ensemble_runs: # Ensemble with explicit (model, temp) pairs (supports duplicate models) model_lines = [] for m, temp in ensemble_runs: model_lines.append(f'("{m}", "auto", "YOUR_API_KEY", {{"creativity": {temp}}})') models_str = ",\n ".join(model_lines) elif models_list: model_lines = [] for m in models_list: temp = model_temperatures.get(m) if model_temperatures else None if temp is not None: model_lines.append(f'("{m}", "auto", "YOUR_API_KEY", {{"creativity": {temp}}})') else: model_lines.append(f'("{m}", "auto", "YOUR_API_KEY")') models_str = ",\n ".join(model_lines) else: models_str = '("gpt-4o", "auto", "YOUR_API_KEY"),\n ("claude-sonnet-4-5-20250929", "auto", "YOUR_API_KEY")' mode_param = f',\n mode="{mode}"' if mode and input_type == "pdf" else '' # Map numeric threshold back to string for cleaner code threshold_str = "majority" if consensus_threshold == 0.5 else "two-thirds" if consensus_threshold == 0.67 else "unanimous" consensus_param = f',\n consensus_threshold="{threshold_str}"' if classify_mode == "Ensemble" else '' return f'''import catllm {load_data} # Define categories categories = [ {categories_str} ] # Define models for {"ensemble voting" if classify_mode == "Ensemble" else "comparison"} models = [ {models_str} ] # Classify with multiple models result = catllm.classify( input_data={input_placeholder}, categories=categories, models=models, description="{description}"{mode_param}{consensus_param} ) # View results print(result) result.to_csv("classified_results.csv", index=False) ''' def generate_methodology_report_pdf(categories, model, column_name, num_rows, model_source, filename, success_rate, result_df=None, processing_time=None, prompt_template=None, data_quality=None, catllm_version=None, python_version=None, task_type="assign", extracted_categories_df=None, max_categories=None, input_type="text", description=None, classify_mode="Single Model", models_list=None, code=None, consensus_threshold=None): """Generate a PDF methodology report.""" from reportlab.lib.pagesizes import letter from reportlab.lib import colors from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer, Table, TableStyle, PageBreak pdf_file = tempfile.NamedTemporaryFile(mode='wb', suffix='_methodology_report.pdf', delete=False) doc = SimpleDocTemplate(pdf_file.name, pagesize=letter) styles = getSampleStyleSheet() title_style = ParagraphStyle('Title', parent=styles['Heading1'], fontSize=18, spaceAfter=20) heading_style = ParagraphStyle('Heading', parent=styles['Heading2'], fontSize=14, spaceAfter=10, spaceBefore=15) normal_style = styles['Normal'] code_style = ParagraphStyle('Code', parent=styles['Normal'], fontName='Courier', fontSize=9, leftIndent=20, spaceAfter=3) story = [] if task_type == "extract_and_assign": report_title = "CatLLM Extraction & Classification Report" else: report_title = "CatLLM Classification Report" story.append(Paragraph(report_title, title_style)) story.append(Paragraph(f"Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", normal_style)) story.append(Spacer(1, 15)) story.append(Paragraph("About This Report", heading_style)) if task_type == "extract_and_assign": about_text = """This methodology report documents the automated category extraction and classification process. \ CatLLM first discovers categories from your data using LLMs, then classifies each item into those categories.""" else: about_text = """This methodology report documents the classification process for reproducibility and transparency. \ CatLLM restricts the prompt to a standard template that is impartial to the researcher's inclinations, ensuring \ consistent and reproducible results.""" story.append(Paragraph(about_text, normal_style)) story.append(Spacer(1, 15)) if categories: story.append(Paragraph("Category Mapping", heading_style)) if classify_mode in ("Ensemble", "Model Comparison") and result_df is not None: # Multi-model: show per-model columns and consensus columns story.append(Paragraph("Each model produces its own binary columns. " "Consensus columns show the majority vote result.", normal_style)) story.append(Spacer(1, 8)) # Detect ALL distinct model suffixes directly from the DataFrame # (handles same-model-different-temperature cases correctly) all_suffixes = _find_all_model_suffixes(result_df) category_data = [["Column Name", "Category Description"]] for i, cat in enumerate(categories, 1): # Per-model columns (each suffix is a unique model/temperature) for suffix in all_suffixes: category_data.append([f"category_{i}_{suffix}", f"{cat} ({suffix})"]) # Consensus + agreement columns category_data.append([f"category_{i}_consensus", f"{cat} (consensus)"]) category_data.append([f"category_{i}_agreement", f"{cat} (agreement score)"]) cat_table = Table(category_data, colWidths=[200, 250]) else: # Single model: simple mapping story.append(Paragraph("Each category column contains binary values: 1 = present, 0 = not present", normal_style)) story.append(Spacer(1, 8)) category_data = [["Column Name", "Category Description"]] for i, cat in enumerate(categories, 1): category_data.append([f"category_{i}", cat]) cat_table = Table(category_data, colWidths=[120, 330]) cat_table.setStyle(TableStyle([ ('BACKGROUND', (0, 0), (-1, 0), colors.grey), ('TEXTCOLOR', (0, 0), (-1, 0), colors.whitesmoke), ('GRID', (0, 0), (-1, -1), 1, colors.black), ('PADDING', (0, 0), (-1, -1), 6), ('BACKGROUND', (0, 1), (0, -1), colors.lightgrey), ('FONTSIZE', (0, 0), (-1, -1), 9), ])) story.append(cat_table) story.append(Spacer(1, 15)) story.append(Spacer(1, 30)) story.append(Paragraph("Citation", heading_style)) story.append(Paragraph("If you use CatLLM in your research, please cite:", normal_style)) story.append(Spacer(1, 5)) story.append(Paragraph("Soria, C. (2025). CatLLM: A Python package for LLM-based text classification. DOI: 10.5281/zenodo.15532316", normal_style)) # Summary section story.append(PageBreak()) story.append(Paragraph("Classification Summary", title_style)) story.append(Spacer(1, 15)) summary_data = [ ["Source File", filename], ["Source Column", column_name], ["Classification Mode", classify_mode], ["Model(s) Used", model], ["Model Source", model_source], ["Rows Classified", str(num_rows)], ["Number of Categories", str(len(categories)) if categories else "0"], ["Success Rate", f"{success_rate:.2f}%"], ] # Add consensus threshold for ensemble mode if classify_mode == "Ensemble" and consensus_threshold is not None: threshold_labels = {0.5: "Majority (50%+)", 0.67: "Two-Thirds (67%+)", 1.0: "Unanimous (100%)"} threshold_label = threshold_labels.get(consensus_threshold, f"Custom ({consensus_threshold:.0%})") summary_data.append(["Consensus Threshold", threshold_label]) summary_table = Table(summary_data, colWidths=[150, 300]) summary_table.setStyle(TableStyle([ ('BACKGROUND', (0, 0), (0, -1), colors.lightgrey), ('GRID', (0, 0), (-1, -1), 1, colors.black), ('PADDING', (0, 0), (-1, -1), 6), ('FONTSIZE', (0, 0), (-1, -1), 9), ])) story.append(summary_table) story.append(Spacer(1, 15)) # Agreement scores table for ensemble mode if classify_mode == "Ensemble" and result_df is not None and categories: agreement_cols = [f"category_{i}_agreement" for i in range(1, len(categories) + 1)] has_agreement = all(col in result_df.columns for col in agreement_cols) if has_agreement: story.append(Paragraph("Ensemble Agreement Scores", heading_style)) story.append(Paragraph( "Agreement shows what proportion of models agreed on each category. " "Higher scores indicate stronger consensus.", normal_style)) story.append(Spacer(1, 8)) agree_data = [["Category", "Mean Agreement", "Min Agreement"]] for i, cat in enumerate(categories, 1): col = f"category_{i}_agreement" mean_val = result_df[col].mean() min_val = result_df[col].min() agree_data.append([cat, f"{mean_val:.1%}", f"{min_val:.1%}"]) agree_table = Table(agree_data, colWidths=[200, 125, 125]) agree_table.setStyle(TableStyle([ ('BACKGROUND', (0, 0), (-1, 0), colors.grey), ('TEXTCOLOR', (0, 0), (-1, 0), colors.whitesmoke), ('GRID', (0, 0), (-1, -1), 1, colors.black), ('PADDING', (0, 0), (-1, -1), 6), ('FONTSIZE', (0, 0), (-1, -1), 9), ])) story.append(agree_table) story.append(Spacer(1, 15)) if processing_time is not None: story.append(Paragraph("Processing Time", heading_style)) rows_per_min = (num_rows / processing_time) * 60 if processing_time > 0 else 0 avg_time = processing_time / num_rows if num_rows > 0 else 0 time_data = [ ["Total Processing Time", f"{processing_time:.1f} seconds"], ["Average Time per Response", f"{avg_time:.2f} seconds"], ["Processing Rate", f"{rows_per_min:.1f} rows/minute"], ] time_table = Table(time_data, colWidths=[180, 270]) time_table.setStyle(TableStyle([ ('BACKGROUND', (0, 0), (0, -1), colors.lightgrey), ('GRID', (0, 0), (-1, -1), 1, colors.black), ('PADDING', (0, 0), (-1, -1), 6), ('FONTSIZE', (0, 0), (-1, -1), 9), ])) story.append(time_table) story.append(Spacer(1, 15)) story.append(Paragraph("Version Information", heading_style)) version_data = [ ["CatLLM Version", catllm_version or "unknown"], ["Python Version", python_version or "unknown"], ["Timestamp", datetime.now().strftime('%Y-%m-%d %H:%M:%S')], ] version_table = Table(version_data, colWidths=[180, 270]) version_table.setStyle(TableStyle([ ('BACKGROUND', (0, 0), (0, -1), colors.lightgrey), ('GRID', (0, 0), (-1, -1), 1, colors.black), ('PADDING', (0, 0), (-1, -1), 6), ('FONTSIZE', (0, 0), (-1, -1), 9), ])) story.append(version_table) # Reproducibility Code section if code: story.append(PageBreak()) story.append(Paragraph("Reproducibility Code", title_style)) story.append(Paragraph("Use this Python code to reproduce the classification with the CatLLM package:", normal_style)) story.append(Spacer(1, 10)) # Split code into lines and add as code-formatted paragraphs for line in code.strip().split('\n'): # Escape special characters for reportlab escaped_line = line.replace('&', '&').replace('<', '<').replace('>', '>') if escaped_line.strip(): story.append(Paragraph(escaped_line, code_style)) else: story.append(Spacer(1, 6)) # Visualizations section if result_df is not None and categories: from reportlab.platypus import Image import io # Distribution chart (new page) story.append(PageBreak()) story.append(Paragraph("Category Distribution", title_style)) try: fig1 = create_distribution_chart(result_df, categories, classify_mode, models_list) img_buffer1 = io.BytesIO() fig1.savefig(img_buffer1, format='png', dpi=150, bbox_inches='tight') img_buffer1.seek(0) plt.close(fig1) # Save to temp file for reportlab img_temp1 = tempfile.NamedTemporaryFile(delete=False, suffix='.png') img_temp1.write(img_buffer1.read()) img_temp1.close() img1 = Image(img_temp1.name, width=450, height=250) story.append(img1) story.append(Spacer(1, 10)) story.append(Paragraph("Note: Categories are not mutually exclusive—each item can belong to multiple categories.", normal_style)) except Exception as e: story.append(Paragraph(f"Could not generate distribution chart: {str(e)}", normal_style)) # Classification matrix (new page) story.append(PageBreak()) story.append(Paragraph("Classification Matrix", title_style)) try: fig2 = create_classification_heatmap(result_df, categories, classify_mode, models_list) img_buffer2 = io.BytesIO() fig2.savefig(img_buffer2, format='png', dpi=150, bbox_inches='tight') img_buffer2.seek(0) plt.close(fig2) # Save to temp file for reportlab img_temp2 = tempfile.NamedTemporaryFile(delete=False, suffix='.png') img_temp2.write(img_buffer2.read()) img_temp2.close() img2 = Image(img_temp2.name, width=450, height=300) story.append(img2) story.append(Spacer(1, 10)) story.append(Paragraph("Orange = category present, Black = not present. Each row represents one response.", normal_style)) except Exception as e: story.append(Paragraph(f"Could not generate classification matrix: {str(e)}", normal_style)) doc.build(story) return pdf_file.name def run_auto_extract(input_type, input_data, description, max_categories_val, model_tier, model, api_key_input, mode=None, progress_callback=None): """Extract categories from data.""" if not CATLLM_AVAILABLE: return None, "catllm package not available" actual_api_key, provider = get_api_key(model, model_tier, api_key_input) if not actual_api_key: return None, f"{provider} API key not configured" model_source = get_model_source(model) try: if isinstance(input_data, list): num_items = len(input_data) else: num_items = 1 if input_type == "image": divisions = min(3, max(1, num_items // 5)) categories_per_chunk = 12 else: divisions = max(1, num_items // 15) divisions = min(divisions, 5) chunk_size = num_items // max(1, divisions) categories_per_chunk = min(10, chunk_size - 1) extract_kwargs = { 'input_data': input_data, 'api_key': actual_api_key, 'input_type': input_type, 'description': description, 'user_model': model, 'model_source': model_source, 'divisions': divisions, 'categories_per_chunk': categories_per_chunk, 'max_categories': int(max_categories_val) } if mode: extract_kwargs['mode'] = mode extract_result = catllm.extract(**extract_kwargs) categories = extract_result.get('top_categories', []) if not categories: return None, "No categories were extracted" return categories, f"Extracted {len(categories)} categories successfully!" except Exception as e: return None, f"Error: {str(e)}" def run_classify_data(input_type, input_data, description, categories, model_tier, model, api_key_input, mode=None, original_filename="data", column_name="text", progress_callback=None): """Classify data with user-provided categories.""" if not CATLLM_AVAILABLE: return None, None, None, None, "catllm package not available" if not categories: return None, None, None, None, "Please enter at least one category" actual_api_key, provider = get_api_key(model, model_tier, api_key_input) if not actual_api_key: return None, None, None, None, f"{provider} API key not configured" model_source = get_model_source(model) try: start_time = time.time() classify_kwargs = { 'input_data': input_data, 'categories': categories, 'models': [(model, model_source, actual_api_key)], 'description': description, } if mode: classify_kwargs['mode'] = mode result = catllm.classify(**classify_kwargs) processing_time = time.time() - start_time num_items = len(result) # Save CSV with tempfile.NamedTemporaryFile(mode='w', suffix='_classified.csv', delete=False) as f: result.to_csv(f.name, index=False) csv_path = f.name # Calculate success rate if 'processing_status' in result.columns: success_count = (result['processing_status'] == 'success').sum() success_rate = (success_count / len(result)) * 100 else: success_rate = 100.0 # Get version info try: catllm_version = catllm.__version__ except AttributeError: catllm_version = "unknown" python_version = sys.version.split()[0] # Generate methodology report report_pdf_path = generate_methodology_report_pdf( categories=categories, model=model, column_name=column_name, num_rows=num_items, model_source=model_source, filename=original_filename, success_rate=success_rate, result_df=result, processing_time=processing_time, catllm_version=catllm_version, python_version=python_version, task_type="assign", input_type=input_type, description=description ) # Generate reproducibility code code = generate_classify_code(input_type, description, categories, model, model_source, mode) return result, csv_path, report_pdf_path, code, f"Classified {num_items} items in {processing_time:.1f}s" except Exception as e: return None, None, None, None, f"Error: {str(e)}" def sanitize_model_name(model: str) -> str: """Convert model name to column-safe suffix (matches catllm logic).""" import re sanitized = re.sub(r'[^a-zA-Z0-9]', '_', model) sanitized = re.sub(r'_+', '_', sanitized) sanitized = sanitized.strip('_').lower() return sanitized[:40] def _find_model_column_suffix(result_df, model_name): """Find the actual column suffix used for a model in the DataFrame. catllm appends a creativity suffix (e.g. _tauto, _t50) to ensemble column names, so we can't just use sanitize_model_name(). This function looks at the real DataFrame columns to discover the full suffix. """ sanitized = sanitize_model_name(model_name) prefix = f"category_1_{sanitized}" for col in result_df.columns: if col.startswith(prefix): # Return everything after "category_1_" return col[len("category_1_"):] # Fallback: return just the sanitized name return sanitized def _find_all_model_suffixes(result_df): """Discover all distinct per-model column suffixes from the DataFrame. Looks at category_1_* columns (excluding _consensus and _agreement) to find every unique model suffix. Works even when the same model appears multiple times with different temperature suffixes. Returns: List of suffix strings, e.g. ['claude_haiku_4_5_20251001_t0', 'claude_haiku_4_5_20251001_t25', ...] """ import re suffixes = [] for col in result_df.columns: m = re.match(r'^category_1_(.+)$', col) if m: suffix = m.group(1) if suffix not in ('consensus', 'agreement'): suffixes.append(suffix) return suffixes def create_classification_heatmap(result_df, categories, classify_mode="Single Model", models_list=None): """Create a binary heatmap showing classification for each row. Args: result_df: DataFrame with classification results categories: List of category names classify_mode: "Single Model", "Model Comparison", or "Ensemble" models_list: List of model names (for multi-model modes) """ import numpy as np total_rows = len(result_df) if total_rows == 0: fig, ax = plt.subplots(figsize=(10, 4)) ax.text(0.5, 0.5, 'No data to display', ha='center', va='center', fontsize=14) ax.axis('off') return fig # Build the binary matrix based on classify_mode if classify_mode == "Ensemble": # Use consensus columns col_names = [f"category_{i}_consensus" for i in range(1, len(categories) + 1)] elif classify_mode == "Model Comparison" and models_list: # Use first model's columns (detect actual suffix from DataFrame) suffix = _find_model_column_suffix(result_df, models_list[0]) col_names = [f"category_{i}_{suffix}" for i in range(1, len(categories) + 1)] else: # Single model col_names = [f"category_{i}" for i in range(1, len(categories) + 1)] # Extract the binary matrix matrix_data = [] for col in col_names: if col in result_df.columns: matrix_data.append(result_df[col].astype(int).values) else: matrix_data.append(np.zeros(total_rows, dtype=int)) matrix = np.array(matrix_data).T # Rows = responses, Cols = categories # Create figure with appropriate sizing fig_height = max(4, min(20, total_rows * 0.15)) fig_width = max(8, len(categories) * 0.8) fig, ax = plt.subplots(figsize=(fig_width, fig_height)) # Create custom colormap: black (0) and orange (1) - CatLLM theme from matplotlib.colors import ListedColormap cmap = ListedColormap(['#1a1a1a', '#E8A33C']) # Plot heatmap im = ax.imshow(matrix, aspect='auto', cmap=cmap, vmin=0, vmax=1) # Set labels - remove y-axis numbers for cleaner look ax.set_xticks(range(len(categories))) ax.set_xticklabels(categories, rotation=45, ha='right', fontsize=9) ax.set_xlabel('Categories', fontsize=11) ax.set_ylabel(f'Responses (n={total_rows})', fontsize=11) ax.set_yticks([]) # Remove y-axis tick marks title = 'Classification Matrix' if classify_mode == "Ensemble": title += ' (Ensemble Consensus)' elif classify_mode == "Model Comparison": title += f' ({models_list[0].split("/")[-1].split(":")[0][:20]})' ax.set_title(title, fontsize=14, fontweight='bold') # Add legend from matplotlib.patches import Patch legend_elements = [ Patch(facecolor='#1a1a1a', edgecolor='white', label='Not Present'), Patch(facecolor='#E8A33C', edgecolor='white', label='Present') ] ax.legend(handles=legend_elements, loc='upper right', bbox_to_anchor=(1.15, 1)) plt.tight_layout() return fig def create_distribution_chart(result_df, categories, classify_mode="Single Model", models_list=None): """Create a bar chart showing category distribution. Args: result_df: DataFrame with classification results categories: List of category names classify_mode: "Single Model", "Model Comparison", or "Ensemble" models_list: List of model names (for multi-model modes) """ import numpy as np total_rows = len(result_df) if total_rows == 0: fig, ax = plt.subplots(figsize=(10, 4)) ax.text(0.5, 0.5, 'No data to display', ha='center', va='center', fontsize=14) ax.axis('off') return fig # Define colors for different models model_colors = ['#2563eb', '#dc2626', '#16a34a', '#ca8a04', '#9333ea', '#0891b2', '#be185d', '#65a30d'] if classify_mode == "Single Model": # Single model: use category_1, category_2, etc. fig, ax = plt.subplots(figsize=(10, max(4, len(categories) * 0.8))) dist_data = [] for i, cat in enumerate(categories, 1): col_name = f"category_{i}" if col_name in result_df.columns: count = int(result_df[col_name].sum()) pct = (count / total_rows) * 100 dist_data.append({"Category": cat, "Percentage": round(pct, 1)}) categories_list = [d["Category"] for d in dist_data][::-1] percentages = [d["Percentage"] for d in dist_data][::-1] bars = ax.barh(categories_list, percentages, color='#2563eb') ax.set_xlim(0, 100) ax.set_xlabel('Percentage (%)', fontsize=11) ax.set_title('Category Distribution (%)', fontsize=14, fontweight='bold') for bar, pct in zip(bars, percentages): ax.text(bar.get_width() + 1, bar.get_y() + bar.get_height()/2, f'{pct:.1f}%', va='center', fontsize=10) elif classify_mode == "Ensemble": # Ensemble: use category_1_consensus, category_2_consensus, etc. fig, ax = plt.subplots(figsize=(10, max(4, len(categories) * 0.8))) dist_data = [] for i, cat in enumerate(categories, 1): col_name = f"category_{i}_consensus" if col_name in result_df.columns: count = int(result_df[col_name].sum()) pct = (count / total_rows) * 100 dist_data.append({"Category": cat, "Percentage": round(pct, 1)}) categories_list = [d["Category"] for d in dist_data][::-1] percentages = [d["Percentage"] for d in dist_data][::-1] bars = ax.barh(categories_list, percentages, color='#16a34a') ax.set_xlim(0, 100) ax.set_xlabel('Percentage (%)', fontsize=11) ax.set_title('Ensemble Consensus Distribution (%)', fontsize=14, fontweight='bold') for bar, pct in zip(bars, percentages): ax.text(bar.get_width() + 1, bar.get_y() + bar.get_height()/2, f'{pct:.1f}%', va='center', fontsize=10) else: # Model Comparison # Model Comparison: grouped bars for each model if not models_list: models_list = [] # Detect actual column suffixes from the DataFrame model_suffixes = [_find_model_column_suffix(result_df, m) for m in models_list] n_models = len(model_suffixes) n_categories = len(categories) fig, ax = plt.subplots(figsize=(12, max(5, n_categories * 1.2))) # Gather data for each model bar_height = 0.8 / n_models y_positions = np.arange(n_categories) for model_idx, (model_name, suffix) in enumerate(zip(models_list, model_suffixes)): model_pcts = [] for i in range(1, n_categories + 1): col_name = f"category_{i}_{suffix}" if col_name in result_df.columns: count = int(result_df[col_name].sum()) pct = (count / total_rows) * 100 else: pct = 0 model_pcts.append(pct) # Reverse for horizontal bar chart model_pcts = model_pcts[::-1] offset = (model_idx - n_models / 2 + 0.5) * bar_height color = model_colors[model_idx % len(model_colors)] # Use shorter display name display_name = model_name.split('/')[-1].split(':')[0][:20] bars = ax.barh(y_positions + offset, model_pcts, bar_height * 0.9, label=display_name, color=color, alpha=0.85) ax.set_yticks(y_positions) ax.set_yticklabels(categories[::-1]) ax.set_xlim(0, 100) ax.set_xlabel('Percentage (%)', fontsize=11) ax.set_title('Category Distribution by Model (%)', fontsize=14, fontweight='bold') ax.legend(loc='lower right', fontsize=9) plt.tight_layout() return fig # Page config st.set_page_config( page_title="CatLLM - Research Data Classifier", page_icon="🐱", layout="wide" ) # Custom CSS for enhanced styling st.markdown(""" """, unsafe_allow_html=True) # Initialize session state if 'categories' not in st.session_state: st.session_state.categories = [''] * MAX_CATEGORIES if 'category_count' not in st.session_state: st.session_state.category_count = INITIAL_CATEGORIES if 'task_mode' not in st.session_state: st.session_state.task_mode = None if 'extracted_categories' not in st.session_state: st.session_state.extracted_categories = None if 'results' not in st.session_state: st.session_state.results = None if 'active_tab' not in st.session_state: st.session_state.active_tab = "survey" if 'survey_data' not in st.session_state: st.session_state.survey_data = None if 'pdf_data' not in st.session_state: st.session_state.pdf_data = None if 'image_data' not in st.session_state: st.session_state.image_data = None if 'extraction_params' not in st.session_state: st.session_state.extraction_params = None # Stores params when categories are auto-extracted # Logo and title - use HTML for better alignment st.markdown("""