""" Multi-Stage Reasoning Engine for Agricultural Chatbot ====================================================== Implements the full spec: Claim → Validate → Contradict → Confirm pipeline with priority-based context selection and evidence tracking. """ import json import logging from typing import Dict, List, Any, Optional, Tuple, Callable from dataclasses import dataclass, field from intent_classifier import IntentClassifier from priority_mapper import PriorityContextMapper from prompts import ( SYSTEM_PROMPT, CLAIM_PROMPT, VALIDATE_PROMPT, CONTRADICT_PROMPT, CONFIRM_PROMPT, RESPONSE_PROMPT, format_stage_prompt, build_context_prompt, generate_followup_questions, # Compact prompts for token reduction COMPACT_CLAIM_PROMPT, COMPACT_VALIDATE_PROMPT, COMPACT_CONTRADICT_PROMPT, COMPACT_CONFIRM_PROMPT, COMPACT_RESPONSE_PROMPT, build_compact_context, get_compact_prompt, format_minimal_diagnosis, # Hybrid Prompts FAST_LANE_PROMPT, DEEP_DIVE_HYPOTHESIS_PROMPT, DEEP_DIVE_ADVERSARY_PROMPT, DEEP_DIVE_JUDGE_PROMPT ) from context_aggregator import ContextAggregator logger = logging.getLogger("ReasoningEngine") # Toggle compact prompts to reduce token usage (saves ~50% tokens) USE_COMPACT_PROMPTS = True # ============================================================================= # DATA CLASSES # ============================================================================= @dataclass class StageResult: """Result from a single reasoning stage.""" stage: str output: Dict[str, Any] context_used: List[str] confidence: float raw_response: str = "" @dataclass class ReasoningResult: """Complete reasoning chain result.""" claim: StageResult validation: StageResult contradiction: StageResult confirmation: StageResult final_diagnosis: str final_confidence: float causal_chain: str root_cause: str symptoms: List[str] recommendation: str evidence_summary: Dict[str, List[str]] # ============================================================================= # REASONING ENGINE # ============================================================================= class ReasoningEngine: """ Multi-stage reasoning engine for agricultural chatbot. Following the spec: - Does NOT ingest all context at once - Uses priority-based context selection per intent - Actively seeks contradictions in Stage C - Tracks evidence chain for transparency """ def __init__(self, llm_caller: Callable[[str], str]): """ Args: llm_caller: Function that takes (prompt: str) -> str """ self.llm = llm_caller self.intent_classifier = IntentClassifier() self.priority_mapper = PriorityContextMapper() self.aggregator = ContextAggregator() def process_query( self, query: str, context: Optional[Dict[str, Any]] = None ) -> Tuple[str, Dict[str, Any]]: """ Process user query through Hybrid Architecture (Fast Lane vs Deep Dive). """ logger.info(f"Processing query: {query[:50]}...") # Stage 1: Classify intent intent = self.intent_classifier.classify(query) logger.info(f"Intent: {intent['primary_intent']} ({intent['confidence']})") # Stage 2: Route Query mode = self.route_query(query, intent) logger.info(f"Routing mode: {mode}") # Stage 3: Execute Logic if mode == "FAST_LANE": reasoning_result = self._execute_fast_lane(query, intent, context or {}) else: reasoning_result = self._execute_deep_dive(query, intent, context or {}) # Stage 4: Generate response (pass full context for persona/weather/zone) # Note: Fast Lane already generates action/diagnosis, but we standardize output format response = self._generate_response(query, reasoning_result, context) # Stage 5: Generate followups followups = generate_followup_questions( intent["primary_intent"], reasoning_result.final_diagnosis ) # Build complete trace trace = self._build_trace(intent, reasoning_result, {}, followups) trace["routing_mode"] = mode return response, trace def route_query(self, query: str, intent: Dict) -> str: """Decide between Fast Lane and Deep Dive.""" # Intention-based routing fast_intents = ["vegetation_health", "water_stress", "nutrient_status"] if intent["primary_intent"] in fast_intents and intent["confidence"] > 0.7: return "FAST_LANE" # "Why" questions or Comparisons usually need Deep Dive if "compare" in query.lower() or "difference" in query.lower(): return "DEEP_DIVE" return "DEEP_DIVE" # Default to robust mode for safety def _execute_fast_lane(self, query: str, intent: Dict, context: Dict) -> ReasoningResult: """Execute 1-Shot Reasoning.""" logger.info("Executing FAST LANE (1-Call)...") # Build ultra-compact context compact_ctx = self.aggregator.build_ultra_compact_context(context) # Include user query in prompt prompt = FAST_LANE_PROMPT.format(query=query, context=compact_ctx) full_prompt = f"{SYSTEM_PROMPT}\n\n{prompt}" response = self.llm(full_prompt) output = self._parse_json_safe(response, { "reasoning_trace": "Analysis failed", "diagnosis": "Unknown", "confidence": 0.0, "action": "Consult expert" }) # Create Dummy StageResult for compatibility dummy_stage = StageResult("fast_lane", output, [], output.get("confidence", 0.0)) return ReasoningResult( claim=dummy_stage, # Fill for struct compatibility validation=dummy_stage, contradiction=dummy_stage, confirmation=dummy_stage, final_diagnosis=output.get("diagnosis", "Unknown"), final_confidence=output.get("confidence", 0.0), causal_chain=output.get("reasoning_trace", ""), root_cause=output.get("diagnosis", "Unknown"), symptoms=[], recommendation=output.get("action", ""), evidence_summary={"method": ["fast_lane_optimization"]} ) def _execute_deep_dive(self, query: str, intent: Dict, context: Dict) -> ReasoningResult: """Execute 3-Stage Deep Dive.""" logger.info("Executing DEEP DIVE (3-Call)...") # 1. Hypothesis Generation - Include query ctx_hyp = self.aggregator.build_deep_dive_context(context, "hypothesis") resp_hyp = self.llm(f"{SYSTEM_PROMPT}\n\nUSER QUERY: {query}\n\n{DEEP_DIVE_HYPOTHESIS_PROMPT.format(query=query, context=ctx_hyp)}") out_hyp = self._parse_json_safe(resp_hyp, {"hypotheses": []}) # 2. Adversarial Check - Include query context ctx_adv = self.aggregator.build_deep_dive_context(context, "adversary") hyp_str = json.dumps(out_hyp, indent=2) resp_adv = self.llm(f"{SYSTEM_PROMPT}\n\nUSER QUERY: {query}\n\n{DEEP_DIVE_ADVERSARY_PROMPT.format(query=query, hypotheses=hyp_str, context=ctx_adv)}") out_adv = self._parse_json_safe(resp_adv, {"surviving_hypothesis": "Unknown"}) # 3. Final Verdict - Include query ctx_judge = self.aggregator.build_deep_dive_context(context, "judge") winner = out_adv.get("surviving_hypothesis", "Unknown") resp_judge = self.llm(f"{SYSTEM_PROMPT}\n\nUSER QUERY: {query}\n\n{DEEP_DIVE_JUDGE_PROMPT.format(query=query, hypothesis=winner, context=ctx_judge)}") out_judge = self._parse_json_safe(resp_judge, {"final_diagnosis": winner, "action_plan": {}}) # Map to ReasoningResult # We map stages roughly to Maintain compatibility result_hyp = StageResult("hypothesis", out_hyp, [], 0.0) result_adv = StageResult("adversary", out_adv, [], 0.0) result_judge = StageResult("judge", out_judge, [], 0.0) return ReasoningResult( claim=result_hyp, validation=result_adv, contradiction=result_adv, confirmation=result_judge, final_diagnosis=out_judge.get("final_diagnosis", "Unknown"), final_confidence=0.9, # Deep dive implies high confidence causal_chain=out_judge.get("detailed_reasoning", ""), root_cause=out_judge.get("root_cause", ""), symptoms=[], recommendation=str(out_judge.get("action_plan", "")), evidence_summary={"method": ["deep_dive_3_stage"]} ) def _reason( self, query: str, intent: Dict, staged_context: Dict ) -> ReasoningResult: """Execute 4-stage reasoning pipeline as per spec.""" # Stage A: Initial Claim (Priority 1 context only) logger.info("Stage A: Making initial claim...") claim = self._stage_claim(query, staged_context["claim_context"]) # Stage B: Validate (Add Priority 2 context) logger.info(f"Stage B: Validating hypothesis '{claim.output.get('hypothesis')}'...") validation = self._stage_validate( hypothesis=claim.output.get("hypothesis", "unknown"), confidence=claim.confidence, context=staged_context["validate_context"] ) # Stage C: Contradict (Priority 3 - actively seek alternatives) current_hypothesis = validation.output.get("hypothesis", claim.output.get("hypothesis", "unknown")) logger.info(f"Stage C: Seeking contradictions to '{current_hypothesis}'...") contradiction = self._stage_contradict( hypothesis=current_hypothesis, confidence=validation.confidence, context=staged_context["contradict_context"] ) # Stage D: Confirm (Priority 4 - final decision) logger.info("Stage D: Final confirmation...") confirmation = self._stage_confirm( hypothesis_1=current_hypothesis, conf_1=validation.confidence, hypothesis_2=contradiction.output.get("alternative_hypothesis", "none"), conf_2=contradiction.output.get("alternative_confidence", 0), context=staged_context["confirm_context"] ) return ReasoningResult( claim=claim, validation=validation, contradiction=contradiction, confirmation=confirmation, final_diagnosis=confirmation.output.get("final_diagnosis", "Undetermined"), final_confidence=confirmation.confidence, causal_chain=confirmation.output.get("causal_chain", ""), root_cause=confirmation.output.get("root_cause", "unknown"), symptoms=confirmation.output.get("symptoms", []), recommendation=confirmation.output.get("recommendation", ""), evidence_summary={ "primary": claim.context_used, "supporting": validation.context_used, "alternative": contradiction.context_used, "validation": confirmation.context_used, "supporting_evidence": confirmation.output.get("evidence_summary", {}).get("supporting", []), "contradicting_evidence": confirmation.output.get("evidence_summary", {}).get("contradicting", []), "inconclusive_evidence": confirmation.output.get("evidence_summary", {}).get("inconclusive", []) } ) def _stage_claim(self, query: str, context: Dict) -> StageResult: """Stage 3A: Make initial claim using Priority 1 context only.""" if USE_COMPACT_PROMPTS: compact_ctx = build_compact_context(context) if isinstance(context, dict) else str(context) prompt = COMPACT_CLAIM_PROMPT.format(query=query, context=compact_ctx) else: prompt = format_stage_prompt( CLAIM_PROMPT, query=query, priority_1_context=context ) full_prompt = f"{SYSTEM_PROMPT}\n\n{prompt}" if not USE_COMPACT_PROMPTS else prompt response = self.llm(full_prompt) output = self._parse_json_safe(response, { "initial_claim": response[:200] if response else "No analysis available", "hypothesis": "general_issue", "evidence_cited": list(context.keys()), "confidence": 0.5, "uncertainties": ["Limited data available"] }) return StageResult( stage="claim", output=output, context_used=list(context.keys()), confidence=output.get("confidence", 0.5), raw_response=response ) def _stage_validate( self, hypothesis: str, confidence: float, context: Dict ) -> StageResult: """Stage 3B: Validate hypothesis using Priority 2 context.""" if USE_COMPACT_PROMPTS: compact_ctx = build_compact_context(context) if isinstance(context, dict) else str(context) prompt = COMPACT_VALIDATE_PROMPT.format( previous_hypothesis=hypothesis, previous_confidence=confidence, priority_2_context=compact_ctx ) else: prompt = format_stage_prompt( VALIDATE_PROMPT, previous_hypothesis=hypothesis, previous_confidence=confidence, priority_2_context=context ) full_prompt = f"{SYSTEM_PROMPT}\n\n{prompt}" if not USE_COMPACT_PROMPTS else prompt response = self.llm(full_prompt) output = self._parse_json_safe(response, { "validation_result": "neutral", "confidence_updated": confidence, "spatial_notes": "Unable to determine spatial distribution", "new_evidence_summary": "" }) # Carry forward hypothesis output["hypothesis"] = hypothesis return StageResult( stage="validate", output=output, context_used=list(context.keys()), confidence=output.get("confidence_updated", confidence), raw_response=response ) def _stage_contradict( self, hypothesis: str, confidence: float, context: Dict ) -> StageResult: """Stage 3C: Actively seek contradictions using Priority 3 context.""" if USE_COMPACT_PROMPTS: compact_ctx = build_compact_context(context) if isinstance(context, dict) else str(context) prompt = COMPACT_CONTRADICT_PROMPT.format( hypothesis=hypothesis, confidence=confidence, priority_3_context=compact_ctx ) else: prompt = format_stage_prompt( CONTRADICT_PROMPT, hypothesis=hypothesis, confidence=confidence, priority_3_context=context ) full_prompt = f"{SYSTEM_PROMPT}\n\n{prompt}" if not USE_COMPACT_PROMPTS else prompt response = self.llm(full_prompt) output = self._parse_json_safe(response, { "contradiction_found": False, "contradicting_evidence": [], "alternative_hypothesis": "none", "alternative_confidence": 0.0, "reasoning": "No strong contradicting evidence found" }) return StageResult( stage="contradict", output=output, context_used=list(context.keys()), confidence=output.get("alternative_confidence", 0.0), raw_response=response ) def _stage_confirm( self, hypothesis_1: str, conf_1: float, hypothesis_2: str, conf_2: float, context: Dict ) -> StageResult: """Stage 3D: Final confirmation using Priority 4 context.""" if USE_COMPACT_PROMPTS: compact_ctx = build_compact_context(context) if isinstance(context, dict) else str(context) prompt = COMPACT_CONFIRM_PROMPT.format( hypothesis_1=hypothesis_1, conf_1=conf_1, hypothesis_2=hypothesis_2 if hypothesis_2 != "none" else "no_alt", conf_2=conf_2, priority_4_context=compact_ctx ) else: prompt = format_stage_prompt( CONFIRM_PROMPT, hypothesis_1=hypothesis_1, conf_1=conf_1, hypothesis_2=hypothesis_2 if hypothesis_2 != "none" else "no_alternative", conf_2=conf_2, priority_4_context=context ) full_prompt = f"{SYSTEM_PROMPT}\n\n{prompt}" if not USE_COMPACT_PROMPTS else prompt response = self.llm(full_prompt) # Pick the more confident hypothesis as default default_diagnosis = hypothesis_1 if conf_1 >= conf_2 else hypothesis_2 default_conf = max(conf_1, conf_2) output = self._parse_json_safe(response, { "final_diagnosis": default_diagnosis, "confidence": default_conf, "causal_chain": f"{default_diagnosis} leads to observed symptoms", "root_cause": default_diagnosis, "symptoms": [], "evidence_summary": { "supporting": list(context.keys()), "contradicting": [], "inconclusive": [] }, "recommendation": "Further investigation recommended based on available data" }) return StageResult( stage="confirm", output=output, context_used=list(context.keys()), confidence=output.get("confidence", default_conf), raw_response=response ) def _generate_response(self, query: str, result: ReasoningResult, context: Dict = None) -> str: """Generate final user-facing response with persona and context.""" diagnosis_data = { "diagnosis": result.final_diagnosis, "confidence": result.final_confidence, "causal_chain": result.causal_chain, "root_cause": result.root_cause, "symptoms": result.symptoms, "recommendation": result.recommendation } # Extract persona instructions persona = context.get("persona", {}) if context else {} persona_instructions = persona.get("instructions", "Provide clear, helpful farming advice.") # Conversation history disabled to reduce token usage history_text = "" # Format zone context zone_data = context.get("zone_analysis", {}) if context else {} if zone_data and zone_data.get("priority_zones"): zones = zone_data["priority_zones"] zone_text = "PRIORITY ZONES:\n" for i, z in enumerate(zones[:3], 1): zone_text += f" {i}. {z.get('location', 'Zone')} - Stress: {z.get('stress_score', 0):.0%}\n" else: zone_text = "No zone-specific data available." # Format trend context trend_data = context.get("historical_trends", {}) if context else {} if trend_data.get("summary"): trend_text = trend_data["summary"] else: trend_text = "No historical trend data available." # Format weather context weather = context.get("weather", {}) if context else {} if weather: from prompts import format_weather_context weather_text = format_weather_context(weather) else: weather_text = "No weather data available." prompt = format_stage_prompt( RESPONSE_PROMPT, query=query, diagnosis=json.dumps(diagnosis_data, indent=2), evidence=json.dumps(result.evidence_summary, indent=2), persona_instructions=persona_instructions, conversation_history=history_text, zone_context=zone_text, trend_context=trend_text, weather_context=weather_text ) full_prompt = f"{SYSTEM_PROMPT}\n\n{prompt}" response = self.llm(full_prompt) return response def _build_trace( self, intent: Dict, result: ReasoningResult, staged_context: Dict, followups: List[str] ) -> Dict[str, Any]: """Build reasoning trace for transparency as per spec.""" return { "intent_detected": intent["primary_intent"], "intent_confidence": intent["confidence"], "sub_intents": intent["sub_intents"], "stages": { "claim": { "hypothesis": result.claim.output.get("hypothesis"), "initial_claim": result.claim.output.get("initial_claim"), "confidence": result.claim.confidence, "context_used": result.claim.context_used, "evidence_cited": result.claim.output.get("evidence_cited", []) }, "validation": { "result": result.validation.output.get("validation_result"), "confidence": result.validation.confidence, "spatial_notes": result.validation.output.get("spatial_notes"), "context_used": result.validation.context_used }, "contradiction": { "found": result.contradiction.output.get("contradiction_found"), "alternative": result.contradiction.output.get("alternative_hypothesis"), "alternative_confidence": result.contradiction.output.get("alternative_confidence"), "reasoning": result.contradiction.output.get("reasoning"), "context_used": result.contradiction.context_used }, "confirmation": { "final": result.final_diagnosis, "confidence": result.final_confidence, "root_cause": result.root_cause, "causal_chain": result.causal_chain, "context_used": result.confirmation.context_used } }, "causal_chain": result.causal_chain, "root_cause": result.root_cause, "symptoms": result.symptoms, "evidence_summary": result.evidence_summary, "context_priority_used": { "priority_1": list(staged_context.get("claim_context", {}).keys()), "priority_2": list(staged_context.get("validate_context", {}).keys()), "priority_3": list(staged_context.get("contradict_context", {}).keys()), "priority_4": list(staged_context.get("confirm_context", {}).keys()) }, "suggested_followups": followups } def _parse_json_safe(self, text: str, default: Dict) -> Dict: """Safely extract and parse JSON from LLM response.""" if not text: return default text = text.strip() # Look for JSON code block if "```json" in text: start = text.find("```json") + 7 end = text.find("```", start) if end > start: text = text[start:end].strip() elif "```" in text: start = text.find("```") + 3 end = text.find("```", start) if end > start: text = text[start:end].strip() # Find JSON object start = text.find("{") end = text.rfind("}") + 1 if start >= 0 and end > start: try: return json.loads(text[start:end]) except json.JSONDecodeError as e: logger.warning(f"JSON parse error: {e}") return default # ============================================================================= # SIMPLE REASONING (Fallback) # ============================================================================= def simple_reason(query: str, context: Dict, llm_caller: Callable) -> str: """Simplified single-stage reasoning for when full pipeline isn't needed.""" context_str = build_context_prompt(context) if context else "No specific field data available." prompt = f"""{SYSTEM_PROMPT} User Query: {query} Available Context: {context_str} Provide a helpful, actionable response. If specific data is available, cite it. If not, provide general guidance based on the query. Use this format: 1. Direct answer to the question 2. Key observations from data (if available) 3. Practical recommendations 4. What to monitor or check next""" return llm_caller(prompt)