Commit e5157cbd authored by Jan Reimes's avatar Jan Reimes
Browse files

refactor(database): extract helpers, parallelize spec crawl, remove TYPE_CHECKING

parent d84c5f43
Loading
Loading
Loading
Loading
+56 −60
Original line number Diff line number Diff line
@@ -6,7 +6,7 @@ from pathlib import Path
from typing import Self
from uuid import uuid4

from oxyde import AsyncDatabase, Model
from oxyde import AsyncDatabase, Model, execute_raw
from oxyde.migrations import extract_current_schema

from tdoc_crawler.database.errors import DatabaseError
@@ -27,6 +27,51 @@ from tdoc_crawler.models.working_groups import WORKING_GROUP_RECORDS

_logger = get_logger(__name__)

_PYTHON_TO_SQLITE_TYPE: dict[str, str] = {
    "int": "INTEGER",
    "str": "TEXT",
    "float": "REAL",
    "datetime": "TEXT",
    "date": "TEXT",
    "bool": "INTEGER",
    "bytes": "BLOB",
    "list": "TEXT",
    "dict": "TEXT",
}


def _build_column_sql(field: dict, pk_count: int) -> str:
    """Build a single column SQL definition from a schema field dict."""
    col_name = field["name"]
    sql_type = _PYTHON_TO_SQLITE_TYPE.get(field["python_type"], "TEXT")

    if not field.get("nullable", True):
        sql_type += " NOT NULL"

    default = field.get("default")
    if default is not None and default != "NULL":
        sql_type += f" DEFAULT {default}"

    is_pk = bool(field.get("primary_key", False))
    is_unique = bool(field.get("unique", False))
    if is_pk and pk_count == 1:
        sql_type += " PRIMARY KEY"
    elif is_unique:
        sql_type += " UNIQUE"

    return f"{col_name} {sql_type}"


def _build_column_defs(fields: list[dict]) -> tuple[list[str], list[str]]:
    """Build all column definitions and collect primary key field names.

    Returns:
        Tuple of (column SQL strings, primary key field names).
    """
    pk_fields = [f["name"] for f in fields if bool(f.get("primary_key", False))]
    column_defs = [_build_column_sql(f, len(pk_fields)) for f in fields]
    return column_defs, pk_fields


class DocDatabase:
    """High-level facade for all database operations using Oxyde ORM."""
@@ -59,7 +104,8 @@ class DocDatabase:
            await self._ensure_tables_exist()
            await self._ensure_reference_data()
        except Exception as exc:
            raise DatabaseError("database-initialization-failed", detail=str(exc)) from exc
            msg = "database-initialization-failed"
            raise DatabaseError(msg, detail=str(exc)) from exc
        return self

    async def __aexit__(self, exc_type: type[BaseException] | None, exc: BaseException | None, exc_tb: object | None) -> None:
@@ -146,71 +192,19 @@ class DocDatabase:
        if self._database is None:
            return

        # Extract current schema from registered models
        schema = extract_current_schema(dialect="sqlite")

        # Build and execute CREATE TABLE IF NOT EXISTS for each table
        for table_name, table_def in schema.get("tables", {}).items():
            fields = table_def.get("fields", [])
            if not fields:
                continue

            # Build column definitions
            column_defs: list[str] = []
            pk_fields: list[str] = []
            for field in fields:
                if bool(field.get("primary_key", False)):
                    pk_fields.append(field["name"])

            for field in fields:
                col_name = field["name"]
                python_type = field["python_type"]

                # Map Python types to SQLite types
                sql_type: str
                if python_type == "int":
                    sql_type = "INTEGER"
                elif python_type == "str":
                    sql_type = "TEXT"
                elif python_type == "float":
                    sql_type = "REAL"
                elif python_type in ("datetime", "date"):
                    sql_type = "TEXT"
                elif python_type == "bool":
                    sql_type = "INTEGER"
                elif python_type == "bytes":
                    sql_type = "BLOB"
                elif python_type in ("list", "dict"):
                    sql_type = "TEXT"  # Store as JSON
                else:
                    sql_type = "TEXT"

                nullable = field.get("nullable", True)
                if not nullable:
                    sql_type += " NOT NULL"

                default = field.get("default")
                if default is not None and default != "NULL":
                    sql_type += f" DEFAULT {default}"

                is_pk = bool(field.get("primary_key", False))
                is_unique = bool(field.get("unique", False))
                if is_pk and len(pk_fields) == 1:
                    sql_type += " PRIMARY KEY"
                elif is_unique:
                    sql_type += " UNIQUE"

                column_defs.append(f"{col_name} {sql_type}")

            column_defs, pk_fields = _build_column_defs(fields)
            if len(pk_fields) > 1:
                pk_cols = ", ".join(pk_fields)
                column_defs.append(f"PRIMARY KEY ({pk_cols})")
                column_defs.append(f"PRIMARY KEY ({', '.join(pk_fields)})")

            # Build complete CREATE TABLE IF NOT EXISTS statement
            columns_sql = ", ".join(column_defs)
            create_sql = f"CREATE TABLE IF NOT EXISTS {table_name} ({columns_sql})"

            # Execute the SQL using Oxyde's IR format for raw SQL
            raw_ir = {"proto": 1, "op": "raw", "table": "", "sql": create_sql}
            await self._database.execute(raw_ir)

@@ -241,8 +235,9 @@ class DocDatabase:
        try:
            return await model_class.objects.all()
        except Exception as exc:
            _logger.error("Failed to read table for %s: %s", model_class.__name__, exc)
            raise DatabaseError("table-read-failed", detail=f"Model '{model_class.__name__}': {exc}") from exc
            _logger.exception("Failed to read table for %s: %s", model_class.__name__, exc)
            msg = "table-read-failed"
            raise DatabaseError(msg, detail=f"Model '{model_class.__name__}': {exc}") from exc

    async def _clear_tables(self, table_names: str | Iterable[str]) -> dict[str, int]:
        """Clear specified tables from database.
@@ -277,10 +272,11 @@ class DocDatabase:
                continue
            try:
                count = await model_class.objects.count()
                await model_class.objects.delete()
                table_name = model_class.get_table_name()
                await execute_raw(f"DELETE FROM {table_name}", using="default")
                counts[table] = count
            except (OSError, ValueError) as exc:
                _logger.error("Failed to clear table %s: %s", table, exc)
                _logger.exception("Failed to clear table %s: %s", table, exc)
                counts[table] = 0
        return counts

+51 −48
Original line number Diff line number Diff line
"""Meeting database operations."""

import asyncio
from collections import defaultdict
from collections.abc import Callable, Iterable
from datetime import datetime

from tdoc_crawler.config.settings import PathConfig
from tdoc_crawler.database.base import DocDatabase
from tdoc_crawler.database.oxyde_models import CrawlLogEntry, MeetingMetadata, TDocMetadata
from tdoc_crawler.logging import get_logger
@@ -19,6 +17,31 @@ from tdoc_crawler.utils.normalization import normalize_portal_meeting_name
_logger = get_logger(__name__)


def _filter_by_working_group(meetings: list[MeetingMetadata], allowed_tbids: set[int]) -> list[MeetingMetadata]:
    """Filter meetings by allowed TBID set."""
    return [m for m in meetings if m.tbid in allowed_tbids]


def _filter_by_subgroup(meetings: list[MeetingMetadata], allowed_codes: set[str]) -> list[MeetingMetadata]:
    """Filter meetings by subgroup codes."""
    return [m for m in meetings if m.subtb in SUBTB_INDEX and SUBTB_INDEX[m.subtb].code in allowed_codes]


def _filter_with_files(meetings: list[MeetingMetadata]) -> list[MeetingMetadata]:
    """Keep only meetings that have a files URL."""
    return [m for m in meetings if m.files_url]


def _filter_by_start_date(meetings: list[MeetingMetadata], start_date: datetime) -> list[MeetingMetadata]:
    """Keep meetings starting on or after the given date."""
    return [m for m in meetings if m.start_date and m.start_date >= start_date]


def _filter_by_end_date(meetings: list[MeetingMetadata], end_date: datetime) -> list[MeetingMetadata]:
    """Keep meetings ending on or before the given date."""
    return [m for m in meetings if m.end_date and m.end_date <= end_date]


class MeetingDatabase(DocDatabase):
    """Database operations for meeting metadata."""

@@ -119,23 +142,7 @@ class MeetingDatabase(DocDatabase):
            List of matching meeting metadata
        """
        meetings = await self._table_rows(MeetingMetadata)

        if config.working_groups:
            allowed = {wg.tbid for wg in config.working_groups}
            meetings = [meeting for meeting in meetings if meeting.tbid in allowed]

        if config.subgroups:
            allowed_subgroups = {value.strip().upper() for value in config.subgroups}
            meetings = [meeting for meeting in meetings if meeting.subtb in SUBTB_INDEX and SUBTB_INDEX[meeting.subtb].code in allowed_subgroups]

        if not config.include_without_files:
            meetings = [meeting for meeting in meetings if meeting.files_url]

        # Date range filters
        if config.start_date is not None:
            meetings = [meeting for meeting in meetings if meeting.start_date and meeting.start_date >= config.start_date]
        if config.end_date is not None:
            meetings = [meeting for meeting in meetings if meeting.end_date and meeting.end_date <= config.end_date]
        meetings = self._apply_meeting_filters(meetings, config)

        descending = config.order.value.lower() == "desc"
        meetings.sort(
@@ -277,6 +284,30 @@ class MeetingDatabase(DocDatabase):
        """Get a meeting by ID."""
        return await MeetingMetadata.objects.filter(meeting_id=meeting_id).first()

    @staticmethod
    def _apply_meeting_filters(
        meetings: list[MeetingMetadata],
        config: MeetingQueryConfig,
    ) -> list[MeetingMetadata]:
        """Apply configured filters to a meeting list."""
        if config.working_groups:
            allowed = {wg.tbid for wg in config.working_groups}
            meetings = _filter_by_working_group(meetings, allowed)

        if config.subgroups:
            allowed_subgroups = {value.strip().upper() for value in config.subgroups}
            meetings = _filter_by_subgroup(meetings, allowed_subgroups)

        if not config.include_without_files:
            meetings = _filter_with_files(meetings)

        if config.start_date is not None:
            meetings = _filter_by_start_date(meetings, config.start_date)
        if config.end_date is not None:
            meetings = _filter_by_end_date(meetings, config.end_date)

        return meetings

    # ------------------------------------------------------------------
    # Normalisation helpers
    # ------------------------------------------------------------------
@@ -296,32 +327,4 @@ class MeetingDatabase(DocDatabase):
        return False


async def _resolve_meeting_short_name(meeting_id: int) -> str | None:
    """Look up a meeting's short name by ID.

    Args:
        meeting_id: Meeting identifier.

    Returns:
        Meeting short name (e.g., "SA4#134") if found, None otherwise.
    """
    async with MeetingDatabase(PathConfig().db_file) as db:
        meeting = await db._get_meeting(meeting_id)
        return meeting.short_name if meeting is not None else None


def get_meeting_short_name(meeting_id: int) -> str | None:
    """Sync convenience wrapper to look up a meeting's short name by ID.

    Uses ``asyncio.run()`` — must not be called from an async context.

    Args:
        meeting_id: Meeting identifier.

    Returns:
        Meeting short name (e.g., "SA4#134") if found, None otherwise.
    """
    return asyncio.run(_resolve_meeting_short_name(meeting_id))


__all__ = ["MeetingDatabase", "get_meeting_short_name"]
__all__ = ["MeetingDatabase"]
+168 −123

File changed.

Preview size limit exceeded, changes collapsed.

+1 −2
Original line number Diff line number Diff line
@@ -165,8 +165,7 @@ class TDocDatabase(MeetingDatabase):
        filtered = self._filter_by_pattern(filtered, config.title_pattern, lambda record: record.title, exclude=False)
        filtered = self._filter_by_pattern(filtered, config.title_pattern_exclude, lambda record: record.title, exclude=True)
        filtered = self._filter_by_pattern(filtered, config.agenda_pattern, lambda record: record.agenda_item_text, exclude=False)
        filtered = self._filter_by_pattern(filtered, config.agenda_pattern_exclude, lambda record: record.agenda_item_text, exclude=True)
        return filtered
        return self._filter_by_pattern(filtered, config.agenda_pattern_exclude, lambda record: record.agenda_item_text, exclude=True)

    async def _get_tdoc(self, tdoc_id: str) -> TDocMetadata | None:
        """Get a TDoc by ID."""