Chatbot / reasoning_engine.py
Aniket2006's picture
Fix: Pass user query to LLM context in Fast Lane and Deep Dive stages
7f34cf9
"""
Multi-Stage Reasoning Engine for Agricultural Chatbot
======================================================
Implements the full spec: Claim → Validate → Contradict → Confirm pipeline
with priority-based context selection and evidence tracking.
"""
import json
import logging
from typing import Dict, List, Any, Optional, Tuple, Callable
from dataclasses import dataclass, field
from intent_classifier import IntentClassifier
from priority_mapper import PriorityContextMapper
from prompts import (
SYSTEM_PROMPT, CLAIM_PROMPT, VALIDATE_PROMPT,
CONTRADICT_PROMPT, CONFIRM_PROMPT, RESPONSE_PROMPT,
format_stage_prompt, build_context_prompt, generate_followup_questions,
# Compact prompts for token reduction
COMPACT_CLAIM_PROMPT, COMPACT_VALIDATE_PROMPT, COMPACT_CONTRADICT_PROMPT,
COMPACT_CONFIRM_PROMPT, COMPACT_RESPONSE_PROMPT,
build_compact_context, get_compact_prompt, format_minimal_diagnosis,
# Hybrid Prompts
FAST_LANE_PROMPT, DEEP_DIVE_HYPOTHESIS_PROMPT,
DEEP_DIVE_ADVERSARY_PROMPT, DEEP_DIVE_JUDGE_PROMPT
)
from context_aggregator import ContextAggregator
logger = logging.getLogger("ReasoningEngine")
# Toggle compact prompts to reduce token usage (saves ~50% tokens)
USE_COMPACT_PROMPTS = True
# =============================================================================
# DATA CLASSES
# =============================================================================
@dataclass
class StageResult:
"""Result from a single reasoning stage."""
stage: str
output: Dict[str, Any]
context_used: List[str]
confidence: float
raw_response: str = ""
@dataclass
class ReasoningResult:
"""Complete reasoning chain result."""
claim: StageResult
validation: StageResult
contradiction: StageResult
confirmation: StageResult
final_diagnosis: str
final_confidence: float
causal_chain: str
root_cause: str
symptoms: List[str]
recommendation: str
evidence_summary: Dict[str, List[str]]
# =============================================================================
# REASONING ENGINE
# =============================================================================
class ReasoningEngine:
"""
Multi-stage reasoning engine for agricultural chatbot.
Following the spec:
- Does NOT ingest all context at once
- Uses priority-based context selection per intent
- Actively seeks contradictions in Stage C
- Tracks evidence chain for transparency
"""
def __init__(self, llm_caller: Callable[[str], str]):
"""
Args:
llm_caller: Function that takes (prompt: str) -> str
"""
self.llm = llm_caller
self.intent_classifier = IntentClassifier()
self.priority_mapper = PriorityContextMapper()
self.aggregator = ContextAggregator()
def process_query(
self,
query: str,
context: Optional[Dict[str, Any]] = None
) -> Tuple[str, Dict[str, Any]]:
"""
Process user query through Hybrid Architecture (Fast Lane vs Deep Dive).
"""
logger.info(f"Processing query: {query[:50]}...")
# Stage 1: Classify intent
intent = self.intent_classifier.classify(query)
logger.info(f"Intent: {intent['primary_intent']} ({intent['confidence']})")
# Stage 2: Route Query
mode = self.route_query(query, intent)
logger.info(f"Routing mode: {mode}")
# Stage 3: Execute Logic
if mode == "FAST_LANE":
reasoning_result = self._execute_fast_lane(query, intent, context or {})
else:
reasoning_result = self._execute_deep_dive(query, intent, context or {})
# Stage 4: Generate response (pass full context for persona/weather/zone)
# Note: Fast Lane already generates action/diagnosis, but we standardize output format
response = self._generate_response(query, reasoning_result, context)
# Stage 5: Generate followups
followups = generate_followup_questions(
intent["primary_intent"],
reasoning_result.final_diagnosis
)
# Build complete trace
trace = self._build_trace(intent, reasoning_result, {}, followups)
trace["routing_mode"] = mode
return response, trace
def route_query(self, query: str, intent: Dict) -> str:
"""Decide between Fast Lane and Deep Dive."""
# Intention-based routing
fast_intents = ["vegetation_health", "water_stress", "nutrient_status"]
if intent["primary_intent"] in fast_intents and intent["confidence"] > 0.7:
return "FAST_LANE"
# "Why" questions or Comparisons usually need Deep Dive
if "compare" in query.lower() or "difference" in query.lower():
return "DEEP_DIVE"
return "DEEP_DIVE" # Default to robust mode for safety
def _execute_fast_lane(self, query: str, intent: Dict, context: Dict) -> ReasoningResult:
"""Execute 1-Shot Reasoning."""
logger.info("Executing FAST LANE (1-Call)...")
# Build ultra-compact context
compact_ctx = self.aggregator.build_ultra_compact_context(context)
# Include user query in prompt
prompt = FAST_LANE_PROMPT.format(query=query, context=compact_ctx)
full_prompt = f"{SYSTEM_PROMPT}\n\n{prompt}"
response = self.llm(full_prompt)
output = self._parse_json_safe(response, {
"reasoning_trace": "Analysis failed",
"diagnosis": "Unknown",
"confidence": 0.0,
"action": "Consult expert"
})
# Create Dummy StageResult for compatibility
dummy_stage = StageResult("fast_lane", output, [], output.get("confidence", 0.0))
return ReasoningResult(
claim=dummy_stage, # Fill for struct compatibility
validation=dummy_stage,
contradiction=dummy_stage,
confirmation=dummy_stage,
final_diagnosis=output.get("diagnosis", "Unknown"),
final_confidence=output.get("confidence", 0.0),
causal_chain=output.get("reasoning_trace", ""),
root_cause=output.get("diagnosis", "Unknown"),
symptoms=[],
recommendation=output.get("action", ""),
evidence_summary={"method": ["fast_lane_optimization"]}
)
def _execute_deep_dive(self, query: str, intent: Dict, context: Dict) -> ReasoningResult:
"""Execute 3-Stage Deep Dive."""
logger.info("Executing DEEP DIVE (3-Call)...")
# 1. Hypothesis Generation - Include query
ctx_hyp = self.aggregator.build_deep_dive_context(context, "hypothesis")
resp_hyp = self.llm(f"{SYSTEM_PROMPT}\n\nUSER QUERY: {query}\n\n{DEEP_DIVE_HYPOTHESIS_PROMPT.format(query=query, context=ctx_hyp)}")
out_hyp = self._parse_json_safe(resp_hyp, {"hypotheses": []})
# 2. Adversarial Check - Include query context
ctx_adv = self.aggregator.build_deep_dive_context(context, "adversary")
hyp_str = json.dumps(out_hyp, indent=2)
resp_adv = self.llm(f"{SYSTEM_PROMPT}\n\nUSER QUERY: {query}\n\n{DEEP_DIVE_ADVERSARY_PROMPT.format(query=query, hypotheses=hyp_str, context=ctx_adv)}")
out_adv = self._parse_json_safe(resp_adv, {"surviving_hypothesis": "Unknown"})
# 3. Final Verdict - Include query
ctx_judge = self.aggregator.build_deep_dive_context(context, "judge")
winner = out_adv.get("surviving_hypothesis", "Unknown")
resp_judge = self.llm(f"{SYSTEM_PROMPT}\n\nUSER QUERY: {query}\n\n{DEEP_DIVE_JUDGE_PROMPT.format(query=query, hypothesis=winner, context=ctx_judge)}")
out_judge = self._parse_json_safe(resp_judge, {"final_diagnosis": winner, "action_plan": {}})
# Map to ReasoningResult
# We map stages roughly to Maintain compatibility
result_hyp = StageResult("hypothesis", out_hyp, [], 0.0)
result_adv = StageResult("adversary", out_adv, [], 0.0)
result_judge = StageResult("judge", out_judge, [], 0.0)
return ReasoningResult(
claim=result_hyp,
validation=result_adv,
contradiction=result_adv,
confirmation=result_judge,
final_diagnosis=out_judge.get("final_diagnosis", "Unknown"),
final_confidence=0.9, # Deep dive implies high confidence
causal_chain=out_judge.get("detailed_reasoning", ""),
root_cause=out_judge.get("root_cause", ""),
symptoms=[],
recommendation=str(out_judge.get("action_plan", "")),
evidence_summary={"method": ["deep_dive_3_stage"]}
)
def _reason(
self,
query: str,
intent: Dict,
staged_context: Dict
) -> ReasoningResult:
"""Execute 4-stage reasoning pipeline as per spec."""
# Stage A: Initial Claim (Priority 1 context only)
logger.info("Stage A: Making initial claim...")
claim = self._stage_claim(query, staged_context["claim_context"])
# Stage B: Validate (Add Priority 2 context)
logger.info(f"Stage B: Validating hypothesis '{claim.output.get('hypothesis')}'...")
validation = self._stage_validate(
hypothesis=claim.output.get("hypothesis", "unknown"),
confidence=claim.confidence,
context=staged_context["validate_context"]
)
# Stage C: Contradict (Priority 3 - actively seek alternatives)
current_hypothesis = validation.output.get("hypothesis",
claim.output.get("hypothesis", "unknown"))
logger.info(f"Stage C: Seeking contradictions to '{current_hypothesis}'...")
contradiction = self._stage_contradict(
hypothesis=current_hypothesis,
confidence=validation.confidence,
context=staged_context["contradict_context"]
)
# Stage D: Confirm (Priority 4 - final decision)
logger.info("Stage D: Final confirmation...")
confirmation = self._stage_confirm(
hypothesis_1=current_hypothesis,
conf_1=validation.confidence,
hypothesis_2=contradiction.output.get("alternative_hypothesis", "none"),
conf_2=contradiction.output.get("alternative_confidence", 0),
context=staged_context["confirm_context"]
)
return ReasoningResult(
claim=claim,
validation=validation,
contradiction=contradiction,
confirmation=confirmation,
final_diagnosis=confirmation.output.get("final_diagnosis", "Undetermined"),
final_confidence=confirmation.confidence,
causal_chain=confirmation.output.get("causal_chain", ""),
root_cause=confirmation.output.get("root_cause", "unknown"),
symptoms=confirmation.output.get("symptoms", []),
recommendation=confirmation.output.get("recommendation", ""),
evidence_summary={
"primary": claim.context_used,
"supporting": validation.context_used,
"alternative": contradiction.context_used,
"validation": confirmation.context_used,
"supporting_evidence": confirmation.output.get("evidence_summary", {}).get("supporting", []),
"contradicting_evidence": confirmation.output.get("evidence_summary", {}).get("contradicting", []),
"inconclusive_evidence": confirmation.output.get("evidence_summary", {}).get("inconclusive", [])
}
)
def _stage_claim(self, query: str, context: Dict) -> StageResult:
"""Stage 3A: Make initial claim using Priority 1 context only."""
if USE_COMPACT_PROMPTS:
compact_ctx = build_compact_context(context) if isinstance(context, dict) else str(context)
prompt = COMPACT_CLAIM_PROMPT.format(query=query, context=compact_ctx)
else:
prompt = format_stage_prompt(
CLAIM_PROMPT,
query=query,
priority_1_context=context
)
full_prompt = f"{SYSTEM_PROMPT}\n\n{prompt}" if not USE_COMPACT_PROMPTS else prompt
response = self.llm(full_prompt)
output = self._parse_json_safe(response, {
"initial_claim": response[:200] if response else "No analysis available",
"hypothesis": "general_issue",
"evidence_cited": list(context.keys()),
"confidence": 0.5,
"uncertainties": ["Limited data available"]
})
return StageResult(
stage="claim",
output=output,
context_used=list(context.keys()),
confidence=output.get("confidence", 0.5),
raw_response=response
)
def _stage_validate(
self,
hypothesis: str,
confidence: float,
context: Dict
) -> StageResult:
"""Stage 3B: Validate hypothesis using Priority 2 context."""
if USE_COMPACT_PROMPTS:
compact_ctx = build_compact_context(context) if isinstance(context, dict) else str(context)
prompt = COMPACT_VALIDATE_PROMPT.format(
previous_hypothesis=hypothesis,
previous_confidence=confidence,
priority_2_context=compact_ctx
)
else:
prompt = format_stage_prompt(
VALIDATE_PROMPT,
previous_hypothesis=hypothesis,
previous_confidence=confidence,
priority_2_context=context
)
full_prompt = f"{SYSTEM_PROMPT}\n\n{prompt}" if not USE_COMPACT_PROMPTS else prompt
response = self.llm(full_prompt)
output = self._parse_json_safe(response, {
"validation_result": "neutral",
"confidence_updated": confidence,
"spatial_notes": "Unable to determine spatial distribution",
"new_evidence_summary": ""
})
# Carry forward hypothesis
output["hypothesis"] = hypothesis
return StageResult(
stage="validate",
output=output,
context_used=list(context.keys()),
confidence=output.get("confidence_updated", confidence),
raw_response=response
)
def _stage_contradict(
self,
hypothesis: str,
confidence: float,
context: Dict
) -> StageResult:
"""Stage 3C: Actively seek contradictions using Priority 3 context."""
if USE_COMPACT_PROMPTS:
compact_ctx = build_compact_context(context) if isinstance(context, dict) else str(context)
prompt = COMPACT_CONTRADICT_PROMPT.format(
hypothesis=hypothesis,
confidence=confidence,
priority_3_context=compact_ctx
)
else:
prompt = format_stage_prompt(
CONTRADICT_PROMPT,
hypothesis=hypothesis,
confidence=confidence,
priority_3_context=context
)
full_prompt = f"{SYSTEM_PROMPT}\n\n{prompt}" if not USE_COMPACT_PROMPTS else prompt
response = self.llm(full_prompt)
output = self._parse_json_safe(response, {
"contradiction_found": False,
"contradicting_evidence": [],
"alternative_hypothesis": "none",
"alternative_confidence": 0.0,
"reasoning": "No strong contradicting evidence found"
})
return StageResult(
stage="contradict",
output=output,
context_used=list(context.keys()),
confidence=output.get("alternative_confidence", 0.0),
raw_response=response
)
def _stage_confirm(
self,
hypothesis_1: str,
conf_1: float,
hypothesis_2: str,
conf_2: float,
context: Dict
) -> StageResult:
"""Stage 3D: Final confirmation using Priority 4 context."""
if USE_COMPACT_PROMPTS:
compact_ctx = build_compact_context(context) if isinstance(context, dict) else str(context)
prompt = COMPACT_CONFIRM_PROMPT.format(
hypothesis_1=hypothesis_1,
conf_1=conf_1,
hypothesis_2=hypothesis_2 if hypothesis_2 != "none" else "no_alt",
conf_2=conf_2,
priority_4_context=compact_ctx
)
else:
prompt = format_stage_prompt(
CONFIRM_PROMPT,
hypothesis_1=hypothesis_1,
conf_1=conf_1,
hypothesis_2=hypothesis_2 if hypothesis_2 != "none" else "no_alternative",
conf_2=conf_2,
priority_4_context=context
)
full_prompt = f"{SYSTEM_PROMPT}\n\n{prompt}" if not USE_COMPACT_PROMPTS else prompt
response = self.llm(full_prompt)
# Pick the more confident hypothesis as default
default_diagnosis = hypothesis_1 if conf_1 >= conf_2 else hypothesis_2
default_conf = max(conf_1, conf_2)
output = self._parse_json_safe(response, {
"final_diagnosis": default_diagnosis,
"confidence": default_conf,
"causal_chain": f"{default_diagnosis} leads to observed symptoms",
"root_cause": default_diagnosis,
"symptoms": [],
"evidence_summary": {
"supporting": list(context.keys()),
"contradicting": [],
"inconclusive": []
},
"recommendation": "Further investigation recommended based on available data"
})
return StageResult(
stage="confirm",
output=output,
context_used=list(context.keys()),
confidence=output.get("confidence", default_conf),
raw_response=response
)
def _generate_response(self, query: str, result: ReasoningResult,
context: Dict = None) -> str:
"""Generate final user-facing response with persona and context."""
diagnosis_data = {
"diagnosis": result.final_diagnosis,
"confidence": result.final_confidence,
"causal_chain": result.causal_chain,
"root_cause": result.root_cause,
"symptoms": result.symptoms,
"recommendation": result.recommendation
}
# Extract persona instructions
persona = context.get("persona", {}) if context else {}
persona_instructions = persona.get("instructions", "Provide clear, helpful farming advice.")
# Conversation history disabled to reduce token usage
history_text = ""
# Format zone context
zone_data = context.get("zone_analysis", {}) if context else {}
if zone_data and zone_data.get("priority_zones"):
zones = zone_data["priority_zones"]
zone_text = "PRIORITY ZONES:\n"
for i, z in enumerate(zones[:3], 1):
zone_text += f" {i}. {z.get('location', 'Zone')} - Stress: {z.get('stress_score', 0):.0%}\n"
else:
zone_text = "No zone-specific data available."
# Format trend context
trend_data = context.get("historical_trends", {}) if context else {}
if trend_data.get("summary"):
trend_text = trend_data["summary"]
else:
trend_text = "No historical trend data available."
# Format weather context
weather = context.get("weather", {}) if context else {}
if weather:
from prompts import format_weather_context
weather_text = format_weather_context(weather)
else:
weather_text = "No weather data available."
prompt = format_stage_prompt(
RESPONSE_PROMPT,
query=query,
diagnosis=json.dumps(diagnosis_data, indent=2),
evidence=json.dumps(result.evidence_summary, indent=2),
persona_instructions=persona_instructions,
conversation_history=history_text,
zone_context=zone_text,
trend_context=trend_text,
weather_context=weather_text
)
full_prompt = f"{SYSTEM_PROMPT}\n\n{prompt}"
response = self.llm(full_prompt)
return response
def _build_trace(
self,
intent: Dict,
result: ReasoningResult,
staged_context: Dict,
followups: List[str]
) -> Dict[str, Any]:
"""Build reasoning trace for transparency as per spec."""
return {
"intent_detected": intent["primary_intent"],
"intent_confidence": intent["confidence"],
"sub_intents": intent["sub_intents"],
"stages": {
"claim": {
"hypothesis": result.claim.output.get("hypothesis"),
"initial_claim": result.claim.output.get("initial_claim"),
"confidence": result.claim.confidence,
"context_used": result.claim.context_used,
"evidence_cited": result.claim.output.get("evidence_cited", [])
},
"validation": {
"result": result.validation.output.get("validation_result"),
"confidence": result.validation.confidence,
"spatial_notes": result.validation.output.get("spatial_notes"),
"context_used": result.validation.context_used
},
"contradiction": {
"found": result.contradiction.output.get("contradiction_found"),
"alternative": result.contradiction.output.get("alternative_hypothesis"),
"alternative_confidence": result.contradiction.output.get("alternative_confidence"),
"reasoning": result.contradiction.output.get("reasoning"),
"context_used": result.contradiction.context_used
},
"confirmation": {
"final": result.final_diagnosis,
"confidence": result.final_confidence,
"root_cause": result.root_cause,
"causal_chain": result.causal_chain,
"context_used": result.confirmation.context_used
}
},
"causal_chain": result.causal_chain,
"root_cause": result.root_cause,
"symptoms": result.symptoms,
"evidence_summary": result.evidence_summary,
"context_priority_used": {
"priority_1": list(staged_context.get("claim_context", {}).keys()),
"priority_2": list(staged_context.get("validate_context", {}).keys()),
"priority_3": list(staged_context.get("contradict_context", {}).keys()),
"priority_4": list(staged_context.get("confirm_context", {}).keys())
},
"suggested_followups": followups
}
def _parse_json_safe(self, text: str, default: Dict) -> Dict:
"""Safely extract and parse JSON from LLM response."""
if not text:
return default
text = text.strip()
# Look for JSON code block
if "```json" in text:
start = text.find("```json") + 7
end = text.find("```", start)
if end > start:
text = text[start:end].strip()
elif "```" in text:
start = text.find("```") + 3
end = text.find("```", start)
if end > start:
text = text[start:end].strip()
# Find JSON object
start = text.find("{")
end = text.rfind("}") + 1
if start >= 0 and end > start:
try:
return json.loads(text[start:end])
except json.JSONDecodeError as e:
logger.warning(f"JSON parse error: {e}")
return default
# =============================================================================
# SIMPLE REASONING (Fallback)
# =============================================================================
def simple_reason(query: str, context: Dict, llm_caller: Callable) -> str:
"""Simplified single-stage reasoning for when full pipeline isn't needed."""
context_str = build_context_prompt(context) if context else "No specific field data available."
prompt = f"""{SYSTEM_PROMPT}
User Query: {query}
Available Context:
{context_str}
Provide a helpful, actionable response. If specific data is available, cite it.
If not, provide general guidance based on the query.
Use this format:
1. Direct answer to the question
2. Key observations from data (if available)
3. Practical recommendations
4. What to monitor or check next"""
return llm_caller(prompt)