Commit 91b7273e authored by Jan Reimes's avatar Jan Reimes
Browse files

refactor(specs): restructure spec handling and database interactions

* Remove SpecCatalog and integrate its functionality into SpecDatabase.
* Introduce new utility functions for normalization in utils.
* Update SpecDownloads to utilize the new SpecDatabase.
* Refactor tests to accommodate changes in database interactions.
* Ensure all spec-related operations are streamlined through the new database structure.
parent 2ffa247f
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -49,6 +49,7 @@ dev = [
    "mdformat>=1.0.0",
    "undersort>=0.1.5",
    "specify-cli",
    "pydeps>=3.0.2",
]

[build-system]
+10 −7
Original line number Diff line number Diff line
@@ -13,7 +13,7 @@ import zipfile
from contextlib import suppress
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, cast
from typing import cast
from urllib.parse import urlparse

import requests
@@ -22,14 +22,12 @@ from tdoc_crawler.crawlers.meeting_doclist import DocumentListError, fetch_meeti
from tdoc_crawler.http_client import download_to_path
from tdoc_crawler.logging import get_logger
from tdoc_crawler.models import MeetingMetadata, TDocMetadata
from tdoc_crawler.specs.database import SpecDatabase
from tdoc_crawler.specs.downloads import SpecDownloads
from tdoc_crawler.specs.sources.base import FunctionSpecSource, SpecSource
from tdoc_crawler.specs.sources.threegpp import fetch_threegpp_metadata
from tdoc_crawler.specs.sources.whatthespec import fetch_whatthespec_metadata

if TYPE_CHECKING:
    from tdoc_crawler.database import TDocDatabase

logger = get_logger(__name__)


@@ -227,7 +225,8 @@ def clear_checkout_specs(checkout_dir: Path) -> int:

def checkout_specs(
    spec_numbers: list[str],
    database: TDocDatabase,
    checkout_dir: Path,
    database: SpecDatabase,
    release: str = "latest",
    doc_only: bool = False,
    cache_manager_name: str | None = None,
@@ -236,7 +235,8 @@ def checkout_specs(

    Args:
            spec_numbers: List of spec numbers to checkout
            database: TDocDatabase instance for metadata lookup
            checkout_dir: Base checkout directory
            database: SpecDatabase instance for metadata lookup
            release: Release version to checkout
            doc_only: If True, download only document files instead of full zip
            cache_manager_name: Optional cache manager name for HTTP caching
@@ -271,6 +271,7 @@ class CheckoutResult:

def checkout_tdocs(
    results: list[TDocMetadata],
    checkout_dir: Path,
    force: bool = False,
    session: requests.Session | None = None,
    cache_manager_name: str | None = None,
@@ -279,6 +280,7 @@ def checkout_tdocs(

    Args:
        results: List of TDocMetadata to checkout
        checkout_dir: Base checkout directory
        force: If True, re-download even if already exists
        session: Optional requests.Session to reuse for downloads
        cache_manager_name: Optional cache manager name for HTTP caching
@@ -307,6 +309,7 @@ def checkout_tdocs(

def checkout_meeting_tdocs(
    meetings: list[MeetingMetadata],
    checkout_dir: Path,
    http_cache_dir: Path,
    session: requests.Session | None = None,
    cache_manager_name: str | None = None,
@@ -341,7 +344,7 @@ def checkout_meeting_tdocs(
            if metadata.tdoc_id not in unique:
                unique[metadata.tdoc_id] = metadata

    return checkout_tdocs(list(unique.values()), force=False, session=session, cache_manager_name=cache_manager_name)
    return checkout_tdocs(list(unique.values()), checkout_dir, force=False, session=session, cache_manager_name=cache_manager_name)


__all__ = [
+2 −0
Original line number Diff line number Diff line
@@ -20,6 +20,8 @@ TDOC_VIEW_URL: Final[str] = f"{PORTAL_BASE_URL}/ngppapp/CreateTdoc.Aspx"
TDOC_DOWNLOAD_URL: Final[str] = f"{PORTAL_BASE_URL}/ngppapp/DownloadTDoc.aspx"
LOGIN_URL: Final[str] = f"{PORTAL_BASE_URL}/login.aspx"

SPEC_URL_TEMPLATE: Final[str] = "https://www.3gpp.org/ftp/Specs/archive/{series}/{normalized}/{file_name}"

DATE_PATTERN: Final[re.Pattern[str]] = re.compile(r"(\d{4}[\-\u2010-\u2015]\d{2}[\-\u2010-\u2015]\d{2})")

MEETING_CODE_REGISTRY: Final[dict[str, list[tuple[str, str | None]]]] = {
+64 −2
Original line number Diff line number Diff line
@@ -4,14 +4,76 @@ from __future__ import annotations

import logging

from tdoc_crawler.database.connection import SpecDatabase, TDocDatabase
from tdoc_crawler.database.connection import TDocDatabase
from tdoc_crawler.database.errors import DatabaseError
from tdoc_crawler.utils.normalization import normalize_portal_meeting_name

logger = logging.getLogger(__name__)


__all__ = [
    "DatabaseError",
    "SpecDatabase",
    "TDocDatabase",
    "resolve_meeting_id",
]


def resolve_meeting_id(database: TDocDatabase, meeting_name: str) -> int | None:
    """Resolve meeting name to meeting_id from database.

    Uses fuzzy matching to handle variations in meeting names:
    - Exact match (case-insensitive)
    - Normalized name match
    - Prefix/suffix matching for variations like "SA4-e" vs "3GPPSA4-e"

    Args:
        database: Database connection
        meeting_name: Meeting identifier (e.g., "SA4#133-e" or "S4-133-e")

    Returns:
        Meeting ID if found, None otherwise
    """
    # Query all meetings from database
    from tdoc_crawler.models import MeetingQueryConfig, SortOrder

    config = MeetingQueryConfig(
        cache_dir=database.db_file.parent,
        working_groups=None,
        subgroups=None,
        limit=None,
        order=SortOrder.DESC,
        include_without_files=True,
    )
    all_meetings = database.query_meetings(config)

    def _match_name(candidate: str, cached: str | None) -> bool:
        """Check if candidate matches cached name via fuzzy matching."""
        if not cached:
            return False
        candidate_lower = candidate.lower()
        cached_lower = cached.lower()

        # Exact match
        if candidate_lower == cached_lower:
            return True
        # Candidate is prefix/suffix of cached or vice versa
        return (
            cached_lower.startswith(candidate_lower)
            or cached_lower.endswith(candidate_lower)
            or candidate_lower.startswith(cached_lower)
            or candidate_lower.endswith(cached_lower)
        )

    # Try matches with original and normalized names
    normalized = normalize_portal_meeting_name(meeting_name)

    candidates = [meeting_name]
    if normalized != meeting_name:
        candidates.append(normalized)

    for candidate in candidates:
        for meeting in all_meetings:
            if _match_name(candidate, meeting.short_name):
                return meeting.meeting_id

    return None
+31 −209
Original line number Diff line number Diff line
"""Database access layer backed by pydantic_sqlite."""

import contextlib
import json
import logging
from collections import defaultdict
@@ -23,16 +22,16 @@ from tdoc_crawler.models import (
    SortOrder,
    TDocMetadata,
    WorkingGroup,
    utc_now,
)
from tdoc_crawler.models.specs import Specification, SpecificationDownload, SpecificationSourceRecord, SpecificationVersion, SpecQueryFilters, SpecQueryResult
from tdoc_crawler.models.specs import Specification, SpecificationDownload, SpecificationSourceRecord, SpecificationVersion
from tdoc_crawler.models.subworking_groups import SUBWORKING_GROUP_RECORDS
from tdoc_crawler.specs.normalization import normalize_portal_meeting_name
from tdoc_crawler.utils.misc import utc_now
from tdoc_crawler.utils.normalization import normalize_portal_meeting_name

_logger = logging.getLogger(__name__)


class _DocDatabase:
class DocDatabase:
    """High-level facade for all database operations."""

    # Map table names to their model classes and safe names
@@ -84,6 +83,32 @@ class _DocDatabase:
        """
        return self._clear_tables(self.model_map.keys())

    def clear_tdocs(self) -> int:
        """Clear all TDoc records from database.

        Returns:
            Number of TDocs deleted
        """
        counts = self._clear_tables(["tdocs"])
        return counts.get("tdocs", 0)

    def clear_meetings(self) -> int:
        """Clear all meeting records from database.

        Returns:
            Number of meetings deleted
        """
        counts = self._clear_tables(["meetings"])
        return counts.get("meetings", 0)

    def clear_specs(self) -> dict[str, int]:
        """Clear all spec-related records from database.

        Returns:
            Mapping of table name to deleted row count.
        """
        return self._clear_tables(["spec_downloads", "spec_versions", "spec_source_records", "specs"])

    def _ensure_reference_data(self) -> None:
        """Populate reference tables for working and subworking groups."""
        database = self.connection
@@ -171,7 +196,7 @@ class _DocDatabase:
        return cursor.fetchone() is not None


class TDocDatabase(_DocDatabase):
class TDocDatabase(DocDatabase):
    # ------------------------------------------------------------------
    # TDoc operations
    # ------------------------------------------------------------------
@@ -694,206 +719,3 @@ class TDocDatabase(_DocDatabase):
            if getattr(current, field) != getattr(candidate, field):
                return True
        return False


class SpecDatabase(_DocDatabase):
    # ------------------------------------------------------------------
    # Spec operations
    # ------------------------------------------------------------------
    def upsert_specification(self, specification: Specification) -> tuple[bool, bool]:
        """Upsert a specification record."""
        existing = self._get_specification(specification.spec_number)
        if existing is None:
            self.connection.add("specs", specification, pk="spec_number")
            return True, False

        changed = self._spec_changed(existing, specification)
        self.connection.add("specs", specification, pk="spec_number")
        return False, changed

    def upsert_spec_source_record(self, record: SpecificationSourceRecord) -> tuple[bool, bool]:
        """Upsert a spec source record."""
        record_id = record.record_id or f"{record.spec_number}:{record.source_name}"
        updated_record = record.model_copy(update={"record_id": record_id})
        existing = self._get_spec_source_record(record_id)
        if existing is None:
            self.connection.add("spec_source_records", updated_record, pk="record_id")
            return True, False

        changed = self._spec_source_changed(existing, updated_record)
        self.connection.add("spec_source_records", updated_record, pk="record_id")
        return False, changed

    def upsert_spec_version(self, version: SpecificationVersion) -> tuple[bool, bool]:
        """Upsert a spec version record."""
        record_id = version.record_id or f"{version.spec_number}:{version.version}:{version.source_name}"
        updated_version = version.model_copy(update={"record_id": record_id})
        existing = self._get_spec_version(record_id)
        if existing is None:
            self.connection.add("spec_versions", updated_version, pk="record_id")
            return True, False

        changed = self._spec_version_changed(existing, updated_version)
        self.connection.add("spec_versions", updated_version, pk="record_id")
        return False, changed

    def get_spec_versions(self, spec_number: str) -> list[SpecificationVersion]:
        """Get all versions for a spec."""
        try:
            cursor = self.connection._db.execute("SELECT * FROM spec_versions WHERE spec_number = ?", (spec_number,))
            columns = [description[0] for description in cursor.description]
            rows = cursor.fetchall()

            result = []
            for row in rows:
                row_dict = dict(zip(columns, row, strict=False))
                result.append(SpecificationVersion(**row_dict))
            return result
        except Exception:
            return []

    def log_spec_download(self, download: SpecificationDownload) -> None:
        """Persist download/extraction outcomes for a spec version."""
        record_id = download.record_id or f"{download.spec_number}:{download.version}"
        # Convert Path objects to strings for SQLite compatibility
        updated_download = download.model_copy(
            update={
                "record_id": record_id,
                "checkout_path": str(download.checkout_path),
                "document_path": str(download.document_path),
                "attachment_paths": [str(p) for p in download.attachment_paths],
            }
        )
        self.connection.add("spec_downloads", updated_download, pk="record_id")

    def query_specs(self, filters: SpecQueryFilters) -> list[SpecQueryResult]:
        """Query stored spec metadata."""
        specs = self._spec_table_rows()
        source_records = self._table_rows("spec_source_records")
        records_by_spec: dict[str, list[SpecificationSourceRecord]] = defaultdict(list)
        for record in source_records:
            records_by_spec[record.spec_number].append(record)

        if filters.spec_numbers:
            allowed = {value.strip() for value in filters.spec_numbers}
            specs = [spec for spec in specs if spec.spec_number in allowed]

        if filters.title:
            needle = filters.title.strip().lower()
            specs = [spec for spec in specs if needle in (spec.title or "").lower()]

        if filters.working_group:
            needle = filters.working_group.strip().lower()
            specs = [spec for spec in specs if (spec.working_group or "").lower() == needle]

        if filters.status:
            needle = filters.status.strip().lower()
            specs = [spec for spec in specs if (spec.status or "").lower() == needle]

        def build_source_differences(records: list[SpecificationSourceRecord]) -> dict[str, dict[str, str | None]]:
            if len(records) < 2:
                return {}

            fields = ("title", "status", "working_group", "latest_version", "spec_type", "series")
            differences: dict[str, dict[str, str | None]] = {}
            for field in fields:
                values: dict[str, str | None] = {}
                normalized_values: set[str] = set()
                for record in records:
                    payload = record.metadata_payload
                    if isinstance(payload, str):
                        try:
                            payload = json.loads(payload)
                        except json.JSONDecodeError:
                            payload = {}
                    if not isinstance(payload, dict):
                        payload = {}
                    raw_value = payload.get(field)
                    value = str(raw_value) if raw_value is not None else None
                    values[record.source_name] = value
                    normalized_values.add((value or "").strip().lower())

                if len(normalized_values) > 1:
                    differences[field] = values

            return differences

        return [
            SpecQueryResult(
                spec_number=spec.spec_number,
                title=spec.title,
                status=spec.status,
                working_group=spec.working_group,
                source_differences=build_source_differences(records_by_spec.get(spec.spec_number, [])),
            )
            for spec in specs
        ]

    def clear_specs(self) -> dict[str, int]:
        """Clear all spec-related records from database.

        Returns:
            Mapping of table name to deleted row count.
        """
        return self._clear_tables(["spec_downloads", "spec_versions", "spec_source_records", "specs"])

    def _spec_table_rows(self) -> list[Specification]:
        return self._table_rows("specs")

    def _get_specification(self, spec_number: str) -> Specification | None:
        try:
            return self.connection.model_from_table("specs", spec_number)  # type: ignore[arg-type]
        except KeyError:
            return None

    def _get_spec_source_record(self, record_id: str) -> SpecificationSourceRecord | None:
        try:
            # Use raw query to handle JSON deserialization manually before model instantiation
            cursor = self.connection._db.execute("SELECT * FROM spec_source_records WHERE record_id = ?", (record_id,))
            row = cursor.fetchone()
            if row is None:
                return None

            columns = [description[0] for description in cursor.description]
            row_dict = dict(zip(columns, row, strict=False))

            # Handle JSON fields
            if "metadata_payload" in row_dict and isinstance(row_dict["metadata_payload"], str):
                try:
                    row_dict["metadata_payload"] = json.loads(row_dict["metadata_payload"])
                except json.JSONDecodeError:
                    row_dict["metadata_payload"] = {}

            if "versions" in row_dict and isinstance(row_dict["versions"], str):
                try:
                    row_dict["versions"] = json.loads(row_dict["versions"])
                except json.JSONDecodeError:
                    row_dict["versions"] = []

            # Handle datetime deserialization
            if "fetched_at" in row_dict and isinstance(row_dict["fetched_at"], str):
                with contextlib.suppress(ValueError, AttributeError):
                    row_dict["fetched_at"] = datetime.fromisoformat(row_dict["fetched_at"])

            return SpecificationSourceRecord(**row_dict)
        except Exception as exc:
            _logger.debug("Error fetching spec source record %s: %s", record_id, exc)
            return None

    def _get_spec_version(self, record_id: str) -> SpecificationVersion | None:
        try:
            return self.connection.model_from_table("spec_versions", record_id)
        except KeyError:
            return None

    @staticmethod
    def _spec_changed(current: Specification, candidate: Specification) -> bool:
        return any(getattr(current, field) != getattr(candidate, field) for field in Specification.model_fields)

    @staticmethod
    def _spec_source_changed(current: SpecificationSourceRecord, candidate: SpecificationSourceRecord) -> bool:
        return any(getattr(current, field) != getattr(candidate, field) for field in SpecificationSourceRecord.model_fields)

    @staticmethod
    def _spec_version_changed(current: SpecificationVersion, candidate: SpecificationVersion) -> bool:
        return any(getattr(current, field) != getattr(candidate, field) for field in SpecificationVersion.model_fields)
Loading