import torch from transformers import T5ForConditionalGeneration, T5Tokenizer import argparse import json class ItineraryGenerator: def __init__(self, model_path: str): self.tokenizer = T5Tokenizer.from_pretrained(model_path) self.model = T5ForConditionalGeneration.from_pretrained( model_path, device_map="auto" ) self.model.eval() def generate_itinerary( self, destination: str, duration: int, preferences: str, budget: str, max_length: int = 1024, temperature: float = 0.7, top_p: float = 0.9, ) -> str: prompt = f"""Generate a detailed travel itinerary for {destination} for {duration} days. Preferences: {preferences} Budget: {budget} Detailed Itinerary:""" inputs = self.tokenizer(prompt, return_tensors="pt", max_length=max_length, truncation=True).to(self.model.device) with torch.no_grad(): outputs = self.model.generate( **inputs, max_length=max_length, temperature=temperature, top_p=top_p, num_return_sequences=1, pad_token_id=self.tokenizer.eos_token_id ) generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) # Extract only the generated itinerary part itinerary = generated_text[len(prompt):] return itinerary.strip() def main(): parser = argparse.ArgumentParser(description="Generate travel itineraries using fine-tuned LLaMA model") parser.add_argument("--model_path", type=str, required=True, help="Path to the fine-tuned model") parser.add_argument("--destination", type=str, required=True, help="Travel destination") parser.add_argument("--duration", type=int, required=True, help="Number of days") parser.add_argument("--preferences", type=str, required=True, help="Travel preferences") parser.add_argument("--budget", type=str, required=True, help="Travel budget") parser.add_argument("--output", type=str, help="Output file path (optional)") args = parser.parse_args() generator = ItineraryGenerator(args.model_path) itinerary = generator.generate_itinerary( destination=args.destination, duration=args.duration, preferences=args.preferences, budget=args.budget ) output = { "destination": args.destination, "duration": args.duration, "preferences": args.preferences, "budget": args.budget, "generated_itinerary": itinerary } if args.output: with open(args.output, 'w') as f: json.dump(output, f, indent=2) print(f"Itinerary saved to {args.output}") else: print("\nGenerated Itinerary:") print("=" * 50) print(itinerary) if __name__ == "__main__": main()