Text Classification
Adapters
biology
File size: 4,454 Bytes
845f390
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
from datasets import load_dataset
from collections import defaultdict
import json
import re
import random

folder = "data/"
system_message = "You are a medical diagnosis classifier. Given a description of symptoms, provide ONLY the name of the most likely diagnosis. Do not include explanations, reasoning, or additional text."

# Load and shuffle the dataset
dataset = load_dataset("sajjadhadi/disease-diagnosis-dataset", split="train")
dataset = dataset.shuffle(seed=42)

# Function to clean symptom text into a standardized format
def clean_symptom_text(text):
    pattern = r'(?:patient reported the following symptoms:|symptoms include:?)?\s*(.*?)(?:\s*(?:may indicate|based on these symptoms|what disease may the patient have\?|what is the most likely diagnosis\?).*)'
    match = re.search(pattern, text, re.IGNORECASE)
    if match:
        symptoms = match.group(1).strip()
        symptoms = re.sub(r'\s*,\s*', ', ', symptoms).rstrip(',')
        return f"{symptoms}"
    return text

# Group samples by diagnosis
diagnosis_to_samples = defaultdict(list)
for i, sample in enumerate(dataset):
    diagnosis_to_samples[sample["diagnosis"]].append(i)

# TODO: @Tingzhen important Select top 50 diagnoses with at least MIN_SAMPLES
TARGET_SAMPLES = 300
MIN_SAMPLES = 75

top_diagnoses = [diag for diag, indices in sorted(diagnosis_to_samples.items(), 
                                                  key=lambda x: len(x[1]), reverse=True) 
                 if len(indices) >= MIN_SAMPLES][:MIN_SAMPLES]

print(top_diagnoses)
# Balance the dataset: ensure TARGET_SAMPLES per diagnosis
balanced_indices = []
for diag in top_diagnoses:
    indices = diagnosis_to_samples[diag]
    if len(indices) >= TARGET_SAMPLES:
        # Cap at TARGET_SAMPLES
        selected_indices = indices[:TARGET_SAMPLES]
    else:
        # Oversample to reach TARGET_SAMPLES
        selected_indices = indices * (TARGET_SAMPLES // len(indices))  # Repeat full set
        remaining = TARGET_SAMPLES % len(indices)  # Add remaining
        selected_indices.extend(random.sample(indices, remaining))  # Randomly sample extras
    balanced_indices.extend(selected_indices)

# Create balanced dataset
balanced_dataset = dataset.select(balanced_indices)
print(f"Original dataset size: {len(dataset)}, Balanced dataset size: {len(balanced_indices)}")
print(f"Number of unique diagnoses: {len(top_diagnoses)}")

# Create train/test/validation splits
splits = balanced_dataset.train_test_split(test_size=0.2, seed=42)
test_valid_splits = splits['test'].train_test_split(test_size=0.5, seed=42)

# Function to convert samples to required format and save as JSONL
def save_as_jsonl(dataset, filename):
    with open(filename, 'w') as file:
        for sample in dataset:
            cleaned_text = clean_symptom_text(sample["text"])
            conversation = {
                "messages": [
                    {"role": "system", "content": system_message},
                    {"role": "user", "content": cleaned_text},
                    {"role": "assistant", "content": sample["diagnosis"]}
                ]
            }
            file.write(json.dumps(conversation) + '\n')

# Save datasets
save_as_jsonl(splits["train"], folder + "train.jsonl")
save_as_jsonl(test_valid_splits["train"], folder + "test.jsonl")
save_as_jsonl(test_valid_splits["test"], folder + "valid.jsonl")

# Print statistics
print("Dataset splits:")
print(f"  Train: {len(splits['train'])}")
print(f"  Test: {len(test_valid_splits['train'])}")
print(f"  Validation: {len(test_valid_splits['test'])}")

# Sample validation
print("\nSample validation:")
with open(folder + "train.jsonl", 'r') as file:
    for i, line in enumerate(file):
        if i >= 3:
            break
        example = json.loads(line)
        print(f"Example {i+1}:")
        print(f"  System: {example['messages'][0]['content']}")
        print(f"  User: {example['messages'][1]['content']}")
        print(f"  Assistant: {example['messages'][2]['content']}")
        print()

# Check class distribution in training set
class_counts = defaultdict(int)
with open(folder + "train.jsonl", 'r') as file:
    for line in file:
        example = json.loads(line)
        diagnosis = example['messages'][2]['content']
        class_counts[diagnosis] += 1

print("\nClass distribution in training set:")
for diagnosis, count in sorted(class_counts.items(), key=lambda x: x[1], reverse=True)[:10]:
    print(f"  {diagnosis}: {count}")