|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| import os
|
| import re
|
| import csv
|
| import sys
|
| import json
|
| import math
|
| import sqlite3
|
| import random
|
| import argparse
|
| from typing import List, Tuple, Dict
|
| from concurrent.futures import ProcessPoolExecutor
|
|
|
| import numpy as np
|
| from tqdm import tqdm
|
|
|
| try:
|
| from transformers import AutoTokenizer
|
| except ImportError:
|
| print("ERROR: transformers not installed. pip install transformers", file=sys.stderr); sys.exit(1)
|
|
|
| LINK_START = "[LINK_START]"
|
| LINK_END = "[LINK_END]"
|
|
|
|
|
| IMG_INLINE_RE = re.compile(r'!\[[^\]]*\]\([^)]*\)')
|
| INLINE_LINK_RE = re.compile(r'\[([^\]]+)\]\([^)]*\)')
|
| REF_LINK_RE = re.compile(r'\[([^\]]+)\]\[[^\]]+\]')
|
| REF_DEF_RE = re.compile(r'^[ \t]{0,3}\[[^\]]+\]:\s+\S+.*$', re.MULTILINE)
|
| AUTOLINK_RE = re.compile(r'<https?[^>]+>')
|
| BARE_URL_RE = re.compile(r'https?://\S+|www\.\S+')
|
| CODE_TICKS_RE = re.compile(r'`+')
|
| EMPH_RE = re.compile(r'[*]+')
|
| HEAD_RE = re.compile(r'^[ \t]*#+[ \t]*', re.MULTILINE)
|
| QUOTE_RE = re.compile(r'^(>+\s*)+', re.MULTILINE)
|
| WS_RE = re.compile(r'\s+')
|
|
|
| def annotate_one(md_text: str) -> Tuple[str, int]:
|
| """Return (single-line annotated text, has_marker[0/1])."""
|
| if not md_text:
|
| return "", 0
|
| t = md_text
|
|
|
|
|
| t = IMG_INLINE_RE.sub('', t)
|
| t = INLINE_LINK_RE.sub(lambda m: f"{LINK_START}{m.group(1)}{LINK_END}", t)
|
| t = REF_LINK_RE.sub(lambda m: f"{LINK_START}{m.group(1)}{LINK_END}", t)
|
| t = REF_DEF_RE.sub('', t)
|
| t = AUTOLINK_RE.sub('', t)
|
| t = BARE_URL_RE.sub('', t)
|
|
|
|
|
| t = CODE_TICKS_RE.sub('', t)
|
| t = EMPH_RE.sub('', t)
|
| t = HEAD_RE.sub('', t)
|
| t = QUOTE_RE.sub('', t)
|
|
|
|
|
| t = WS_RE.sub(' ', t).strip()
|
|
|
| has = 1 if (LINK_START in t and LINK_END in t) else 0
|
| return t, has
|
|
|
|
|
| def strip_and_get_spans(s: str) -> Tuple[str, List[Tuple[int, int]]]:
|
| """Remove LINK markers and return (plain_text, spans) in char offsets."""
|
| spans: List[Tuple[int, int]] = []
|
| out: List[str] = []
|
| i = 0
|
| n = len(s)
|
| in_link = False
|
| start_pos = -1
|
| while i < n:
|
| if s.startswith(LINK_START, i):
|
| if not in_link:
|
| in_link = True
|
| start_pos = len(out)
|
| i += len(LINK_START); continue
|
| if s.startswith(LINK_END, i):
|
| if in_link:
|
| in_link = False
|
| end_pos = len(out)
|
| if end_pos > start_pos >= 0:
|
| spans.append((start_pos, end_pos))
|
| start_pos = -1
|
| i += len(LINK_END); continue
|
| out.append(s[i]); i += 1
|
| return "".join(out), spans
|
|
|
| def labels_from_spans(offset_mapping: List[Tuple[int, int]], spans: List[Tuple[int, int]]) -> List[int]:
|
| """Binary label 1 if token overlaps any span by >=1 char, else 0."""
|
| labels: List[int] = []
|
| spans = sorted(spans)
|
| for ts, te in offset_mapping:
|
| if ts == te:
|
| labels.append(0); continue
|
| lab = 0
|
| for ss, se in spans:
|
| if te <= ss: break
|
| if ts >= se: continue
|
| lab = 1; break
|
| labels.append(lab)
|
| return labels
|
|
|
| def windowize_ids_and_labels(
|
| input_ids_no_special: List[int],
|
| labels_no_special: List[int],
|
| tokenizer: AutoTokenizer,
|
| max_length: int,
|
| doc_stride: int
|
| ) -> Tuple[List[List[int]], List[List[int]], List[List[int]]]:
|
| """Slice long sequences to windows with specials (<= max_length)."""
|
| assert len(input_ids_no_special) == len(labels_no_special)
|
| specials = tokenizer.num_special_tokens_to_add(pair=False)
|
| cap = max_length - specials
|
| if cap <= 0:
|
| raise ValueError(f"max_length too small; specials={specials}")
|
|
|
| def pack(ids_no_sp: List[int], labs_no_sp: List[int]):
|
| ids_with = tokenizer.build_inputs_with_special_tokens(ids_no_sp)
|
| attn = [1] * len(ids_with)
|
| if specials == 2:
|
| labs_with = [0] + labs_no_sp + [0]
|
| else:
|
| pad_n = len(ids_with) - len(labs_no_sp)
|
| labs_with = [0] * pad_n
|
| if pad_n >= 1:
|
| labs_with = [0] + labs_no_sp + [0] * (pad_n - 1)
|
| else:
|
| labs_with = labs_no_sp[:len(ids_with)]
|
| return ids_with[:max_length], attn[:max_length], labs_with[:max_length]
|
|
|
| if len(input_ids_no_special) <= cap:
|
| ids_w, attn_w, labs_w = pack(input_ids_no_special, labels_no_special)
|
| return [ids_w], [attn_w], [labs_w]
|
|
|
| step = max(cap - doc_stride, 1)
|
| out_ids: List[List[int]] = []
|
| out_attn: List[List[int]] = []
|
| out_labs: List[List[int]] = []
|
| start = 0
|
| total = len(input_ids_no_special)
|
| while start < total:
|
| end = min(start + cap, total)
|
| ids_slice = input_ids_no_special[start:end]
|
| labs_slice = labels_no_special[start:end]
|
| ids_w, attn_w, labs_w = pack(ids_slice, labs_slice)
|
| out_ids.append(ids_w); out_attn.append(attn_w); out_labs.append(labs_w)
|
| if end == total: break
|
| start += step
|
| return out_ids, out_attn, out_labs
|
|
|
|
|
| def read_markdown_from_db(db_path: str) -> List[str]:
|
| conn = sqlite3.connect(db_path)
|
| try:
|
| cur = conn.cursor()
|
| cur.execute("""
|
| SELECT full_markdown_content
|
| FROM scraped_data
|
| WHERE status_code = 200
|
| AND full_markdown_content IS NOT NULL
|
| AND TRIM(full_markdown_content) != ''
|
| """)
|
| rows = cur.fetchall()
|
| return [r[0] if isinstance(r[0], str) else str(r[0]) for r in rows]
|
| finally:
|
| conn.close()
|
|
|
|
|
| def main():
|
| p = argparse.ArgumentParser(description="Fast end-to-end preprocessing for link token classification.")
|
| p.add_argument("--db", default="scraped.db", help="SQLite DB path (table scraped_data).")
|
| p.add_argument("--output_csv", default="train_clean.csv", help="Output cleaned CSV (quoted, one line/doc).")
|
| p.add_argument("--tokenizer", default="microsoft/mdeberta-v3-base", help="HF tokenizer.")
|
| p.add_argument("--max_length", type=int, default=512, help="Max tokens incl specials.")
|
| p.add_argument("--doc_stride", type=int, default=128, help="Overlap on content tokens.")
|
| p.add_argument("--val_ratio", type=float, default=0.1, help="Validation ratio by document.")
|
| p.add_argument("--seed", type=int, default=42, help="Random seed for split.")
|
| p.add_argument("--batch_size", type=int, default=64, help="Tokenization batch size.")
|
| p.add_argument("--workers", default="auto", help="Annotation worker count: int or 'auto'.")
|
| args = p.parse_args()
|
|
|
| script_dir = os.path.dirname(os.path.abspath(__file__))
|
| db_path = os.path.join(script_dir, args.db)
|
| out_csv = os.path.join(script_dir, args.output_csv)
|
| if not os.path.isfile(db_path):
|
| print(f"ERROR: DB not found: {db_path}", file=sys.stderr); sys.exit(1)
|
|
|
|
|
| print(f"[1/4] Read from DB: {args.db}")
|
| md_rows = read_markdown_from_db(db_path)
|
| n_docs = len(md_rows)
|
| print(f" Rows: {n_docs}")
|
|
|
|
|
| print(f"[2/4] Clean + annotate -> {args.output_csv}")
|
| workers = os.cpu_count() if args.workers == "auto" else int(args.workers)
|
| markers = 0
|
| written = 0
|
|
|
| with open(out_csv, "w", encoding="utf-8", newline="") as f_out:
|
| writer = csv.writer(f_out, quoting=csv.QUOTE_ALL)
|
| with ProcessPoolExecutor(max_workers=workers) as ex:
|
| for txt, has in tqdm(ex.map(annotate_one, md_rows, chunksize=512), total=n_docs, unit="doc", desc="Annotating"):
|
| if not txt:
|
| continue
|
| if '\n' in txt or '\r' in txt:
|
| txt = WS_RE.sub(' ', txt).strip()
|
| writer.writerow([txt])
|
| written += 1
|
| markers += has
|
|
|
| if written == 0:
|
| print("ERROR: No documents written after cleaning.", file=sys.stderr); sys.exit(1)
|
| print(f" Written: {written} | With LINK markers: {markers}")
|
|
|
|
|
| print(f"[3/4] Tokenize + align + split + windowize (tokenizer={args.tokenizer})")
|
| tokenizer = AutoTokenizer.from_pretrained(args.tokenizer, use_fast=True)
|
| specials = tokenizer.num_special_tokens_to_add(pair=False)
|
| cap = args.max_length - specials
|
| if cap <= 0:
|
| print(f"ERROR: max_length too small for specials={specials}", file=sys.stderr); sys.exit(1)
|
|
|
| texts = []
|
| with open(out_csv, "r", encoding="utf-8", newline="") as f_in:
|
| rdr = csv.reader(f_in, quoting=csv.QUOTE_ALL)
|
| for row in rdr:
|
| texts.append(row[0])
|
| num_docs = len(texts)
|
|
|
| plain_texts: List[str] = []
|
| spans_all: List[List[Tuple[int, int]]] = []
|
| for t in tqdm(texts, total=num_docs, unit="doc", desc="Extract spans"):
|
| plain, spans = strip_and_get_spans(t)
|
| plain_texts.append(plain)
|
| spans_all.append(spans)
|
|
|
| input_ids_no_sp: List[List[int]] = []
|
| offsets_all: List[List[Tuple[int, int]]] = []
|
| for i in tqdm(range(0, num_docs, args.batch_size), unit="batch", desc="Tokenize"):
|
| batch = plain_texts[i:i+args.batch_size]
|
| enc = tokenizer(
|
| batch,
|
| add_special_tokens=False,
|
| return_offsets_mapping=True,
|
| return_attention_mask=False,
|
| return_token_type_ids=False,
|
| truncation=False,
|
| )
|
| input_ids_no_sp.extend(enc["input_ids"])
|
| offsets_all.extend([[(int(a), int(b)) for (a, b) in off] for off in enc["offset_mapping"]])
|
|
|
| labels_no_sp: List[List[int]] = []
|
| total_tokens = 0
|
| pos_tokens = 0
|
| for offs, spans in tqdm(zip(offsets_all, spans_all), total=num_docs, unit="doc", desc="Align labels"):
|
| labs = labels_from_spans(offs, spans)
|
| labels_no_sp.append(labs)
|
| total_tokens += len(labs)
|
| if labs:
|
| pos_tokens += int(np.sum(labs))
|
|
|
| idx = list(range(num_docs))
|
| random.Random(args.seed).shuffle(idx)
|
| val_n = max(1, int(round(num_docs * args.val_ratio)))
|
| val_set = set(idx[:val_n])
|
|
|
| train_out_path = os.path.join(script_dir, "train_windows.jsonl")
|
| val_out_path = os.path.join(script_dir, "val_windows.jsonl")
|
| train_out = open(train_out_path, "w", encoding="utf-8")
|
| val_out = open(val_out_path, "w", encoding="utf-8")
|
|
|
| train_windows = 0
|
| val_windows = 0
|
| train_win_with_link = 0
|
| val_win_with_link = 0
|
| exceeding_docs = 0
|
|
|
| for doc_id in tqdm(range(num_docs), unit="doc", desc="Windowize+write"):
|
| ids = input_ids_no_sp[doc_id]
|
| labs = labels_no_sp[doc_id]
|
| if len(ids) + specials > args.max_length:
|
| exceeding_docs += 1
|
| ids_ws, attn_ws, labs_ws = windowize_ids_and_labels(ids, labs, tokenizer, args.max_length, args.doc_stride)
|
| target = val_out if doc_id in val_set else train_out
|
| for w_id, (iw, aw, lw) in enumerate(zip(ids_ws, attn_ws, labs_ws)):
|
| if any(x == 1 for x in lw):
|
| if doc_id in val_set: val_win_with_link += 1
|
| else: train_win_with_link += 1
|
| rec = {"doc_id": int(doc_id), "window_id": int(w_id), "input_ids": iw, "attention_mask": aw, "labels": lw}
|
| target.write(json.dumps(rec, ensure_ascii=False) + "\n")
|
| if doc_id in val_set: val_windows += len(ids_ws)
|
| else: train_windows += len(ids_ws)
|
|
|
| train_out.close(); val_out.close()
|
|
|
|
|
| pos_rate = (pos_tokens / total_tokens) if total_tokens else 0.0
|
| summary_lines = [
|
| "=== prep.py Summary ===",
|
| f"DB: {args.db}",
|
| f"Output CSV: {args.output_csv}",
|
| f"Tokenizer: {args.tokenizer}",
|
| f"max_length: {args.max_length} (specials={specials}, content_capacity={cap})",
|
| f"doc_stride: {args.doc_stride}",
|
| f"Documents cleaned: {num_docs}",
|
| f"Documents exceeding max_length (incl specials): {exceeding_docs}",
|
| f"Tokens total (no specials): {total_tokens}",
|
| f"Positive tokens: {pos_tokens} ({pos_rate:.4%})",
|
| f"Train windows: {train_windows} (with_link={train_win_with_link})",
|
| f"Val windows: {val_windows} (with_link={val_win_with_link})",
|
| f"Train JSONL: train_windows.jsonl",
|
| f"Val JSONL: val_windows.jsonl",
|
| ]
|
| with open(os.path.join(script_dir, "prep_summary.txt"), "w", encoding="utf-8") as f:
|
| f.write("\n".join(summary_lines) + "\n")
|
|
|
| print(f"[4/4] Summary -> prep_summary.txt\n" + "\n".join(summary_lines))
|
| print("Done.")
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|