Commit 1a152040 authored by Jan Reimes's avatar Jan Reimes
Browse files

specs/database: persist specification records and avoid circular imports

parent d3f3a69b
Loading
Loading
Loading
Loading
+128 −2
Original line number Diff line number Diff line
"""Database access layer backed by pydantic_sqlite."""

from __future__ import annotations

from collections import defaultdict
from collections.abc import Callable, Iterable
from datetime import UTC, datetime
from decimal import Decimal
from pathlib import Path
from typing import TYPE_CHECKING

from pydantic_sqlite import DataBase

@@ -22,8 +21,12 @@ from tdoc_crawler.models import (
    WorkingGroup,
    utc_now,
)
from tdoc_crawler.models.specs import Specification, SpecificationDownload, SpecificationSourceRecord, SpecificationVersion
from tdoc_crawler.models.subworking_groups import SUBWORKING_GROUP_RECORDS

if TYPE_CHECKING:
    from tdoc_crawler.specs.query import SpecQueryFilters, SpecQueryResult


class TDocDatabase:
    """High-level facade for all database operations."""
@@ -371,6 +374,92 @@ class TDocDatabase:
            "name": record.name,
        }

    # ------------------------------------------------------------------
    # Spec operations
    # ------------------------------------------------------------------
    def upsert_specification(self, specification: Specification) -> tuple[bool, bool]:
        """Upsert a specification record."""
        existing = self._get_specification(specification.spec_number)
        if existing is None:
            self.connection.add("specs", specification, pk="spec_number")
            return True, False

        changed = self._spec_changed(existing, specification)
        self.connection.add("specs", specification, pk="spec_number")
        return False, changed

    def upsert_spec_source_record(self, record: SpecificationSourceRecord) -> tuple[bool, bool]:
        """Upsert a spec source record."""
        record_id = record.record_id or f"{record.spec_number}:{record.source_name}"
        updated_record = record.model_copy(update={"record_id": record_id})
        existing = self._get_spec_source_record(record_id)
        if existing is None:
            self.connection.add("spec_source_records", updated_record, pk="record_id")
            return True, False

        changed = self._spec_source_changed(existing, updated_record)
        self.connection.add("spec_source_records", updated_record, pk="record_id")
        return False, changed

    def upsert_spec_version(self, version: SpecificationVersion) -> tuple[bool, bool]:
        """Upsert a spec version record."""
        record_id = version.record_id or f"{version.spec_number}:{version.version}:{version.source_name}"
        updated_version = version.model_copy(update={"record_id": record_id})
        existing = self._get_spec_version(record_id)
        if existing is None:
            self.connection.add("spec_versions", updated_version, pk="record_id")
            return True, False

        changed = self._spec_version_changed(existing, updated_version)
        self.connection.add("spec_versions", updated_version, pk="record_id")
        return False, changed

    def log_spec_download(self, download: SpecificationDownload) -> None:
        """Persist download/extraction outcomes for a spec version."""
        record_id = download.record_id or f"{download.spec_number}:{download.version}"
        # Convert Path objects to strings for SQLite compatibility
        updated_download = download.model_copy(
            update={
                "record_id": record_id,
                "checkout_path": str(download.checkout_path),
                "document_path": str(download.document_path),
                "attachment_paths": [str(p) for p in download.attachment_paths],
            }
        )
        self.connection.add("spec_downloads", updated_download, pk="record_id")

    def query_specs(self, filters: SpecQueryFilters) -> list[SpecQueryResult]:
        """Query stored spec metadata."""
        from tdoc_crawler.specs.query import SpecQueryResult  # noqa: PLC0415

        specs = self._spec_table_rows()

        if filters.spec_numbers:
            allowed = {value.strip() for value in filters.spec_numbers}
            specs = [spec for spec in specs if spec.spec_number in allowed]

        if filters.title:
            needle = filters.title.strip().lower()
            specs = [spec for spec in specs if needle in (spec.title or "").lower()]

        if filters.working_group:
            needle = filters.working_group.strip().lower()
            specs = [spec for spec in specs if (spec.working_group or "").lower() == needle]

        if filters.status:
            needle = filters.status.strip().lower()
            specs = [spec for spec in specs if (spec.status or "").lower() == needle]

        return [
            SpecQueryResult(
                spec_number=spec.spec_number,
                title=spec.title,
                status=spec.status,
                working_group=spec.working_group,
            )
            for spec in specs
        ]

    # ------------------------------------------------------------------
    # Crawl logging and statistics
    # ------------------------------------------------------------------
@@ -538,6 +627,10 @@ class TDocDatabase:
            "tdocs": TDocMetadata,
            "meetings": MeetingMetadata,
            "crawl_log": CrawlLogEntry,
            "specs": Specification,
            "spec_source_records": SpecificationSourceRecord,
            "spec_versions": SpecificationVersion,
            "spec_downloads": SpecificationDownload,
        }

        if table not in model_map:
@@ -584,6 +677,27 @@ class TDocDatabase:
        except KeyError:
            return None

    def _spec_table_rows(self) -> list[Specification]:
        return self._table_rows("specs")

    def _get_specification(self, spec_number: str) -> Specification | None:
        try:
            return self.connection.model_from_table("specs", spec_number)  # type: ignore[arg-type]
        except KeyError:
            return None

    def _get_spec_source_record(self, record_id: str) -> SpecificationSourceRecord | None:
        try:
            return self.connection.model_from_table("spec_source_records", record_id)  # type: ignore[arg-type]
        except KeyError:
            return None

    def _get_spec_version(self, record_id: str) -> SpecificationVersion | None:
        try:
            return self.connection.model_from_table("spec_versions", record_id)  # type: ignore[arg-type]
        except KeyError:
            return None

    @staticmethod
    def _tdoc_changed(current: TDocMetadata, candidate: TDocMetadata) -> bool:
        for field in TDocMetadata.model_fields:
@@ -601,3 +715,15 @@ class TDocDatabase:
            if getattr(current, field) != getattr(candidate, field):
                return True
        return False

    @staticmethod
    def _spec_changed(current: Specification, candidate: Specification) -> bool:
        return any(getattr(current, field) != getattr(candidate, field) for field in Specification.model_fields)

    @staticmethod
    def _spec_source_changed(current: SpecificationSourceRecord, candidate: SpecificationSourceRecord) -> bool:
        return any(getattr(current, field) != getattr(candidate, field) for field in SpecificationSourceRecord.model_fields)

    @staticmethod
    def _spec_version_changed(current: SpecificationVersion, candidate: SpecificationVersion) -> bool:
        return any(getattr(current, field) != getattr(candidate, field) for field in SpecificationVersion.model_fields)
+3 −3
Original line number Diff line number Diff line
@@ -13,10 +13,11 @@ from __future__ import annotations
import logging
from pathlib import Path

from tdoc_crawler.models.base import HttpCacheConfig
from tdoc_crawler.models.tdocs import TDocMetadata
from tdoc_crawler.crawlers.portal import extract_tdoc_url_from_portal, fetch_tdoc_metadata
from tdoc_crawler.crawlers.whatthespec import resolve_via_whatthespec
from tdoc_crawler.models.base import HttpCacheConfig, PortalCredentials
from tdoc_crawler.models.tdocs import TDocMetadata

logger = logging.getLogger(__name__)


@@ -48,7 +49,6 @@ def fetch_tdoc(
    """
    # Import here to avoid circular imports


    if use_whatthespec:
        # Always use WhatTheSpec method (Method 3)
        logger.debug(f"Fetching {tdoc_id} via WhatTheSpec API")