|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os
|
|
|
import re
|
|
|
import csv
|
|
|
import sys
|
|
|
import json
|
|
|
import math
|
|
|
import sqlite3
|
|
|
import random
|
|
|
import argparse
|
|
|
from typing import List, Tuple, Dict
|
|
|
from concurrent.futures import ProcessPoolExecutor
|
|
|
|
|
|
import numpy as np
|
|
|
from tqdm import tqdm
|
|
|
|
|
|
try:
|
|
|
from transformers import AutoTokenizer
|
|
|
except ImportError:
|
|
|
print("ERROR: transformers not installed. pip install transformers", file=sys.stderr); sys.exit(1)
|
|
|
|
|
|
LINK_START = "[LINK_START]"
|
|
|
LINK_END = "[LINK_END]"
|
|
|
|
|
|
|
|
|
IMG_INLINE_RE = re.compile(r'!\[[^\]]*\]\([^)]*\)')
|
|
|
INLINE_LINK_RE = re.compile(r'\[([^\]]+)\]\([^)]*\)')
|
|
|
REF_LINK_RE = re.compile(r'\[([^\]]+)\]\[[^\]]+\]')
|
|
|
REF_DEF_RE = re.compile(r'^[ \t]{0,3}\[[^\]]+\]:\s+\S+.*$', re.MULTILINE)
|
|
|
AUTOLINK_RE = re.compile(r'<https?[^>]+>')
|
|
|
BARE_URL_RE = re.compile(r'https?://\S+|www\.\S+')
|
|
|
CODE_TICKS_RE = re.compile(r'`+')
|
|
|
EMPH_RE = re.compile(r'[*]+')
|
|
|
HEAD_RE = re.compile(r'^[ \t]*#+[ \t]*', re.MULTILINE)
|
|
|
QUOTE_RE = re.compile(r'^(>+\s*)+', re.MULTILINE)
|
|
|
WS_RE = re.compile(r'\s+')
|
|
|
|
|
|
def annotate_one(md_text: str) -> Tuple[str, int]:
|
|
|
"""Return (single-line annotated text, has_marker[0/1])."""
|
|
|
if not md_text:
|
|
|
return "", 0
|
|
|
t = md_text
|
|
|
|
|
|
|
|
|
t = IMG_INLINE_RE.sub('', t)
|
|
|
t = INLINE_LINK_RE.sub(lambda m: f"{LINK_START}{m.group(1)}{LINK_END}", t)
|
|
|
t = REF_LINK_RE.sub(lambda m: f"{LINK_START}{m.group(1)}{LINK_END}", t)
|
|
|
t = REF_DEF_RE.sub('', t)
|
|
|
t = AUTOLINK_RE.sub('', t)
|
|
|
t = BARE_URL_RE.sub('', t)
|
|
|
|
|
|
|
|
|
t = CODE_TICKS_RE.sub('', t)
|
|
|
t = EMPH_RE.sub('', t)
|
|
|
t = HEAD_RE.sub('', t)
|
|
|
t = QUOTE_RE.sub('', t)
|
|
|
|
|
|
|
|
|
t = WS_RE.sub(' ', t).strip()
|
|
|
|
|
|
has = 1 if (LINK_START in t and LINK_END in t) else 0
|
|
|
return t, has
|
|
|
|
|
|
|
|
|
def strip_and_get_spans(s: str) -> Tuple[str, List[Tuple[int, int]]]:
|
|
|
"""Remove LINK markers and return (plain_text, spans) in char offsets."""
|
|
|
spans: List[Tuple[int, int]] = []
|
|
|
out: List[str] = []
|
|
|
i = 0
|
|
|
n = len(s)
|
|
|
in_link = False
|
|
|
start_pos = -1
|
|
|
while i < n:
|
|
|
if s.startswith(LINK_START, i):
|
|
|
if not in_link:
|
|
|
in_link = True
|
|
|
start_pos = len(out)
|
|
|
i += len(LINK_START); continue
|
|
|
if s.startswith(LINK_END, i):
|
|
|
if in_link:
|
|
|
in_link = False
|
|
|
end_pos = len(out)
|
|
|
if end_pos > start_pos >= 0:
|
|
|
spans.append((start_pos, end_pos))
|
|
|
start_pos = -1
|
|
|
i += len(LINK_END); continue
|
|
|
out.append(s[i]); i += 1
|
|
|
return "".join(out), spans
|
|
|
|
|
|
def labels_from_spans(offset_mapping: List[Tuple[int, int]], spans: List[Tuple[int, int]]) -> List[int]:
|
|
|
"""Binary label 1 if token overlaps any span by >=1 char, else 0."""
|
|
|
labels: List[int] = []
|
|
|
spans = sorted(spans)
|
|
|
for ts, te in offset_mapping:
|
|
|
if ts == te:
|
|
|
labels.append(0); continue
|
|
|
lab = 0
|
|
|
for ss, se in spans:
|
|
|
if te <= ss: break
|
|
|
if ts >= se: continue
|
|
|
lab = 1; break
|
|
|
labels.append(lab)
|
|
|
return labels
|
|
|
|
|
|
def windowize_ids_and_labels(
|
|
|
input_ids_no_special: List[int],
|
|
|
labels_no_special: List[int],
|
|
|
tokenizer: AutoTokenizer,
|
|
|
max_length: int,
|
|
|
doc_stride: int
|
|
|
) -> Tuple[List[List[int]], List[List[int]], List[List[int]]]:
|
|
|
"""Slice long sequences to windows with specials (<= max_length)."""
|
|
|
assert len(input_ids_no_special) == len(labels_no_special)
|
|
|
specials = tokenizer.num_special_tokens_to_add(pair=False)
|
|
|
cap = max_length - specials
|
|
|
if cap <= 0:
|
|
|
raise ValueError(f"max_length too small; specials={specials}")
|
|
|
|
|
|
def pack(ids_no_sp: List[int], labs_no_sp: List[int]):
|
|
|
ids_with = tokenizer.build_inputs_with_special_tokens(ids_no_sp)
|
|
|
attn = [1] * len(ids_with)
|
|
|
if specials == 2:
|
|
|
labs_with = [0] + labs_no_sp + [0]
|
|
|
else:
|
|
|
pad_n = len(ids_with) - len(labs_no_sp)
|
|
|
labs_with = [0] * pad_n
|
|
|
if pad_n >= 1:
|
|
|
labs_with = [0] + labs_no_sp + [0] * (pad_n - 1)
|
|
|
else:
|
|
|
labs_with = labs_no_sp[:len(ids_with)]
|
|
|
return ids_with[:max_length], attn[:max_length], labs_with[:max_length]
|
|
|
|
|
|
if len(input_ids_no_special) <= cap:
|
|
|
ids_w, attn_w, labs_w = pack(input_ids_no_special, labels_no_special)
|
|
|
return [ids_w], [attn_w], [labs_w]
|
|
|
|
|
|
step = max(cap - doc_stride, 1)
|
|
|
out_ids: List[List[int]] = []
|
|
|
out_attn: List[List[int]] = []
|
|
|
out_labs: List[List[int]] = []
|
|
|
start = 0
|
|
|
total = len(input_ids_no_special)
|
|
|
while start < total:
|
|
|
end = min(start + cap, total)
|
|
|
ids_slice = input_ids_no_special[start:end]
|
|
|
labs_slice = labels_no_special[start:end]
|
|
|
ids_w, attn_w, labs_w = pack(ids_slice, labs_slice)
|
|
|
out_ids.append(ids_w); out_attn.append(attn_w); out_labs.append(labs_w)
|
|
|
if end == total: break
|
|
|
start += step
|
|
|
return out_ids, out_attn, out_labs
|
|
|
|
|
|
|
|
|
def read_markdown_from_db(db_path: str) -> List[str]:
|
|
|
conn = sqlite3.connect(db_path)
|
|
|
try:
|
|
|
cur = conn.cursor()
|
|
|
cur.execute("""
|
|
|
SELECT full_markdown_content
|
|
|
FROM scraped_data
|
|
|
WHERE status_code = 200
|
|
|
AND full_markdown_content IS NOT NULL
|
|
|
AND TRIM(full_markdown_content) != ''
|
|
|
""")
|
|
|
rows = cur.fetchall()
|
|
|
return [r[0] if isinstance(r[0], str) else str(r[0]) for r in rows]
|
|
|
finally:
|
|
|
conn.close()
|
|
|
|
|
|
|
|
|
def main():
|
|
|
p = argparse.ArgumentParser(description="Fast end-to-end preprocessing for link token classification.")
|
|
|
p.add_argument("--db", default="scraped.db", help="SQLite DB path (table scraped_data).")
|
|
|
p.add_argument("--output_csv", default="train_clean.csv", help="Output cleaned CSV (quoted, one line/doc).")
|
|
|
p.add_argument("--tokenizer", default="microsoft/mdeberta-v3-base", help="HF tokenizer.")
|
|
|
p.add_argument("--max_length", type=int, default=512, help="Max tokens incl specials.")
|
|
|
p.add_argument("--doc_stride", type=int, default=128, help="Overlap on content tokens.")
|
|
|
p.add_argument("--val_ratio", type=float, default=0.1, help="Validation ratio by document.")
|
|
|
p.add_argument("--seed", type=int, default=42, help="Random seed for split.")
|
|
|
p.add_argument("--batch_size", type=int, default=64, help="Tokenization batch size.")
|
|
|
p.add_argument("--workers", default="auto", help="Annotation worker count: int or 'auto'.")
|
|
|
args = p.parse_args()
|
|
|
|
|
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
|
|
db_path = os.path.join(script_dir, args.db)
|
|
|
out_csv = os.path.join(script_dir, args.output_csv)
|
|
|
if not os.path.isfile(db_path):
|
|
|
print(f"ERROR: DB not found: {db_path}", file=sys.stderr); sys.exit(1)
|
|
|
|
|
|
|
|
|
print(f"[1/4] Read from DB: {args.db}")
|
|
|
md_rows = read_markdown_from_db(db_path)
|
|
|
n_docs = len(md_rows)
|
|
|
print(f" Rows: {n_docs}")
|
|
|
|
|
|
|
|
|
print(f"[2/4] Clean + annotate -> {args.output_csv}")
|
|
|
workers = os.cpu_count() if args.workers == "auto" else int(args.workers)
|
|
|
markers = 0
|
|
|
written = 0
|
|
|
|
|
|
with open(out_csv, "w", encoding="utf-8", newline="") as f_out:
|
|
|
writer = csv.writer(f_out, quoting=csv.QUOTE_ALL)
|
|
|
with ProcessPoolExecutor(max_workers=workers) as ex:
|
|
|
for txt, has in tqdm(ex.map(annotate_one, md_rows, chunksize=512), total=n_docs, unit="doc", desc="Annotating"):
|
|
|
if not txt:
|
|
|
continue
|
|
|
if '\n' in txt or '\r' in txt:
|
|
|
txt = WS_RE.sub(' ', txt).strip()
|
|
|
writer.writerow([txt])
|
|
|
written += 1
|
|
|
markers += has
|
|
|
|
|
|
if written == 0:
|
|
|
print("ERROR: No documents written after cleaning.", file=sys.stderr); sys.exit(1)
|
|
|
print(f" Written: {written} | With LINK markers: {markers}")
|
|
|
|
|
|
|
|
|
print(f"[3/4] Tokenize + align + split + windowize (tokenizer={args.tokenizer})")
|
|
|
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer, use_fast=True)
|
|
|
specials = tokenizer.num_special_tokens_to_add(pair=False)
|
|
|
cap = args.max_length - specials
|
|
|
if cap <= 0:
|
|
|
print(f"ERROR: max_length too small for specials={specials}", file=sys.stderr); sys.exit(1)
|
|
|
|
|
|
texts = []
|
|
|
with open(out_csv, "r", encoding="utf-8", newline="") as f_in:
|
|
|
rdr = csv.reader(f_in, quoting=csv.QUOTE_ALL)
|
|
|
for row in rdr:
|
|
|
texts.append(row[0])
|
|
|
num_docs = len(texts)
|
|
|
|
|
|
plain_texts: List[str] = []
|
|
|
spans_all: List[List[Tuple[int, int]]] = []
|
|
|
for t in tqdm(texts, total=num_docs, unit="doc", desc="Extract spans"):
|
|
|
plain, spans = strip_and_get_spans(t)
|
|
|
plain_texts.append(plain)
|
|
|
spans_all.append(spans)
|
|
|
|
|
|
input_ids_no_sp: List[List[int]] = []
|
|
|
offsets_all: List[List[Tuple[int, int]]] = []
|
|
|
for i in tqdm(range(0, num_docs, args.batch_size), unit="batch", desc="Tokenize"):
|
|
|
batch = plain_texts[i:i+args.batch_size]
|
|
|
enc = tokenizer(
|
|
|
batch,
|
|
|
add_special_tokens=False,
|
|
|
return_offsets_mapping=True,
|
|
|
return_attention_mask=False,
|
|
|
return_token_type_ids=False,
|
|
|
truncation=False,
|
|
|
)
|
|
|
input_ids_no_sp.extend(enc["input_ids"])
|
|
|
offsets_all.extend([[(int(a), int(b)) for (a, b) in off] for off in enc["offset_mapping"]])
|
|
|
|
|
|
labels_no_sp: List[List[int]] = []
|
|
|
total_tokens = 0
|
|
|
pos_tokens = 0
|
|
|
for offs, spans in tqdm(zip(offsets_all, spans_all), total=num_docs, unit="doc", desc="Align labels"):
|
|
|
labs = labels_from_spans(offs, spans)
|
|
|
labels_no_sp.append(labs)
|
|
|
total_tokens += len(labs)
|
|
|
if labs:
|
|
|
pos_tokens += int(np.sum(labs))
|
|
|
|
|
|
idx = list(range(num_docs))
|
|
|
random.Random(args.seed).shuffle(idx)
|
|
|
val_n = max(1, int(round(num_docs * args.val_ratio)))
|
|
|
val_set = set(idx[:val_n])
|
|
|
|
|
|
train_out_path = os.path.join(script_dir, "train_windows.jsonl")
|
|
|
val_out_path = os.path.join(script_dir, "val_windows.jsonl")
|
|
|
train_out = open(train_out_path, "w", encoding="utf-8")
|
|
|
val_out = open(val_out_path, "w", encoding="utf-8")
|
|
|
|
|
|
train_windows = 0
|
|
|
val_windows = 0
|
|
|
train_win_with_link = 0
|
|
|
val_win_with_link = 0
|
|
|
exceeding_docs = 0
|
|
|
|
|
|
for doc_id in tqdm(range(num_docs), unit="doc", desc="Windowize+write"):
|
|
|
ids = input_ids_no_sp[doc_id]
|
|
|
labs = labels_no_sp[doc_id]
|
|
|
if len(ids) + specials > args.max_length:
|
|
|
exceeding_docs += 1
|
|
|
ids_ws, attn_ws, labs_ws = windowize_ids_and_labels(ids, labs, tokenizer, args.max_length, args.doc_stride)
|
|
|
target = val_out if doc_id in val_set else train_out
|
|
|
for w_id, (iw, aw, lw) in enumerate(zip(ids_ws, attn_ws, labs_ws)):
|
|
|
if any(x == 1 for x in lw):
|
|
|
if doc_id in val_set: val_win_with_link += 1
|
|
|
else: train_win_with_link += 1
|
|
|
rec = {"doc_id": int(doc_id), "window_id": int(w_id), "input_ids": iw, "attention_mask": aw, "labels": lw}
|
|
|
target.write(json.dumps(rec, ensure_ascii=False) + "\n")
|
|
|
if doc_id in val_set: val_windows += len(ids_ws)
|
|
|
else: train_windows += len(ids_ws)
|
|
|
|
|
|
train_out.close(); val_out.close()
|
|
|
|
|
|
|
|
|
pos_rate = (pos_tokens / total_tokens) if total_tokens else 0.0
|
|
|
summary_lines = [
|
|
|
"=== prep.py Summary ===",
|
|
|
f"DB: {args.db}",
|
|
|
f"Output CSV: {args.output_csv}",
|
|
|
f"Tokenizer: {args.tokenizer}",
|
|
|
f"max_length: {args.max_length} (specials={specials}, content_capacity={cap})",
|
|
|
f"doc_stride: {args.doc_stride}",
|
|
|
f"Documents cleaned: {num_docs}",
|
|
|
f"Documents exceeding max_length (incl specials): {exceeding_docs}",
|
|
|
f"Tokens total (no specials): {total_tokens}",
|
|
|
f"Positive tokens: {pos_tokens} ({pos_rate:.4%})",
|
|
|
f"Train windows: {train_windows} (with_link={train_win_with_link})",
|
|
|
f"Val windows: {val_windows} (with_link={val_win_with_link})",
|
|
|
f"Train JSONL: train_windows.jsonl",
|
|
|
f"Val JSONL: val_windows.jsonl",
|
|
|
]
|
|
|
with open(os.path.join(script_dir, "prep_summary.txt"), "w", encoding="utf-8") as f:
|
|
|
f.write("\n".join(summary_lines) + "\n")
|
|
|
|
|
|
print(f"[4/4] Summary -> prep_summary.txt\n" + "\n".join(summary_lines))
|
|
|
print("Done.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main()
|
|
|
|