""" Gradio web app for Shakespeare-style text generation using the trained GPT model. This app provides an interactive interface for users to generate Shakespeare-style text with customizable parameters. """ import os import torch import gradio as gr from model import GPT, GPTConfig import tiktoken torch.set_default_device('cpu') class ShakespeareTextGenerator: def __init__(self, model_path='compressed_model_cpu_compatible.pt'): """Initialize the text generator with the trained model""" self.device = 'cuda' if torch.cuda.is_available() else 'cpu' # Load checkpoint checkpoint = torch.load(model_path, map_location=self.device) # Initialize model with saved config self.config = GPTConfig(**checkpoint['config']) self.model = GPT(self.config) # Load state dict and convert to correct dtype if needed if checkpoint['dtype'] == 'float16' and self.device == 'cuda': self.model.half() elif checkpoint['dtype'] == 'float32': self.model.float() self.model.load_state_dict(checkpoint['model_state_dict']) self.model.to(self.device) self.model.eval() # Initialize tokenizer with special token handling self.tokenizer = tiktoken.get_encoding('gpt2') self.end_token = self.tokenizer.encode('<|endoftext|>', allowed_special={'<|endoftext|>'})[0] def generate(self, prompt, max_length=100, temperature=0.7, top_k=50, top_p=0.9, num_return_sequences=1): """ Generate Shakespeare-style text based on the prompt """ # Encode the prompt with special token handling input_ids = torch.tensor( self.tokenizer.encode(prompt, allowed_special=set()) ).unsqueeze(0).to(self.device) generated_sequences = [] with torch.no_grad(): for _ in range(num_return_sequences): # Initialize sequence with input_ids cur_ids = input_ids.clone() for _ in range(max_length): # Get model's logits for next token outputs, _ = self.model(cur_ids) next_token_logits = outputs[:, -1, :] / temperature # Apply top-k filtering if top_k > 0: values, _ = torch.topk(next_token_logits, top_k) min_value = values[:, -1].unsqueeze(-1).expand_as(next_token_logits) next_token_logits = torch.where( next_token_logits < min_value, torch.ones_like(next_token_logits) * float('-inf'), next_token_logits ) # Apply top-p (nucleus) filtering if top_p < 1.0: sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True) cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) # Remove tokens with cumulative probability above the threshold sorted_indices_to_remove = cumulative_probs > top_p # Shift the indices to the right to keep also the first token above the threshold sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 # Scatter sorted tensors to original indexing indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) next_token_logits = next_token_logits.masked_fill(indices_to_remove, float('-inf')) # Sample next token probs = torch.softmax(next_token_logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) # Append to sequence cur_ids = torch.cat([cur_ids, next_token], dim=1) # Stop if we predict the end of text token if next_token.item() == self.end_token: break # Decode the generated sequence generated_text = self.tokenizer.decode(cur_ids[0].tolist()) generated_sequences.append(generated_text) return generated_sequences # Initialize the generator generator = ShakespeareTextGenerator() def generate_text(prompt, max_length, temperature, top_k, top_p, num_sequences): """Gradio interface function""" try: sequences = generator.generate( prompt=prompt, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, num_return_sequences=num_sequences ) return "\n\n---\n\n".join(sequences) except Exception as e: return f"Error: {str(e)}" # Create Gradio interface iface = gr.Interface( fn=generate_text, inputs=[ gr.Textbox( lines=3, label="Prompt", placeholder="Enter your prompt here...", value="To be, or not to be," ), gr.Slider( minimum=10, maximum=500, value=100, step=10, label="Maximum Length" ), gr.Slider( minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature (randomness)" ), gr.Slider( minimum=0, maximum=100, value=50, step=5, label="Top-k" ), gr.Slider( minimum=0.0, maximum=1.0, value=0.9, step=0.05, label="Top-p (nucleus sampling)" ), gr.Slider( minimum=1, maximum=5, value=1, step=1, label="Number of Sequences" ) ], outputs=gr.Textbox( lines=10, label="Generated Text" ), title="Shakespeare-Style Text Generator", description="""Generate Shakespeare-style text using a fine-tuned GPT model. Training repository: [https://github.com/dhairyag/ShakespeareGPT-Forge](https://github.com/dhairyag/ShakespeareGPT-Forge) Adjust the parameters to control the generation: - Temperature: Higher values make the output more random - Top-k: Limits the vocabulary to the k most likely tokens - Top-p: Limits the cumulative probability of tokens considered - Number of Sequences: Generate multiple variations""", examples=[ ["To be, or not to be,", 100, 0.7, 50, 0.9, 1], ["O Romeo, Romeo,", 150, 0.8, 40, 0.85, 2], ["All the world's a stage,", 200, 0.6, 60, 0.95, 1] ] ) # Launch the app if __name__ == "__main__": iface.launch()