MedExamDoc-Llama-3.1-Swallow-8B-Instruct-v0.5
MedExamDoc-Llama-3.1-Swallow-8B-Instruct-v0.5 is a fine-tuned QLora of tokyotech-llm/Llama-3.1-Swallow-8B-Instruct-v0.5 to enhance capability for Japanese Medical Exam.
- We trained a QLora for Japanese medical exams: MedExamDoc-Llama-3.1-Swallow-8B-Instruct-v0.5.
- Since Llama-3.1 provides strong foundational capabilities, and Llama-3.1-Swallow-8B-Instruct-v0.5 is well fine-tuned for Japanese, we chose them as our main base model.
- We used multiple-choice question data from the MIT-licensed portion of JMedBench (218,912 samples) and pharmacy-related data from KokushiMD-10 (1,386 samples) for training model. We evaluated the model on IgakuQA, which includes 2,000 medical exam questions from 2018–2022. Given the relatively large training set, we trained for only one epoch and adopted the lightweight QLoRA technique.
- After fine-tuning, model’s accuracy improved from 55.75% to 62.40%, a gain of 6.65%. Despite being only an 8B model, it outperforms the ChatGPT baselines reported in IgakuQA (ChatGPT: 53.95%, Translate-ChatGPT: 56.60%).
Model Overview
- Developer: Ingenta Inc.
- Base Model: Llama-3.1-Swallow-8B-Instruct-v0.5, Llama-3.1-8B-Instruct
- Training Tool: Axolotl
- Supported Languages: Japanese
- License: Llama 3.1 and Gemma
Base Model Reference
The following tabel show the model related to our work
| Used Model | License |
|---|---|
| Llama-3.1-Swallow-8B-Instruct-v0.5 | Llama 3.1 and Gemma |
| Llama-3.1-8B-Instruct | Llama 3.1 |
Training Configure
We use axolotl to train QLora.
See axolotl training config
axolotl version: 0.10.0
base_model: tokyotech-llm/Llama-3.1-Swallow-8B-Instruct-v0.5
# optionally might have model_type or tokenizer_type
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
load_in_8bit: false
load_in_4bit: true
datasets:
- path: merged_medical_qa_MIT.json
type: alpaca
dataset_prepared_path:
val_set_size: 0
output_dir: ./outputs/qlora-out_swallow-8b
adapter: qlora
lora_model_dir:
sequence_len: 2048
sample_packing: true
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: paged_adamw_32bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_steps: 10
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
pad_token: "<|end_of_text|>"
Usage
- Preparing environment
conda create --name medexamdoc python=3.11
conda activate medexamdoc
pip install PEFT==0.15.2 Transformers==4.52.3 torch==2.5.1 Datasets==3.6.0 Tokenizers==0.21.2
hf download IngentaAITeam/MedExamDoc-Llama-3.1-Swallow-8B-Instruct-v0.5
- Python inference
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import torch
import re
import time
# 基本モデル読み込み
model_path = "IngentaAITeam/MedExamDoc-Llama-3.1-Swallow-8B-Instruct-v0.5"
def load_model(model_name, device="cuda"):
"""モデルを読み込み"""
print(f"モデルを読み込み中:{model_name}")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
device_map="auto" if device == "cuda" else None
)
model.eval()
return tokenizer, model
def create_medical_prompt_template():
"""医学問題のプロンプトテンプレートを作成"""
template = """Answer this medical multiple choice question by selecting the correct option letter (A, B, C, D, or E).
Question: {question}
Options:
{options}
Answer:"""
return template
def format_question(question_data):
"""問題をフォーマット"""
question = question_data['question']
options = question_data['options']
# 選択肢をフォーマット(訓練データ形式と完全一致)
options_text = ""
for key, value in options.items():
options_text += f"{key}. {value}\n"
template = create_medical_prompt_template()
prompt = template.format(
question=question,
options=options_text.rstrip() # 最後の改行を削除
)
return prompt
def extract_answer(response):
"""モデルの回答から選択肢の文字を抽出"""
pattern = r'\b[A-E]\b'
matches = re.findall(pattern, response.upper())
return matches[0] if matches else None
def extract_multiple_answers(response):
"""複数選択肢の回答を抽出し、重複を除去"""
pattern = r'\b[A-E]\b'
matches = re.findall(pattern, response.upper())
# 重複を除去し、ソート
unique_matches = list(dict.fromkeys(matches)) # 順序を保持して重複除去
return ''.join(sorted(unique_matches)) if unique_matches else None
def generate_answer(tokenizer, model, prompt, device="cuda"):
"""回答を生成"""
inputs = tokenizer(prompt, return_tensors="pt").to(device)
# 推論時間を測定
start_time = time.time()
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=64,
temperature=0.0,
do_sample=False,
pad_token_id=tokenizer.eos_token_id
)
end_time = time.time()
response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
inference_time = end_time - start_time
return response.strip(), inference_time
# 使用例
tokenizer, model = load_model(model_path)
# 医学問題の例
question_data = {
'question': 'What is the most common cause of acute myocardial infarction?',
'options': {
'A': 'Coronary artery spasm',
'B': 'Atherosclerotic plaque rupture',
'C': 'Coronary artery embolism',
'D': 'Coronary artery dissection',
'E': 'Takotsubo cardiomyopathy'
}
}
# プロンプトをフォーマット
prompt = format_question(question_data)
print("プロンプト:")
print(prompt)
print("---" * 40)
# 回答を生成
response, inference_time = generate_answer(tokenizer, model, prompt)
predicted_answer = extract_answer(response)
print("モデル回答:")
print(response)
print(f"抽出された選択肢: {predicted_answer}")
print(f"推論時間: {inference_time:.3f}秒")
Training and evaluation data
Training Data
| Dataset | # of Training Samples | Data Source | License | Data Selection Method | Version(Commit ID) |
|---|---|---|---|---|---|
| JMedBench | 218,912 | - medmcqa_jp (translated from MedMCQA) - usmleqa_jp (translated from MedQA) - medqa_jp (translated from MedQA) - mmlu_medical_jp (translated from MMLU) - pubmedqa_jp (translated from PubMedQA) |
MIT | MultipleChoiceQA | fe772d4fb76c11a4b24e06a2d06c72a7e3e32ef5 |
| KokushiMD-10 | 1,386 | Japanese national healthcare licensing examinations (2020–2024) | MIT | text_only=True & profession=pharmacy (some questions in profession medicine overlapping with IgakuQA test set, so profession medicine are excluded) | c381c014c6769d0a8ca40356d7c30a969a12816d |
Testing Data
| Dataset | # of Testing Samples | Data Source | License | Data Selection Method | Version(Commit ID) |
|---|---|---|---|---|---|
| IgakuQA | 2,000 | Japanese medical licensing examinations (2018–2022) | Public(Ministry of Health, Labour and Welfare) | All data | 2bc4c3d159cf5505f6253d24a909fbd53237e239 |
Model Performance
- We evaluate our model on Igaku dataset. Other baseline result provided by IgakuQA are also listed below. As shown in the table below, although our MedExamDoc-Llama-3.1-Swallow-8B-Instruct-v0.5 is based on an 8B parameter model, it outperforms ChatGPT's results from that year.
- As for the fine-tuned JPharmatron-7B, it was mainly used as a baseline and is not the primary model we propose in this work; therefore, we do not upload the model. The training parameters are the same as those used for training our MedExamDoc-Llama-3.1-Swallow-8B-Instruct-v0.5.
Evaluation Results on IgakuQA Benchmark
For our experiment, the prompt for generating answer is as follows:
"""Answer this medical multiple choice question by selecting the correct option letter (A, B, C, D, or E).
Question: {question}
Options:
{options}
Answer:"""
See the table below for evaluation results:
| Model Configuration | Overall Accuracy | Single-choice accuracy | Multiple-choice accuracy | Notes |
|---|---|---|---|---|
| Llama-3.1-Swallow-8B-Instruct-v0.5 | 55.75% | 60.33% | 30.87% | Base model |
| MedExamDoc-Llama-3.1-Swallow-8B-Instruct-v0.5 | 62.40% | 66.31% | 41.16% | Our fine-tuned model |
| JPharmatron-7B | 61.25% | 67.02% | 29.90% | We use the open source model with our script to test accuracy |
| JPharmatron-7B + finetune | 65.90% | 71.11% | 37.62% | We finetune the model, and use our script to test accuracy. After finetune, the accuracy improve by 4.65%. |
| student_majority | 93.90% | 94.24% | 91.95% | Provided by IgakuQA; selects the option most frequently chosen by students |
| GPT-4 | 76.60% | 77.97% | 68.79% | Provided by IgakuQA benchmark |
| translate_chatgpt | 56.60% | 60.11% | 36.58% | Provided by IgakuQA benchmark; approximately ChatGPT (2023) with translation |
| ChatGPT | 53.95% | 56.99% | 36.58% | Provided by IgakuQA benchmark; approximately ChatGPT (2023) |
| GPT-3 | 40.35% | 43.13% | 24.50% | Provided by IgakuQA benchmark |
Framework versions
- PEFT 0.15.2
- Transformers 4.52.3
- Pytorch 2.5.1
- Datasets 3.6.0
- Tokenizers 0.21.2
- Downloads last month
- 2
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
🙋
Ask for provider support
Model tree for IngentaAITeam/MedExamDoc-Llama-3.1-Swallow-8B-Instruct-v0.5
Base model
meta-llama/Llama-3.1-8B
Finetuned
meta-llama/Llama-3.1-8B-Instruct
Finetuned
tokyotech-llm/Llama-3.1-Swallow-8B-v0.5