MedExamDoc-Llama-3.1-Swallow-8B-Instruct-v0.5

MedExamDoc-Llama-3.1-Swallow-8B-Instruct-v0.5 is a fine-tuned QLora of tokyotech-llm/Llama-3.1-Swallow-8B-Instruct-v0.5 to enhance capability for Japanese Medical Exam.

  • We trained a QLora for Japanese medical exams: MedExamDoc-Llama-3.1-Swallow-8B-Instruct-v0.5.
  • Since Llama-3.1 provides strong foundational capabilities, and Llama-3.1-Swallow-8B-Instruct-v0.5 is well fine-tuned for Japanese, we chose them as our main base model.
  • We used multiple-choice question data from the MIT-licensed portion of JMedBench (218,912 samples) and pharmacy-related data from KokushiMD-10 (1,386 samples) for training model. We evaluated the model on IgakuQA, which includes 2,000 medical exam questions from 2018–2022. Given the relatively large training set, we trained for only one epoch and adopted the lightweight QLoRA technique.
  • After fine-tuning, model’s accuracy improved from 55.75% to 62.40%, a gain of 6.65%. Despite being only an 8B model, it outperforms the ChatGPT baselines reported in IgakuQA (ChatGPT: 53.95%, Translate-ChatGPT: 56.60%).

Model Overview

Base Model Reference

The following tabel show the model related to our work

Training Configure

We use axolotl to train QLora.

Built with Axolotl

See axolotl training config

axolotl version: 0.10.0

base_model: tokyotech-llm/Llama-3.1-Swallow-8B-Instruct-v0.5
# optionally might have model_type or tokenizer_type
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name

load_in_8bit: false
load_in_4bit: true

datasets:
  - path: merged_medical_qa_MIT.json
    type: alpaca
dataset_prepared_path:
val_set_size: 0
output_dir: ./outputs/qlora-out_swallow-8b

adapter: qlora
lora_model_dir:

sequence_len: 2048
sample_packing: true
pad_to_sequence_len: true

lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true

wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: paged_adamw_32bit
lr_scheduler: cosine
learning_rate: 0.0002

bf16: auto
tf32: false

gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true

warmup_steps: 10
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
  pad_token: "<|end_of_text|>" 

Usage

  • Preparing environment
conda create --name medexamdoc python=3.11
conda activate medexamdoc
pip install PEFT==0.15.2 Transformers==4.52.3 torch==2.5.1 Datasets==3.6.0 Tokenizers==0.21.2
hf download IngentaAITeam/MedExamDoc-Llama-3.1-Swallow-8B-Instruct-v0.5
  • Python inference
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import torch
import re
import time

# 基本モデル読み込み
model_path = "IngentaAITeam/MedExamDoc-Llama-3.1-Swallow-8B-Instruct-v0.5"

def load_model(model_name, device="cuda"):
    """モデルを読み込み"""
    print(f"モデルを読み込み中:{model_name}")
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float16 if device == "cuda" else torch.float32,
        device_map="auto" if device == "cuda" else None
    )
    model.eval()
    return tokenizer, model

def create_medical_prompt_template():
    """医学問題のプロンプトテンプレートを作成"""
    template = """Answer this medical multiple choice question by selecting the correct option letter (A, B, C, D, or E).

Question: {question}

Options:
{options}
Answer:"""
    return template

def format_question(question_data):
    """問題をフォーマット"""
    question = question_data['question']
    options = question_data['options']
    
    # 選択肢をフォーマット(訓練データ形式と完全一致)
    options_text = ""
    for key, value in options.items():
        options_text += f"{key}. {value}\n"
    
    template = create_medical_prompt_template()
    prompt = template.format(
        question=question,
        options=options_text.rstrip()  # 最後の改行を削除
    )
    
    return prompt

def extract_answer(response):
    """モデルの回答から選択肢の文字を抽出"""
    pattern = r'\b[A-E]\b'
    matches = re.findall(pattern, response.upper())
    return matches[0] if matches else None

def extract_multiple_answers(response):
    """複数選択肢の回答を抽出し、重複を除去"""
    pattern = r'\b[A-E]\b'
    matches = re.findall(pattern, response.upper())
    # 重複を除去し、ソート
    unique_matches = list(dict.fromkeys(matches))  # 順序を保持して重複除去
    return ''.join(sorted(unique_matches)) if unique_matches else None

def generate_answer(tokenizer, model, prompt, device="cuda"):
    """回答を生成"""
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    
    # 推論時間を測定
    start_time = time.time()
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=64,
            temperature=0.0,
            do_sample=False,
            pad_token_id=tokenizer.eos_token_id
        )
    end_time = time.time()
    
    response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
    inference_time = end_time - start_time
    
    return response.strip(), inference_time

# 使用例
tokenizer, model = load_model(model_path)

# 医学問題の例
question_data = {
    'question': 'What is the most common cause of acute myocardial infarction?',
    'options': {
        'A': 'Coronary artery spasm',
        'B': 'Atherosclerotic plaque rupture', 
        'C': 'Coronary artery embolism',
        'D': 'Coronary artery dissection',
        'E': 'Takotsubo cardiomyopathy'
    }
}

# プロンプトをフォーマット
prompt = format_question(question_data)
print("プロンプト:")
print(prompt)
print("---" * 40)

# 回答を生成
response, inference_time = generate_answer(tokenizer, model, prompt)
predicted_answer = extract_answer(response)

print("モデル回答:")
print(response)
print(f"抽出された選択肢: {predicted_answer}")
print(f"推論時間: {inference_time:.3f}秒")

Training and evaluation data

Training Data

Dataset # of Training Samples Data Source License Data Selection Method Version(Commit ID)
JMedBench 218,912 - medmcqa_jp (translated from MedMCQA)
- usmleqa_jp (translated from MedQA)
- medqa_jp (translated from MedQA)
- mmlu_medical_jp (translated from MMLU)
- pubmedqa_jp (translated from PubMedQA)
MIT MultipleChoiceQA fe772d4fb76c11a4b24e06a2d06c72a7e3e32ef5
KokushiMD-10 1,386 Japanese national healthcare licensing examinations (2020–2024) MIT text_only=True & profession=pharmacy (some questions in profession medicine overlapping with IgakuQA test set, so profession medicine are excluded) c381c014c6769d0a8ca40356d7c30a969a12816d

Testing Data

Dataset # of Testing Samples Data Source License Data Selection Method Version(Commit ID)
IgakuQA 2,000 Japanese medical licensing examinations (2018–2022) Public(Ministry of Health, Labour and Welfare) All data 2bc4c3d159cf5505f6253d24a909fbd53237e239

Model Performance

  • We evaluate our model on Igaku dataset. Other baseline result provided by IgakuQA are also listed below. As shown in the table below, although our MedExamDoc-Llama-3.1-Swallow-8B-Instruct-v0.5 is based on an 8B parameter model, it outperforms ChatGPT's results from that year.
  • As for the fine-tuned JPharmatron-7B, it was mainly used as a baseline and is not the primary model we propose in this work; therefore, we do not upload the model. The training parameters are the same as those used for training our MedExamDoc-Llama-3.1-Swallow-8B-Instruct-v0.5.

Evaluation Results on IgakuQA Benchmark

For our experiment, the prompt for generating answer is as follows:

"""Answer this medical multiple choice question by selecting the correct option letter (A, B, C, D, or E).

Question: {question}

Options:
{options}
Answer:"""

See the table below for evaluation results:

Model Configuration Overall Accuracy Single-choice accuracy Multiple-choice accuracy Notes
Llama-3.1-Swallow-8B-Instruct-v0.5 55.75% 60.33% 30.87% Base model
MedExamDoc-Llama-3.1-Swallow-8B-Instruct-v0.5 62.40% 66.31% 41.16% Our fine-tuned model
JPharmatron-7B 61.25% 67.02% 29.90% We use the open source model with our script to test accuracy
JPharmatron-7B + finetune 65.90% 71.11% 37.62% We finetune the model, and use our script to test accuracy. After finetune, the accuracy improve by 4.65%.
student_majority 93.90% 94.24% 91.95% Provided by IgakuQA; selects the option most frequently chosen by students
GPT-4 76.60% 77.97% 68.79% Provided by IgakuQA benchmark
translate_chatgpt 56.60% 60.11% 36.58% Provided by IgakuQA benchmark; approximately ChatGPT (2023) with translation
ChatGPT 53.95% 56.99% 36.58% Provided by IgakuQA benchmark; approximately ChatGPT (2023)
GPT-3 40.35% 43.13% 24.50% Provided by IgakuQA benchmark

Framework versions

  • PEFT 0.15.2
  • Transformers 4.52.3
  • Pytorch 2.5.1
  • Datasets 3.6.0
  • Tokenizers 0.21.2
Downloads last month
2
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for IngentaAITeam/MedExamDoc-Llama-3.1-Swallow-8B-Instruct-v0.5

Datasets used to train IngentaAITeam/MedExamDoc-Llama-3.1-Swallow-8B-Instruct-v0.5