Text Classification
Adapters
biology
naifenn commited on
Commit
845f390
·
verified ·
1 Parent(s): bb7f618

Upload load_data.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. load_data.py +109 -0
load_data.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ from collections import defaultdict
3
+ import json
4
+ import re
5
+ import random
6
+
7
+ folder = "data/"
8
+ system_message = "You are a medical diagnosis classifier. Given a description of symptoms, provide ONLY the name of the most likely diagnosis. Do not include explanations, reasoning, or additional text."
9
+
10
+ # Load and shuffle the dataset
11
+ dataset = load_dataset("sajjadhadi/disease-diagnosis-dataset", split="train")
12
+ dataset = dataset.shuffle(seed=42)
13
+
14
+ # Function to clean symptom text into a standardized format
15
+ def clean_symptom_text(text):
16
+ pattern = r'(?:patient reported the following symptoms:|symptoms include:?)?\s*(.*?)(?:\s*(?:may indicate|based on these symptoms|what disease may the patient have\?|what is the most likely diagnosis\?).*)'
17
+ match = re.search(pattern, text, re.IGNORECASE)
18
+ if match:
19
+ symptoms = match.group(1).strip()
20
+ symptoms = re.sub(r'\s*,\s*', ', ', symptoms).rstrip(',')
21
+ return f"{symptoms}"
22
+ return text
23
+
24
+ # Group samples by diagnosis
25
+ diagnosis_to_samples = defaultdict(list)
26
+ for i, sample in enumerate(dataset):
27
+ diagnosis_to_samples[sample["diagnosis"]].append(i)
28
+
29
+ # TODO: @Tingzhen important Select top 50 diagnoses with at least MIN_SAMPLES
30
+ TARGET_SAMPLES = 300
31
+ MIN_SAMPLES = 75
32
+
33
+ top_diagnoses = [diag for diag, indices in sorted(diagnosis_to_samples.items(),
34
+ key=lambda x: len(x[1]), reverse=True)
35
+ if len(indices) >= MIN_SAMPLES][:MIN_SAMPLES]
36
+
37
+ print(top_diagnoses)
38
+ # Balance the dataset: ensure TARGET_SAMPLES per diagnosis
39
+ balanced_indices = []
40
+ for diag in top_diagnoses:
41
+ indices = diagnosis_to_samples[diag]
42
+ if len(indices) >= TARGET_SAMPLES:
43
+ # Cap at TARGET_SAMPLES
44
+ selected_indices = indices[:TARGET_SAMPLES]
45
+ else:
46
+ # Oversample to reach TARGET_SAMPLES
47
+ selected_indices = indices * (TARGET_SAMPLES // len(indices)) # Repeat full set
48
+ remaining = TARGET_SAMPLES % len(indices) # Add remaining
49
+ selected_indices.extend(random.sample(indices, remaining)) # Randomly sample extras
50
+ balanced_indices.extend(selected_indices)
51
+
52
+ # Create balanced dataset
53
+ balanced_dataset = dataset.select(balanced_indices)
54
+ print(f"Original dataset size: {len(dataset)}, Balanced dataset size: {len(balanced_indices)}")
55
+ print(f"Number of unique diagnoses: {len(top_diagnoses)}")
56
+
57
+ # Create train/test/validation splits
58
+ splits = balanced_dataset.train_test_split(test_size=0.2, seed=42)
59
+ test_valid_splits = splits['test'].train_test_split(test_size=0.5, seed=42)
60
+
61
+ # Function to convert samples to required format and save as JSONL
62
+ def save_as_jsonl(dataset, filename):
63
+ with open(filename, 'w') as file:
64
+ for sample in dataset:
65
+ cleaned_text = clean_symptom_text(sample["text"])
66
+ conversation = {
67
+ "messages": [
68
+ {"role": "system", "content": system_message},
69
+ {"role": "user", "content": cleaned_text},
70
+ {"role": "assistant", "content": sample["diagnosis"]}
71
+ ]
72
+ }
73
+ file.write(json.dumps(conversation) + '\n')
74
+
75
+ # Save datasets
76
+ save_as_jsonl(splits["train"], folder + "train.jsonl")
77
+ save_as_jsonl(test_valid_splits["train"], folder + "test.jsonl")
78
+ save_as_jsonl(test_valid_splits["test"], folder + "valid.jsonl")
79
+
80
+ # Print statistics
81
+ print("Dataset splits:")
82
+ print(f" Train: {len(splits['train'])}")
83
+ print(f" Test: {len(test_valid_splits['train'])}")
84
+ print(f" Validation: {len(test_valid_splits['test'])}")
85
+
86
+ # Sample validation
87
+ print("\nSample validation:")
88
+ with open(folder + "train.jsonl", 'r') as file:
89
+ for i, line in enumerate(file):
90
+ if i >= 3:
91
+ break
92
+ example = json.loads(line)
93
+ print(f"Example {i+1}:")
94
+ print(f" System: {example['messages'][0]['content']}")
95
+ print(f" User: {example['messages'][1]['content']}")
96
+ print(f" Assistant: {example['messages'][2]['content']}")
97
+ print()
98
+
99
+ # Check class distribution in training set
100
+ class_counts = defaultdict(int)
101
+ with open(folder + "train.jsonl", 'r') as file:
102
+ for line in file:
103
+ example = json.loads(line)
104
+ diagnosis = example['messages'][2]['content']
105
+ class_counts[diagnosis] += 1
106
+
107
+ print("\nClass distribution in training set:")
108
+ for diagnosis, count in sorted(class_counts.items(), key=lambda x: x[1], reverse=True)[:10]:
109
+ print(f" {diagnosis}: {count}")