Commit f4458448 authored by Jan Reimes's avatar Jan Reimes
Browse files

refactor(db): migrate database layer to oxyde ORM backend

- Rewrite base.py: manager-based schema creation, PK/UNIQUE DDL
  emission, overwrite-safe connection registration, idempotent
  reference seeding, fixed crawl log start/end write paths
- oxyde_models.py: normalise PK declarations to db_pk=True across
  all models; add missing UNIQUE constraints
- tdocs.py: switch to manager create/upsert; fix datetime timezone
  comparison; manager-based retrieval and filter paths
- meetings.py: align insert/update semantics with oxyde manager;
  async-safe get_existing_meeting_ids
- specs.py: switch insert paths to manager create; manager-based
  get and source/version upsert flows
parent 6e1efc6b
Loading
Loading
Loading
Loading
+112 −36
Original line number Diff line number Diff line
@@ -4,8 +4,10 @@ from collections.abc import Iterable
from datetime import datetime
from pathlib import Path
from typing import Self
from uuid import uuid4

from oxyde import AsyncDatabase, Model
from oxyde.migrations import extract_current_schema

from tdoc_crawler.database.errors import DatabaseError
from tdoc_crawler.database.oxyde_models import (
@@ -40,18 +42,10 @@ class DocDatabase:
    async def __aenter__(self) -> Self:
        try:
            self.cache_dir.mkdir(parents=True, exist_ok=True)
            self._database = AsyncDatabase(self.db_file)
            url = f"sqlite:///{self.db_file.as_posix()}"
            self._database = AsyncDatabase(url, auto_register=True, overwrite=True)
            await self._database.connect()
            # Register all models with the database
            self._database.register(WorkingGroupRecord)
            self._database.register(SubWorkingGroupRecord)
            self._database.register(CrawlLogEntry)
            self._database.register(MeetingMetadata)
            self._database.register(TDocMetadata)
            self._database.register(Specification)
            self._database.register(SpecificationSourceRecord)
            self._database.register(SpecificationVersion)
            self._database.register(SpecificationDownload)
            await self._ensure_tables_exist()
            await self._ensure_reference_data()
        except Exception as exc:
            raise DatabaseError("database-initialization-failed", detail=str(exc)) from exc
@@ -100,7 +94,9 @@ class DocDatabase:
        Returns:
            The ID of the created crawl log entry
        """
        log_id = f"{crawl_type}-{uuid4().hex}"
        entry = CrawlLogEntry(
            log_id=log_id,
            crawl_type=crawl_type,
            working_groups=filters or [],
            incremental=incremental,
@@ -111,7 +107,7 @@ class DocDatabase:
            errors_count=0,
            status="RUNNING",
        )
        await self.connection.save(entry)
        await CrawlLogEntry.objects.create(instance=entry)
        return entry.log_id

    async def log_crawl_end(
@@ -129,37 +125,117 @@ class DocDatabase:
            items_updated: Number of existing items updated
            errors_count: Number of errors encountered
        """
        entry = await self.connection.get(CrawlLogEntry, crawl_id)
        if entry:
            entry.end_time = datetime.now()
            entry.items_added = items_added
            entry.items_updated = items_updated
            entry.errors_count = errors_count
            entry.status = "COMPLETED"
            await self.connection.save(entry)
        await CrawlLogEntry.objects.filter(log_id=crawl_id).update(
            end_time=datetime.now(),
            items_added=items_added,
            items_updated=items_updated,
            errors_count=errors_count,
            status="COMPLETED",
        )

    async def _ensure_tables_exist(self) -> None:
        """Create database tables from Oxyde models if they don't exist.

        Uses extract_current_schema() to get table definitions from registered models,
        then creates tables using raw SQL CREATE IF NOT EXISTS statements.
        """
        if self._database is None:
            return

        # Extract current schema from registered models
        schema = extract_current_schema(dialect="sqlite")

        # Build and execute CREATE TABLE IF NOT EXISTS for each table
        for table_name, table_def in schema.get("tables", {}).items():
            fields = table_def.get("fields", [])
            if not fields:
                continue

            # Build column definitions
            column_defs: list[str] = []
            pk_fields: list[str] = []
            for field in fields:
                if bool(field.get("primary_key", False)):
                    pk_fields.append(field["name"])

            for field in fields:
                col_name = field["name"]
                python_type = field["python_type"]

                # Map Python types to SQLite types
                sql_type: str
                if python_type == "int":
                    sql_type = "INTEGER"
                elif python_type == "str":
                    sql_type = "TEXT"
                elif python_type == "float":
                    sql_type = "REAL"
                elif python_type in ("datetime", "date"):
                    sql_type = "TEXT"
                elif python_type == "bool":
                    sql_type = "INTEGER"
                elif python_type == "bytes":
                    sql_type = "BLOB"
                elif python_type in ("list", "dict"):
                    sql_type = "TEXT"  # Store as JSON
                else:
                    sql_type = "TEXT"

                nullable = field.get("nullable", True)
                if not nullable:
                    sql_type += " NOT NULL"

                default = field.get("default")
                if default is not None and default != "NULL":
                    sql_type += f" DEFAULT {default}"

                is_pk = bool(field.get("primary_key", False))
                is_unique = bool(field.get("unique", False))
                if is_pk and len(pk_fields) == 1:
                    sql_type += " PRIMARY KEY"
                elif is_unique:
                    sql_type += " UNIQUE"

                column_defs.append(f"{col_name} {sql_type}")

            if len(pk_fields) > 1:
                pk_cols = ", ".join(pk_fields)
                column_defs.append(f"PRIMARY KEY ({pk_cols})")

            # Build complete CREATE TABLE IF NOT EXISTS statement
            columns_sql = ", ".join(column_defs)
            create_sql = f"CREATE TABLE IF NOT EXISTS {table_name} ({columns_sql})"

            # Execute the SQL using Oxyde's IR format for raw SQL
            raw_ir = {"proto": 1, "op": "raw", "table": "", "sql": create_sql}
            await self._database.execute(raw_ir)

    async def _ensure_reference_data(self) -> None:
        """Populate reference tables for working and subworking groups."""
        for record in WORKING_GROUP_RECORDS:
            existing = await WorkingGroupRecord.objects.filter(tbid=record.tbid).first()
            if existing is None:
                oxyde_record = WorkingGroupRecord(
                    tbid=record.tbid,
                    code=record.code,
                    name=record.name,
                )
            await self.connection.save(oxyde_record)
                await WorkingGroupRecord.objects.create(instance=oxyde_record)
        for record in SUBWORKING_GROUP_RECORDS:
            existing = await SubWorkingGroupRecord.objects.filter(subtb=record.subtb).first()
            if existing is None:
                oxyde_record = SubWorkingGroupRecord(
                    subtb=record.subtb,
                    tbid=record.tbid,
                    code=record.code,
                    name=record.name,
                )
            await self.connection.save(oxyde_record)
                await SubWorkingGroupRecord.objects.create(instance=oxyde_record)

    async def _table_rows(self, model_class: type) -> list:
        """Fetch all rows for a given model class."""
        try:
            return await self.connection.all(model_class)
            return await model_class.objects.all()
        except Exception as exc:
            _logger.error("Failed to read table for %s: %s", model_class.__name__, exc)
            raise DatabaseError("table-read-failed", detail=f"Model '{model_class.__name__}': {exc}") from exc
@@ -196,8 +272,8 @@ class DocDatabase:
                counts[table] = 0
                continue
            try:
                count = await self.connection.count(model_class)
                await self.connection.delete_all(model_class)
                count = await model_class.objects.count()
                await model_class.objects.delete()
                counts[table] = count
            except (OSError, ValueError) as exc:
                _logger.error("Failed to clear table %s: %s", table, exc)
+3 −3
Original line number Diff line number Diff line
@@ -53,7 +53,7 @@ class MeetingDatabase(DocDatabase):
                created_at=now,
                updated_at=now,
            )
            await self.connection.save(created)
            await MeetingMetadata.objects.create(instance=created)
            return True, False

        created_at = record.created_at or existing.created_at
@@ -74,7 +74,7 @@ class MeetingDatabase(DocDatabase):
            created_at=created_at,
            updated_at=now,
        )
        await self.connection.save(updated)
        await updated.save()
        return False, changed

    async def bulk_upsert_meetings(
@@ -273,7 +273,7 @@ class MeetingDatabase(DocDatabase):

    async def _get_meeting(self, meeting_id: int) -> MeetingMetadata | None:
        """Get a meeting by ID."""
        return await self.connection.get(MeetingMetadata, meeting_id)
        return await MeetingMetadata.objects.filter(meeting_id=meeting_id).first()

    # ------------------------------------------------------------------
    # Normalisation helpers
+60 −9
Original line number Diff line number Diff line
@@ -6,33 +6,41 @@ These models replace the previous pydantic-sqlite based persistence layer.

from __future__ import annotations

import json
from datetime import date, datetime
from typing import Any

from oxyde import Field, Model
from pydantic import field_validator


class WorkingGroupRecord(Model):
    """Reference data for working groups."""

    tbid: int = Field(primary_key=True)
    tbid: int = Field(db_pk=True)
    code: str
    name: str

    class Meta:
        is_table = True


class SubWorkingGroupRecord(Model):
    """Reference data for subworking groups."""

    subtb: int = Field(primary_key=True)
    subtb: int = Field(db_pk=True)
    tbid: int
    code: str
    name: str

    class Meta:
        is_table = True


class CrawlLogEntry(Model):
    """Log entry for crawl operations."""

    log_id: str = Field(primary_key=True)
    log_id: str = Field(db_pk=True)
    crawl_type: str
    start_time: datetime
    end_time: datetime | None = None
@@ -44,11 +52,14 @@ class CrawlLogEntry(Model):
    status: str = Field(default="RUNNING")
    created_at: datetime = Field(default_factory=datetime.utcnow)

    class Meta:
        is_table = True


class MeetingMetadata(Model):
    """Metadata for 3GPP meetings."""

    meeting_id: int = Field(primary_key=True)
    meeting_id: int = Field(db_pk=True)
    tbid: int
    subtb: int | None = None
    short_name: str
@@ -63,11 +74,14 @@ class MeetingMetadata(Model):
    created_at: datetime = Field(default_factory=datetime.utcnow)
    updated_at: datetime = Field(default_factory=datetime.utcnow)

    class Meta:
        is_table = True


class TDocMetadata(Model):
    """Metadata for TDoc documents."""

    tdoc_id: str = Field(primary_key=True)
    tdoc_id: str = Field(db_pk=True)
    meeting_id: int
    title: str
    url: str | None = None
@@ -86,11 +100,14 @@ class TDocMetadata(Model):
    validated: bool = Field(default=False)
    validation_failed: bool = Field(default=False)

    class Meta:
        is_table = True


class Specification(Model):
    """Canonical specification identity and metadata."""

    spec_number: str = Field(primary_key=True)
    spec_number: str = Field(db_pk=True)
    spec_number_compact: str
    spec_type: str
    title: str
@@ -99,11 +116,14 @@ class Specification(Model):
    series: str
    latest_version: str | None = None

    class Meta:
        is_table = True


class SpecificationSourceRecord(Model):
    """Source-specific metadata snapshot for specifications."""

    record_id: str = Field(primary_key=True)
    record_id: str = Field(db_pk=True)
    spec_number: str
    source_name: str
    source_identifier: str | None = None
@@ -111,21 +131,49 @@ class SpecificationSourceRecord(Model):
    versions: list[str] = Field(default_factory=list)
    fetched_at: datetime | None = None

    @field_validator("metadata_payload", mode="before")
    @classmethod
    def _parse_metadata_payload(cls, value: dict[str, Any] | str) -> dict[str, Any]:
        if isinstance(value, str):
            try:
                parsed = json.loads(value)
            except json.JSONDecodeError:
                return {}
            return parsed if isinstance(parsed, dict) else {}
        return value

    @field_validator("versions", mode="before")
    @classmethod
    def _parse_versions(cls, value: list[str] | str) -> list[str]:
        if isinstance(value, str):
            try:
                parsed = json.loads(value)
            except json.JSONDecodeError:
                return []
            return [str(item) for item in parsed] if isinstance(parsed, list) else []
        return [str(item) for item in value]

    class Meta:
        is_table = True


class SpecificationVersion(Model):
    """Version details for specifications."""

    record_id: str = Field(primary_key=True)
    record_id: str = Field(db_pk=True)
    spec_number: str
    version: str
    file_name: str
    source_name: str

    class Meta:
        is_table = True


class SpecificationDownload(Model):
    """Download and extraction outcome for spec versions."""

    record_id: str = Field(primary_key=True)
    record_id: str = Field(db_pk=True)
    spec_number: str
    version: str
    download_url: str
@@ -137,6 +185,9 @@ class SpecificationDownload(Model):
    outcome_message: str | None = None
    extracted_at: datetime | None = None

    class Meta:
        is_table = True


__all__ = [
    "CrawlLogEntry",
+10 −11
Original line number Diff line number Diff line
@@ -139,11 +139,11 @@ class SpecDatabase(DocDatabase):
        """Upsert a specification record."""
        existing = await self._get_specification(specification.spec_number)
        if existing is None:
            await self.connection.save(specification)
            await Specification.objects.create(instance=specification)
            return True, False

        changed = self._spec_changed(existing, specification)
        await self.connection.save(specification)
        await specification.save()
        return False, changed

    async def upsert_spec_source_record(self, record: SpecificationSourceRecord) -> tuple[bool, bool]:
@@ -160,11 +160,11 @@ class SpecDatabase(DocDatabase):
        )
        existing = await self._get_spec_source_record(record_id)
        if existing is None:
            await self.connection.save(updated_record)
            await SpecificationSourceRecord.objects.create(instance=updated_record)
            return True, False

        changed = self._spec_source_changed(existing, updated_record)
        await self.connection.save(updated_record)
        await updated_record.save()
        return False, changed

    async def upsert_spec_version(self, version: SpecificationVersion) -> tuple[bool, bool]:
@@ -179,18 +179,17 @@ class SpecDatabase(DocDatabase):
        )
        existing = await self._get_spec_version(record_id)
        if existing is None:
            await self.connection.save(updated_version)
            await SpecificationVersion.objects.create(instance=updated_version)
            return True, False

        changed = self._spec_version_changed(existing, updated_version)
        await self.connection.save(updated_version)
        await updated_version.save()
        return False, changed

    async def get_spec_versions(self, spec_number: str) -> list[SpecificationVersion]:
        """Get all versions for a spec."""
        try:
            query = self.connection.query(SpecificationVersion).filter(spec_number=spec_number)
            return await query.all()
            return await SpecificationVersion.objects.filter(spec_number=spec_number).all()
        except Exception as exc:
            _logger.error("Failed to get spec versions for '%s': %s", spec_number, exc)
            raise DatabaseError("spec-versions-read-failed", detail=str(exc)) from exc
@@ -457,17 +456,17 @@ class SpecDatabase(DocDatabase):
        return await self._table_rows(Specification)

    async def _get_specification(self, spec_number: str) -> Specification | None:
        return await self.connection.get(Specification, spec_number)
        return await Specification.objects.filter(spec_number=spec_number).first()

    async def _get_spec_source_record(self, record_id: str) -> SpecificationSourceRecord | None:
        try:
            return await self.connection.get(SpecificationSourceRecord, record_id)
            return await SpecificationSourceRecord.objects.filter(record_id=record_id).first()
        except Exception as exc:
            _logger.warning("Failed to fetch spec source record '%s': %s", record_id, exc)
            return None

    async def _get_spec_version(self, record_id: str) -> SpecificationVersion | None:
        return await self.connection.get(SpecificationVersion, record_id)
        return await SpecificationVersion.objects.filter(record_id=record_id).first()

    @staticmethod
    def _spec_changed(current: Specification, candidate: Specification) -> bool:
+21 −5
Original line number Diff line number Diff line
@@ -49,7 +49,7 @@ class TDocDatabase(MeetingDatabase):
                    "date_created": record.date_created or record.date_retrieved,
                },
            )
            await self.connection.save(created_record)
            await TDocMetadata.objects.create(instance=created_record)
            return True, False

        if record.date_created is None and existing.date_created is not None:
@@ -57,7 +57,7 @@ class TDocDatabase(MeetingDatabase):

        changed = self._tdoc_changed(existing, record)
        updated_record = self._clone_tdoc(record, {"date_updated": now})
        await self.connection.save(updated_record)
        await updated_record.save()
        return False, changed

    async def bulk_upsert_tdocs(
@@ -170,7 +170,7 @@ class TDocDatabase(MeetingDatabase):

    async def _get_tdoc(self, tdoc_id: str) -> TDocMetadata | None:
        """Get a TDoc by ID."""
        return await self.connection.get(TDocMetadata, tdoc_id.upper())
        return await TDocMetadata.objects.filter(tdoc_id=tdoc_id.upper()).first()

    @classmethod
    def _meeting_matches_filters(
@@ -216,11 +216,27 @@ class TDocDatabase(MeetingDatabase):
        end_date: datetime | None,
    ) -> list[TDocMetadata]:
        """Filter records by retrieval datetime bounds."""

        def _normalized_bound(value: datetime, bound: datetime) -> tuple[datetime, datetime]:
            if value.tzinfo is None and bound.tzinfo is not None:
                return value.replace(tzinfo=UTC), bound
            if value.tzinfo is not None and bound.tzinfo is None:
                return value, bound.replace(tzinfo=UTC)
            return value, bound

        filtered = records
        if start_date is not None:
            filtered = [record for record in filtered if record.date_retrieved and record.date_retrieved >= start_date]
            filtered = [
                record
                for record in filtered
                if record.date_retrieved and _normalized_bound(record.date_retrieved, start_date)[0] >= _normalized_bound(record.date_retrieved, start_date)[1]
            ]
        if end_date is not None:
            filtered = [record for record in filtered if record.date_retrieved and record.date_retrieved <= end_date]
            filtered = [
                record
                for record in filtered
                if record.date_retrieved and _normalized_bound(record.date_retrieved, end_date)[0] <= _normalized_bound(record.date_retrieved, end_date)[1]
            ]
        return filtered

    @staticmethod