Commit fee3bbf2 authored by Jan Reimes's avatar Jan Reimes
Browse files

feat(http-client): add connection pooling configuration for HTTP sessions

* Introduced PoolConfig dataclass for HTTP connection pooling settings.
* Updated create_cached_session to accept pool_config for pooling behavior.
* Enhanced download_to_file to utilize connection timeout from pool_config.
* Added tests to verify pooling behavior in cached and non-cached sessions.
parent dec472f7
Loading
Loading
Loading
Loading
+32 −14
Original line number Diff line number Diff line
@@ -7,7 +7,6 @@ import importlib.metadata
import json
import logging
from datetime import UTC, datetime
from functools import cache
from pathlib import Path
from typing import Any, NoReturn

@@ -54,7 +53,7 @@ from tdoc_crawler.cli.args import (
    WorkspaceProcessNewOnlyOption,
    WorkspaceReleaseOption,
)
from tdoc_crawler.config import CacheManager, resolve_cache_manager
from tdoc_crawler.config import CacheManager
from tdoc_crawler.database import SpecDatabase, TDocDatabase
from tdoc_crawler.logging import get_console
from tdoc_crawler.models.base import OutputFormat, SortOrder
@@ -65,8 +64,17 @@ from tdoc_crawler.utils.normalization import resolve_release_to_full_version
ai_app = typer.Typer(help="3GPP AI - Document Processing and RAG")
console = get_console()
_logger = logging.getLogger(__name__)

# TODO: This is a really hacky way to handle optional AI dependencies. Refactor to a proper plugin system or separate module if it grows.
# TODO: The whole module is fucked-up with circular imports and dynamic loading. Refactor to a cleaner structure if more commands are added.
# TODO: The whole module is fucked-up with redundant global variables and repeated/dynamic imports.
# TODO: CacheManager handling is fucked up and absolute mess. Not even an different cache directory is considered!

_ai_loaded: list[bool] = [False]
_ai_config: Any = None

# Register CacheManager once at module load time (CLI is the entry point)
_cache_manager: CacheManager | None = None
checkout_spec_to_workspace: Any = None
checkout_tdoc_to_workspace: Any = None
convert_document: Any = None
@@ -155,6 +163,11 @@ def _load_ai() -> None:

    _ai_loaded[0] = True

    # Register CacheManager once when AI module is loaded
    global _cache_manager
    if _cache_manager is None:
        _cache_manager = CacheManager().register()


@ai_app.callback()
def _check_ai_installed() -> None:
@@ -162,15 +175,19 @@ def _check_ai_installed() -> None:
    _load_ai()


@cache
def _get_cache_manager() -> CacheManager:
    """Get or create the default cache manager (avoids repeated register() calls)."""
    # Check if already registered
    try:
        return resolve_cache_manager("default")
    except ValueError:
        # Not registered yet, create and register
    """Get the registered cache manager.

    Returns:
        The registered CacheManager instance.

    Raises:
        ValueError: If CacheManager was not registered (should not happen in CLI).
    """
    if _cache_manager is None:
        # Fallback: register now (should have happened in _load_ai)
        return CacheManager().register()
    return _cache_manager


def resolve_workspace(workspace: str | None) -> str:
@@ -230,7 +247,10 @@ def ai_convert(
        elif json_output:
            typer.echo(json.dumps({"markdown": markdown_content}))
        else:
            typer.echo(markdown_content)
            # Write raw UTF-8 to stdout (avoid Rich markup/encoding issues)
            import sys

            sys.stdout.buffer.write(markdown_content.encode("utf-8"))
    except Exception as exc:
        if json_output:
            typer.echo(json.dumps({"error": str(exc)}), err=True)
@@ -688,7 +708,7 @@ def workspace_process(
    import asyncio

    workspace = resolve_workspace(workspace)
    manager = _get_cache_manager()
    _get_cache_manager()

    # Get workspace members
    members = list_workspace_members(workspace, include_inactive=False)
@@ -713,8 +733,6 @@ def workspace_process(
        processor = TDocProcessor(config)
        results = []

        import asyncio

        from threegpp_ai.lightrag.config import LightRAGConfig
        from threegpp_ai.lightrag.metadata import RAGMetadata
        from threegpp_ai.lightrag.processor import TDocProcessor
@@ -776,8 +794,8 @@ def workspace_process(
                metadata = None
                if member.source_item_id.startswith(("S", "R", "C", "T")):  # TDoc ID pattern
                    try:
                        from tdoc_crawler.database import TDocDatabase
                        from tdoc_crawler.config import resolve_cache_manager
                        from tdoc_crawler.database import TDocDatabase

                        manager = resolve_cache_manager()
                        with TDocDatabase(manager.db_file) as db:
+3 −2
Original line number Diff line number Diff line
@@ -3,10 +3,11 @@
Re-exports from session module for backward-compatible imports:
    from tdoc_crawler.http_client import create_cached_session
    from tdoc_crawler.http_client import download_to_file
    from tdoc_crawler.http_client import PoolConfig
"""

from __future__ import annotations

from tdoc_crawler.http_client.session import create_cached_session, download_to_file
from tdoc_crawler.http_client.session import PoolConfig, create_cached_session, download_to_file

__all__ = ["create_cached_session", "download_to_file"]
__all__ = ["PoolConfig", "create_cached_session", "download_to_file"]
+67 −8
Original line number Diff line number Diff line
@@ -3,12 +3,14 @@
from __future__ import annotations

import os
from dataclasses import dataclass
from pathlib import Path
from typing import cast

import requests
from hishel import SyncBaseStorage, SyncSqliteStorage
from hishel.requests import CacheAdapter
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry

from tdoc_crawler.config import resolve_cache_manager
@@ -19,6 +21,25 @@ from tdoc_crawler.models import HttpCacheConfig
logger = get_logger(__name__)


@dataclass
class PoolConfig:
    """HTTP connection pool configuration.

    Attributes:
        max_connections: Maximum number of connection pools to cache
        max_per_host: Maximum number of connections per host
        connection_timeout: Connection timeout in seconds
        enable_retry: Whether to enable retry logic
        retry_attempts: Number of retry attempts for failed requests
    """

    max_connections: int = 10
    max_per_host: int = 5
    connection_timeout: float = 30.0
    enable_retry: bool = True
    retry_attempts: int = 3


def download_to_file(
    url: str,
    destination: Path,
@@ -27,6 +48,7 @@ def download_to_file(
    http_cache: HttpCacheConfig | None = None,
    cache_manager_name: str | None = None,
    http_cache_enabled: bool | None = None,
    pool_config: PoolConfig | None = None,
) -> requests.Session | None:
    """Download a file from URL to destination path.

@@ -38,6 +60,7 @@ def download_to_file(
        http_cache: Optional HTTP cache configuration
        cache_manager_name: Optional cache manager name to determine cache directory when creating a temporary session.
        http_cache_enabled: Whether to enable HTTP caching. If None, defaults to True.
        pool_config: Optional connection pool configuration.

    Raises:
        ValueError: If URL scheme is not supported
@@ -52,13 +75,19 @@ def download_to_file(
    # Use provided session or create a new one (might be used for multiple downloads, so we don't want to create a new session for each)
    temp_session: requests.Session | None = None
    if session is None:
        temp_session = create_cached_session(http_cache=http_cache, cache_manager_name=cache_manager_name, http_cache_enabled=http_cache_enabled)
        temp_session = create_cached_session(
            http_cache=http_cache,
            cache_manager_name=cache_manager_name,
            http_cache_enabled=http_cache_enabled,
            pool_config=pool_config,
        )
        active_session = temp_session
    else:
        active_session = session

    try:
        response = active_session.get(url, timeout=60, stream=True)
        timeout = pool_config.connection_timeout if pool_config is not None else 60
        response = active_session.get(url, timeout=timeout, stream=True)
        response.raise_for_status()
        with destination.open("wb") as target:
            for chunk in response.iter_content(chunk_size=8192):
@@ -79,6 +108,7 @@ def create_cached_session(
    http_cache: HttpCacheConfig | None = None,
    cache_manager_name: str | None = None,
    http_cache_enabled: bool | None = None,
    pool_config: PoolConfig | None = None,
) -> requests.Session:
    """Create a requests.Session with hishel caching enabled.

@@ -87,6 +117,9 @@ def create_cached_session(
        cache_manager_name: Optional cache manager name to determine cache configuration.
        http_cache_enabled: Whether to enable HTTP caching. If None, defaults to True.
                         Can be set via HTTP_CACHE_ENABLED environment variable.
        pool_config: Optional connection pool configuration.
                When provided, pool settings are applied to the active adapter
                (cache adapter when caching is enabled, HTTPAdapter otherwise).

    Returns:
        Configured requests.Session with caching enabled (unless disabled)
@@ -100,8 +133,24 @@ def create_cached_session(
        env_enabled = os.getenv("HTTP_CACHE_ENABLED", "").lower()
        http_cache_enabled = env_enabled not in ("false", "0", "no", "off", "f", "n")

    # If caching is disabled, return a plain session without caching
    # If caching is disabled, optionally configure a pooled HTTP adapter and return.
    if not http_cache_enabled:
        if pool_config is not None:
            retry_strategy = Retry(
                total=pool_config.retry_attempts if pool_config.enable_retry else 0,
                backoff_factor=1,
                status_forcelist=[429, 500, 502, 503, 504],
                allowed_methods=["HEAD", "GET", "OPTIONS"],
            )
            adapter = HTTPAdapter(
                pool_connections=pool_config.max_connections,
                pool_maxsize=pool_config.max_per_host,
                max_retries=retry_strategy,
            )
            session.mount("http://", adapter)
            session.mount("https://", adapter)
            logger.debug(f"Configured connection pool without caching: max_connections={pool_config.max_connections}, max_per_host={pool_config.max_per_host}")

        logger.debug("Creating plain HTTP session (caching disabled)")
        return session

@@ -122,22 +171,32 @@ def create_cached_session(
    )
    storage = cast(SyncBaseStorage, storage)

    # Configure retry strategy for the session
    # Configure retry strategy for the cache adapter.
    # If pool_config is set, reuse its retry settings so caching + pooling share one adapter.
    retry_attempts = pool_config.retry_attempts if pool_config and pool_config.enable_retry else http_cache.max_retries
    max_retries = Retry(
        total=http_cache.max_retries,
        total=retry_attempts,
        backoff_factor=1,
        status_forcelist=[429, 500, 502, 503, 504],
        allowed_methods=["HEAD", "GET", "OPTIONS"],
    )

    # Create cache adapter
    cache_adapter = CacheAdapter(storage=storage, max_retries=max_retries)  # ty:ignore[invalid-argument-type]
    # Create a single cache adapter with pool settings so both behaviors are active.
    cache_adapter = CacheAdapter(
        pool_connections=pool_config.max_connections if pool_config is not None else 10,
        pool_maxsize=pool_config.max_per_host if pool_config is not None else 10,
        max_retries=max_retries,
        storage=storage,
    )

    # Mount the cache adapter
    session.mount("http://", cache_adapter)
    session.mount("https://", cache_adapter)

    if pool_config is not None:
        logger.debug(f"Configured cache + pool adapter: max_connections={pool_config.max_connections}, max_per_host={pool_config.max_per_host}")

    return session


__all__ = ["create_cached_session"]
__all__ = ["PoolConfig", "create_cached_session", "download_to_file"]
+70 −1
Original line number Diff line number Diff line
@@ -10,9 +10,10 @@ from unittest.mock import MagicMock, patch
import pytest
import requests
from hishel.requests import CacheAdapter
from requests.adapters import HTTPAdapter

from tdoc_crawler.config import DEFAULT_HTTP_CACHE_FILENAME, CacheManager, reset_cache_managers
from tdoc_crawler.http_client import create_cached_session
from tdoc_crawler.http_client import PoolConfig, create_cached_session, download_to_file
from tdoc_crawler.models.base import HttpCacheConfig


@@ -111,6 +112,50 @@ class TestCreateCachedSession:

        session.close()

    def test_cache_adapter_keeps_pool_settings(self, test_cache_dir: Path) -> None:
        """Regression: cache + pool settings are active on the same adapter."""
        CacheManager(root_path=test_cache_dir, name="cache_pool").register()

        pool_config = PoolConfig(max_connections=7, max_per_host=11, retry_attempts=4)
        session = create_cached_session(
            cache_manager_name="cache_pool",
            pool_config=pool_config,
        )

        http_adapter = session.adapters["http://"]
        https_adapter = session.adapters["https://"]

        assert isinstance(http_adapter, CacheAdapter)
        assert isinstance(https_adapter, CacheAdapter)
        assert http_adapter._pool_connections == 7
        assert http_adapter._pool_maxsize == 11
        assert https_adapter._pool_connections == 7
        assert https_adapter._pool_maxsize == 11

        session.close()

    def test_pool_only_session_uses_http_adapter(self) -> None:
        """When caching is disabled, pooling should still be configured."""
        pool_config = PoolConfig(max_connections=4, max_per_host=9, retry_attempts=2)
        session = create_cached_session(
            http_cache_enabled=False,
            pool_config=pool_config,
        )

        http_adapter = session.adapters["http://"]
        https_adapter = session.adapters["https://"]

        assert isinstance(http_adapter, HTTPAdapter)
        assert isinstance(https_adapter, HTTPAdapter)
        assert not isinstance(http_adapter, CacheAdapter)
        assert not isinstance(https_adapter, CacheAdapter)
        assert http_adapter._pool_connections == 4
        assert http_adapter._pool_maxsize == 9
        assert https_adapter._pool_connections == 4
        assert https_adapter._pool_maxsize == 9

        session.close()

    def test_default_parameters(self, test_cache_dir: Path) -> None:
        """Test that default parameters work correctly."""
        CacheManager(root_path=test_cache_dir, name="default").register()
@@ -238,6 +283,30 @@ class TestCachingBehavior:
        session.close()


class TestDownloadToFilePooling:
    """Tests for pooled timeout usage in download helper."""

    def test_uses_pool_timeout_for_download(self, tmp_path: Path) -> None:
        """Regression: download_to_file should use pool connection timeout."""
        destination = tmp_path / "download.bin"
        session = MagicMock(spec=requests.Session)
        response = MagicMock()
        response.iter_content.return_value = [b"abc", b"def"]
        response.raise_for_status.return_value = None
        session.get.return_value = response

        pool_config = PoolConfig(connection_timeout=7.5)
        download_to_file(
            "https://example.com/test.bin",
            destination,
            session=session,
            pool_config=pool_config,
        )

        session.get.assert_called_once_with("https://example.com/test.bin", timeout=7.5, stream=True)
        assert destination.read_bytes() == b"abcdef"


class TestResolveHttpCacheConfig:
    """Tests for resolve_http_cache_config helper function."""

+1 −1

File changed.

Contains only whitespace changes.