Commit 3a453305 authored by Jan Reimes's avatar Jan Reimes
Browse files

refactor: enhance type annotations and code structure across modules

parent cb8d9d25
Loading
Loading
Loading
Loading
+60 −64
Original line number Diff line number Diff line
@@ -12,14 +12,10 @@ from urllib.parse import urljoin

from bs4 import BeautifulSoup, Tag

from tdoc_crawler.crawlers.constants import (DATE_PATTERN,
                                             MEETING_CODE_REGISTRY,
                                             MEETINGS_BASE_URL,
                                             PORTAL_BASE_URL)
from tdoc_crawler.crawlers.constants import DATE_PATTERN, MEETING_CODE_REGISTRY, MEETINGS_BASE_URL, PORTAL_BASE_URL
from tdoc_crawler.database import TDocDatabase
from tdoc_crawler.http_client import create_cached_session
from tdoc_crawler.models import (CrawlLimits, MeetingCrawlConfig,
                                 MeetingMetadata, WorkingGroup)
from tdoc_crawler.models import CrawlLimits, MeetingCrawlConfig, MeetingMetadata, WorkingGroup

logger = logging.getLogger(__name__)

@@ -100,6 +96,64 @@ class MeetingCrawlResult:
class MeetingCrawler:
    """Crawler fetching meeting metadata from the 3GPP portal."""

    def __init__(self, database: TDocDatabase) -> None:
        self.database = database

    def crawl(self, config: MeetingCrawlConfig, progress_callback: Callable[[float, float], None] | None = None) -> MeetingCrawlResult:
        errors: list[str] = []
        meetings: list[MeetingMetadata] = []

        working_groups = self._limit_working_groups(config.working_groups, config.limits)
        existing_ids: set[int] = set()
        if config.incremental:
            existing_ids = self.database.get_existing_meeting_ids(working_groups)
        session = create_cached_session(
            cache_dir=config.cache_dir,
            ttl=config.http_cache.ttl,
            refresh_ttl_on_access=config.http_cache.refresh_ttl_on_access,
            max_retries=config.max_retries,
        )
        session.headers["User-Agent"] = "tdoc-crawler/0.0.1"
        if config.credentials is not None:
            session.auth = (config.credentials.username, config.credentials.password)

        try:
            for working_group in working_groups:
                for code, subgroup in MEETING_CODE_REGISTRY.get(working_group.value, []):
                    # Skip subgroup if subgroups filter is set and this subgroup is not in the list
                    if config.subgroups and subgroup not in config.subgroups:
                        continue
                    url = MEETINGS_BASE_URL.format(code=code)
                    try:
                        response = session.get(url, timeout=config.timeout)
                        response.raise_for_status()
                    except Exception as exc:
                        message = f"Meeting crawl failed for {code}: {exc}"
                        logger.warning(message)
                        errors.append(message)
                        continue
                    parsed_meetings = self._parse_meeting_page(response.text, working_group, subgroup)
                    for meeting in parsed_meetings:
                        if config.incremental and meeting.meeting_id in existing_ids:
                            continue
                        meetings.append(meeting)
        finally:
            session.close()

        filtered = self._apply_limits(meetings, config.limits)
        inserted = 0
        updated = 0
        if filtered:
            # Pass progress callback to bulk_upsert_meetings to update after each DB operation
            inserted, updated = self.database.bulk_upsert_meetings(filtered, progress_callback=progress_callback)

        return MeetingCrawlResult(
            processed=len(filtered),
            inserted=inserted,
            updated=updated,
            errors=errors,
        )

    def _limit_working_groups(self, working_groups: list[WorkingGroup], limits: CrawlLimits) -> list[WorkingGroup]:
        if limits.limit_wgs is None or limits.limit_wgs == 0:
            return working_groups
@@ -264,64 +318,6 @@ class MeetingCrawler:
        allowed = set(sequence[:limit]) if limit > 0 else set(sequence[limit:])
        return [meeting for meeting in meetings if meeting.meeting_id in allowed]

    def __init__(self, database: TDocDatabase) -> None:
        self.database = database

    def crawl(self, config: MeetingCrawlConfig, progress_callback: Callable[[float, float], None] | None = None) -> MeetingCrawlResult:
        errors: list[str] = []
        meetings: list[MeetingMetadata] = []

        working_groups = self._limit_working_groups(config.working_groups, config.limits)
        existing_ids: set[int] = set()
        if config.incremental:
            existing_ids = self.database.get_existing_meeting_ids(working_groups)
        session = create_cached_session(
            cache_dir=config.cache_dir,
            ttl=config.http_cache.ttl,
            refresh_ttl_on_access=config.http_cache.refresh_ttl_on_access,
            max_retries=config.max_retries,
        )
        session.headers["User-Agent"] = "tdoc-crawler/0.0.1"
        if config.credentials is not None:
            session.auth = (config.credentials.username, config.credentials.password)

        try:
            for working_group in working_groups:
                for code, subgroup in MEETING_CODE_REGISTRY.get(working_group.value, []):
                    # Skip subgroup if subgroups filter is set and this subgroup is not in the list
                    if config.subgroups and subgroup not in config.subgroups:
                        continue
                    url = MEETINGS_BASE_URL.format(code=code)
                    try:
                        response = session.get(url, timeout=config.timeout)
                        response.raise_for_status()
                    except Exception as exc:
                        message = f"Meeting crawl failed for {code}: {exc}"
                        logger.warning(message)
                        errors.append(message)
                        continue
                    parsed_meetings = self._parse_meeting_page(response.text, working_group, subgroup)
                    for meeting in parsed_meetings:
                        if config.incremental and meeting.meeting_id in existing_ids:
                            continue
                        meetings.append(meeting)
        finally:
            session.close()

        filtered = self._apply_limits(meetings, config.limits)
        inserted = 0
        updated = 0
        if filtered:
            # Pass progress callback to bulk_upsert_meetings to update after each DB operation
            inserted, updated = self.database.bulk_upsert_meetings(filtered, progress_callback=progress_callback)

        return MeetingCrawlResult(
            processed=len(filtered),
            inserted=inserted,
            updated=updated,
            errors=errors,
        )


__all__ = [
    "MEETING_CODE_REGISTRY",
+50 −102
Original line number Diff line number Diff line
@@ -11,71 +11,23 @@ from pathlib import Path
from pydantic_sqlite import DataBase

from tdoc_crawler.database.errors import DatabaseError
from tdoc_crawler.models import (CODE_INDEX, WORKING_GROUP_RECORDS,
                                 CrawlLogEntry, MeetingMetadata,
                                 MeetingQueryConfig, QueryConfig, TDocMetadata,
                                 WorkingGroup, utc_now)
from tdoc_crawler.models import (
    CODE_INDEX,
    WORKING_GROUP_RECORDS,
    CrawlLogEntry,
    MeetingMetadata,
    MeetingQueryConfig,
    QueryConfig,
    TDocMetadata,
    WorkingGroup,
    utc_now,
)
from tdoc_crawler.models.subworking_groups import SUBWORKING_GROUP_RECORDS


class TDocDatabase:
    """High-level facade for all database operations."""

    @staticmethod
    def _tdoc_changed(current: TDocMetadata, candidate: TDocMetadata) -> bool:
        for field in TDocMetadata.model_fields:
            if field == "date_updated":
                continue
            if getattr(current, field) != getattr(candidate, field):
                return True
        return False

    @staticmethod
    def _meeting_changed(current: MeetingMetadata, candidate: MeetingMetadata) -> bool:
        for field in MeetingMetadata.model_fields:
            if field in {"updated_at", "last_synced", "tdoc_count"}:
                continue
            if getattr(current, field) != getattr(candidate, field):
                return True
        return False
                        except ValueError, AttributeError:

    # ------------------------------------------------------------------
    # Normalisation helpers
    # ------------------------------------------------------------------
    def _prepare_tdoc(self, metadata: TDocMetadata) -> TDocMetadata:
        updates: dict[str, object] = {}
        if metadata.date_retrieved is None:
            updates["date_retrieved"] = utc_now()
        if updates:
            return metadata.model_copy(update=updates)
        return metadata

    def _prepare_meeting(self, metadata: MeetingMetadata) -> MeetingMetadata:
        if metadata.working_group is None and metadata.tbid:
            for working_group in WorkingGroup:
                if working_group.tbid == metadata.tbid:
                    return metadata.model_copy(update={"working_group": working_group})
        return metadata

    # ------------------------------------------------------------------
    # Data clearing methods
    # ------------------------------------------------------------------
    def _table_exists(self, table_name: str) -> bool:
        """Check if a table exists in the database.

        Args:
            table_name: Name of the table to check

        Returns:
            True if table exists, False otherwise
        """
        cursor = self.connection._db.execute(
            "SELECT name FROM sqlite_master WHERE type='table' AND name=?",
            (table_name,),
        )
        return cursor.fetchone() is not None

    def __init__(self, db_path: Path) -> None:
        self.db_path = db_path
        self._database: DataBase | None = None
@@ -92,7 +44,7 @@ class TDocDatabase:
            raise DatabaseError("database-initialization-failed", detail=str(exc)) from exc
        return self

    def __exit__(self, exc_type, exc, exc_tb) -> None:
    def __exit__(self, exc_type: type[BaseException] | None, exc: BaseException | None, exc_tb: object | None) -> None:
        self._database = None

    # ------------------------------------------------------------------
@@ -101,7 +53,6 @@ class TDocDatabase:
    @property
    def connection(self) -> DataBase:
        """Expose the underlying DataBase instance (read-only)."""

        if self._database is None:
            raise DatabaseError.connection_not_open()
        return self._database
@@ -536,9 +487,44 @@ class TDocDatabase:
        meetings_count = self.clear_meetings()
        return (tdocs_count, meetings_count)

    # ------------------------------------------------------------------
    # Normalisation helpers
    # ------------------------------------------------------------------
    def _prepare_tdoc(self, metadata: TDocMetadata) -> TDocMetadata:
        updates: dict[str, object] = {}
        if metadata.date_retrieved is None:
            updates["date_retrieved"] = utc_now()
        if updates:
            return metadata.model_copy(update=updates)
        return metadata

    def _prepare_meeting(self, metadata: MeetingMetadata) -> MeetingMetadata:
        if metadata.working_group is None and metadata.tbid:
            for working_group in WorkingGroup:
                if working_group.tbid == metadata.tbid:
                    return metadata.model_copy(update={"working_group": working_group})
        return metadata

    # ------------------------------------------------------------------
    # Data clearing methods
    # ------------------------------------------------------------------
    def _table_exists(self, table_name: str) -> bool:
        """Check if a table exists in the database.

        Args:
            table_name: Name of the table to check

        Returns:
            True if table exists, False otherwise
        """
        cursor = self.connection._db.execute(
            "SELECT name FROM sqlite_master WHERE type='table' AND name=?",
            (table_name,),
        )
        return cursor.fetchone() is not None

    def _ensure_reference_data(self) -> None:
        """Populate reference tables for working and subworking groups."""

        database = self.connection
        for record in WORKING_GROUP_RECORDS:
            database.add("working_groups", record, pk="tbid")
@@ -547,8 +533,6 @@ class TDocDatabase:

    def _table_rows(self, table: str) -> list:
        """Fetch all rows from a table using raw SQL to avoid pydantic_sqlite registry issues."""
        from datetime import datetime

        # Map table names to their model classes and safe names
        model_map = {
            "tdocs": TDocMetadata,
@@ -578,7 +562,7 @@ class TDocDatabase:
                        try:
                            if "T" in value and value.endswith(("Z", "+00:00")):
                                row_dict[key] = datetime.fromisoformat(value)
                        except ValueError, AttributeError:
                        except (ValueError, AttributeError):
                            pass
                result.append(model_class(**row_dict))
            return result
@@ -590,7 +574,7 @@ class TDocDatabase:

    def _get_meeting(self, meeting_id: int) -> MeetingMetadata | None:
        try:
            return self.connection.model_from_table("meetings", meeting_id)  # type: ignore[arg-type]
            return self.connection.model_from_table("meetings", str(meeting_id))
        except KeyError:
            return None

@@ -600,42 +584,6 @@ class TDocDatabase:
        except KeyError:
            return None

    # ------------------------------------------------------------------
    # Normalisation helpers
    # ------------------------------------------------------------------
    def _prepare_tdoc(self, metadata: TDocMetadata) -> TDocMetadata:
        updates: dict[str, object] = {}
        if metadata.date_retrieved is None:
            updates["date_retrieved"] = utc_now()
        if updates:
            return metadata.model_copy(update=updates)
        return metadata

    def _prepare_meeting(self, metadata: MeetingMetadata) -> MeetingMetadata:
        if metadata.working_group is None and metadata.tbid:
            for working_group in WorkingGroup:
                if working_group.tbid == metadata.tbid:
                    return metadata.model_copy(update={"working_group": working_group})
        return metadata

    # ------------------------------------------------------------------
    # Data clearing methods
    # ------------------------------------------------------------------
    def _table_exists(self, table_name: str) -> bool:
        """Check if a table exists in the database.

        Args:
            table_name: Name of the table to check

        Returns:
            True if table exists, False otherwise
        """
        cursor = self.connection._db.execute(
            "SELECT name FROM sqlite_master WHERE type='table' AND name=?",
            (table_name,),
        )
        return cursor.fetchone() is not None

    @staticmethod
    def _tdoc_changed(current: TDocMetadata, candidate: TDocMetadata) -> bool:
        for field in TDocMetadata.model_fields:
+51 −37
Original line number Diff line number Diff line
@@ -3,26 +3,44 @@
from __future__ import annotations

# Re-export all public symbols
from .base import (DEFAULT_CACHE_DIR, BaseConfigModel, OutputFormat,
                   PortalCredentials, SortOrder, utc_now)
from .crawl_limits import CrawlLimits
from .crawl_log import CrawlLogEntry
from .meetings import MeetingCrawlConfig, MeetingMetadata, MeetingQueryConfig
from .subworking_groups import (CODE_INDEX, SUBTB_INDEX,
from .base import (  # noqa: F401
    DEFAULT_CACHE_DIR,
    BaseConfigModel,
    OutputFormat,
    PortalCredentials,
    SortOrder,
    utc_now,
)
from .crawl_limits import CrawlLimits  # noqa: F401
from .crawl_log import CrawlLogEntry  # noqa: F401
from .meetings import (  # noqa: F401
    MeetingCrawlConfig,
    MeetingMetadata,
    MeetingQueryConfig,
)
from .subworking_groups import (  # noqa: F401
    CODE_INDEX,
    SUBTB_INDEX,
    SUBWORKING_GROUP_RECORDS,
                                SubWorkingGroupRecord)
from .tdocs import CrawlConfig, QueryConfig, TDocCrawlConfig, TDocMetadata
from .working_groups import (WORKING_GROUP_RECORDS, WorkingGroup,
                             WorkingGroupRecord)
    SubWorkingGroupRecord,
)
from .tdocs import CrawlConfig, QueryConfig, TDocCrawlConfig, TDocMetadata  # noqa: F401
from .working_groups import (  # noqa: F401
    WORKING_GROUP_RECORDS,
    WorkingGroup,
    WorkingGroupRecord,
)

__all__ = sorted(
    [
        "BaseConfigModel",
__all__ = [
    "CODE_INDEX",
    "DEFAULT_CACHE_DIR",
    "SUBTB_INDEX",
    "SUBWORKING_GROUP_RECORDS",
    "WORKING_GROUP_RECORDS",
    "BaseConfigModel",
    "CrawlConfig",
    "CrawlLimits",
    "CrawlLogEntry",
        "DEFAULT_CACHE_DIR",
    "MeetingCrawlConfig",
    "MeetingMetadata",
    "MeetingQueryConfig",
@@ -30,14 +48,10 @@ __all__ = sorted(
    "PortalCredentials",
    "QueryConfig",
    "SortOrder",
        "SUBTB_INDEX",
        "SUBWORKING_GROUP_RECORDS",
    "SubWorkingGroupRecord",
    "TDocCrawlConfig",
    "TDocMetadata",
        "WORKING_GROUP_RECORDS",
    "WorkingGroup",
    "WorkingGroupRecord",
    "utc_now",
]
)
+3 −9
Original line number Diff line number Diff line
@@ -2,13 +2,14 @@

from __future__ import annotations

from datetime import datetime
from datetime import UTC, date, datetime
from decimal import Decimal
from pathlib import Path

import pytest

from tdoc_crawler.database import TDocDatabase
from tdoc_crawler.models import TDocMetadata
from tdoc_crawler.models import MeetingMetadata, TDocMetadata, WorkingGroup


@pytest.fixture
@@ -46,9 +47,6 @@ def sample_tdocs() -> list[TDocMetadata]:
    Returns:
        List of sample TDocMetadata instances
    """
    from datetime import UTC
    from decimal import Decimal

    return [
        TDocMetadata(
            tdoc_id="R1-2301234",
@@ -188,10 +186,6 @@ def insert_sample_meetings(database: TDocDatabase, meetings: list[dict]) -> None
            - location: Meeting location
            - files_url: URL to meeting files
    """
    from datetime import date

    from tdoc_crawler.models import MeetingMetadata, WorkingGroup

    # Map tbid to WorkingGroup enum
    tbid_to_wg = {
        373: WorkingGroup.RAN,  # RAN
+2 −7
Original line number Diff line number Diff line
@@ -5,11 +5,12 @@ from __future__ import annotations
from pathlib import Path
from unittest.mock import MagicMock, patch

from conftest import insert_sample_meetings
from typer.testing import CliRunner

from tdoc_crawler.cli import app
from tdoc_crawler.database import TDocDatabase
from tdoc_crawler.models import TDocMetadata
from tdoc_crawler.models import TDocMetadata, WorkingGroup

runner = CliRunner()

@@ -300,8 +301,6 @@ class TestQueryMeetingsCommand:
        test_cache_dir: Path,
    ) -> None:
        """Test query-meetings with working group alias (SP -> SA)."""
        from tdoc_crawler.models import WorkingGroup

        mock_db = MagicMock(spec=TDocDatabase)
        mock_db_class.return_value.__enter__.return_value = mock_db
        mock_db.query_meetings.return_value = []
@@ -323,8 +322,6 @@ class TestQueryMeetingsCommand:
        test_cache_dir: Path,
    ) -> None:
        """Test query-meetings with combined working group and subgroup filters."""
        from tdoc_crawler.models import WorkingGroup

        mock_db = MagicMock(spec=TDocDatabase)
        mock_db_class.return_value.__enter__.return_value = mock_db
        mock_db.query_meetings.return_value = []
@@ -351,8 +348,6 @@ class TestStatsCommand:
        sample_meetings: list[dict],
    ) -> None:
        """Test stats command execution."""
        from conftest import insert_sample_meetings

        with TDocDatabase(test_db_path) as db:
            insert_sample_meetings(db, sample_meetings)
            for tdoc in sample_tdocs:
Loading