| | import torch |
| | from transformers import PreTrainedModel, PreTrainedTokenizerFast, PretrainedConfig, CausalLMOutput |
| |
|
| | |
| | class HelloWorldConfig(PretrainedConfig): |
| | model_type = "hello-world" |
| | vocab_size = 2 |
| | bos_token_id = 0 |
| | eos_token_id = 1 |
| |
|
| | |
| | class HelloWorldModel(PreTrainedModel): |
| | config_class = HelloWorldConfig |
| |
|
| | def __init__(self, config): |
| | super().__init__(config) |
| |
|
| | def forward(self, input_ids=None, **kwargs): |
| | batch_size = input_ids.shape[0] |
| | sequence_length = input_ids.shape[1] |
| |
|
| | |
| | hello_world_token_id = self.config.vocab_size - 1 |
| | logits = torch.full((batch_size, sequence_length, self.config.vocab_size), float('-inf')) |
| | logits[:, :, hello_world_token_id] = 0 |
| |
|
| | return CausalLMOutput(logits=logits) |
| |
|
| | |
| | tokenizer = PreTrainedTokenizerFast(tokenizer_file="tokenizer.json") |
| | tokenizer.add_tokens(["Hello, world!"]) |
| |
|
| | tokenizer_config = { |
| | "do_lower_case": False, |
| | "model_max_length": 512, |
| | "padding_side": "right", |
| | "special_tokens_map_file": None, |
| | "tokenizer_file": "tokenizer.json", |
| | "unk_token": "<unk>", |
| | "bos_token": "<s>", |
| | "eos_token": "</s>", |
| | "vocab_size": 2, |
| | } |
| |
|
| | with open("tokenizer.json", "w") as f: |
| | json.dump(tokenizer_config, f) |
| |
|
| | |
| | config = HelloWorldConfig() |
| | model = HelloWorldModel(config) |
| |
|
| | |
| | from safetensors.torch import save_file |
| | save_file(model.state_dict(), "hello_world_model.safetensors") |
| |
|