| | import torch |
| | from transformers import T5ForConditionalGeneration, T5Tokenizer |
| | import argparse |
| | import json |
| |
|
| | class ItineraryGenerator: |
| | def __init__(self, model_path: str): |
| | self.tokenizer = T5Tokenizer.from_pretrained(model_path) |
| | self.model = T5ForConditionalGeneration.from_pretrained( |
| | model_path, |
| | device_map="auto" |
| | ) |
| | self.model.eval() |
| |
|
| | def generate_itinerary( |
| | self, |
| | destination: str, |
| | duration: int, |
| | preferences: str, |
| | budget: str, |
| | max_length: int = 1024, |
| | temperature: float = 0.7, |
| | top_p: float = 0.9, |
| | ) -> str: |
| | prompt = f"""Generate a detailed travel itinerary for {destination} for {duration} days. |
| | Preferences: {preferences} |
| | Budget: {budget} |
| | |
| | Detailed Itinerary:""" |
| |
|
| | inputs = self.tokenizer(prompt, return_tensors="pt", max_length=max_length, truncation=True).to(self.model.device) |
| | |
| | with torch.no_grad(): |
| | outputs = self.model.generate( |
| | **inputs, |
| | max_length=max_length, |
| | temperature=temperature, |
| | top_p=top_p, |
| | num_return_sequences=1, |
| | pad_token_id=self.tokenizer.eos_token_id |
| | ) |
| |
|
| | generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
| | |
| | itinerary = generated_text[len(prompt):] |
| | return itinerary.strip() |
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser(description="Generate travel itineraries using fine-tuned LLaMA model") |
| | parser.add_argument("--model_path", type=str, required=True, help="Path to the fine-tuned model") |
| | parser.add_argument("--destination", type=str, required=True, help="Travel destination") |
| | parser.add_argument("--duration", type=int, required=True, help="Number of days") |
| | parser.add_argument("--preferences", type=str, required=True, help="Travel preferences") |
| | parser.add_argument("--budget", type=str, required=True, help="Travel budget") |
| | parser.add_argument("--output", type=str, help="Output file path (optional)") |
| | |
| | args = parser.parse_args() |
| | |
| | generator = ItineraryGenerator(args.model_path) |
| | |
| | itinerary = generator.generate_itinerary( |
| | destination=args.destination, |
| | duration=args.duration, |
| | preferences=args.preferences, |
| | budget=args.budget |
| | ) |
| | |
| | output = { |
| | "destination": args.destination, |
| | "duration": args.duration, |
| | "preferences": args.preferences, |
| | "budget": args.budget, |
| | "generated_itinerary": itinerary |
| | } |
| | |
| | if args.output: |
| | with open(args.output, 'w') as f: |
| | json.dump(output, f, indent=2) |
| | print(f"Itinerary saved to {args.output}") |
| | else: |
| | print("\nGenerated Itinerary:") |
| | print("=" * 50) |
| | print(itinerary) |
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|