Commit eaebc5ba authored by Jan Reimes's avatar Jan Reimes
Browse files

♻️ refactor(http_client): simplify niquests-hishel session adapters

parent 9e3dca91
Loading
Loading
Loading
Loading
+0 −4
Original line number Diff line number Diff line
@@ -10,8 +10,6 @@ from __future__ import annotations

from tdoc_crawler.http_client.session import (
    PoolConfig,
    SSLContextCacheAdapter,
    SSLContextHTTPAdapter,
    create_cached_session,
    download_to_file,
    resolve_ssl_verify,
@@ -19,8 +17,6 @@ from tdoc_crawler.http_client.session import (

__all__ = [
    "PoolConfig",
    "SSLContextCacheAdapter",
    "SSLContextHTTPAdapter",
    "create_cached_session",
    "download_to_file",
    "resolve_ssl_verify",
+72 −58
Original line number Diff line number Diff line
@@ -2,16 +2,17 @@

from __future__ import annotations

import ssl
from dataclasses import dataclass
from pathlib import Path
from typing import cast

import niquests as requests
from hishel._core._headers import Headers
from hishel._core.models import Request, Response
from hishel import SyncBaseStorage, SyncSqliteStorage
from hishel.requests import CacheAdapter
from hishel.requests import CacheAdapter, extract_metadata_from_headers, snake_to_header
from niquests.adapters import HTTPAdapter
from truststore import SSLContext as TruststoreSSLContext
from urllib3.response import HTTPResponse
from urllib3.util.retry import Retry

from tdoc_crawler.config.settings import HttpConfig, PathConfig
@@ -21,40 +22,71 @@ from tdoc_crawler.logging import get_logger
logger = get_logger(__name__)


class SSLContextHTTPAdapter(HTTPAdapter):
    """HTTP adapter that can enforce a specific SSL context."""
def _niquests_to_internal_request(model: requests.models.PreparedRequest) -> Request:
    """Convert niquests prepared request into hishel internal request model."""
    body: bytes
    if isinstance(model.body, str):
        body = model.body.encode("utf-8")
    elif isinstance(model.body, bytes):
        body = model.body
    else:
        body = b""

    def __init__(self, *args: object, ssl_context: ssl.SSLContext | None = None, **kwargs: object) -> None:
        self._ssl_context = ssl_context
        super().__init__(*args, **kwargs)
    if model.method is None:
        raise ValueError("prepared request method must not be None")

    def init_poolmanager(self, connections: int, maxsize: int, block: bool = False, **pool_kwargs: object) -> None:
        if self._ssl_context is not None:
            pool_kwargs["ssl_context"] = self._ssl_context
        super().init_poolmanager(connections=connections, maxsize=maxsize, block=block, **pool_kwargs)
    return Request(
        method=model.method,
        url=str(model.url),
        headers=Headers(model.headers),
        stream=iter([body]),
        metadata=extract_metadata_from_headers(model.headers),
    )

    def proxy_manager_for(self, proxy: str, **proxy_kwargs: object) -> object:
        if self._ssl_context is not None:
            proxy_kwargs["ssl_context"] = self._ssl_context
        return super().proxy_manager_for(proxy, **proxy_kwargs)

def _internal_to_niquests_response(model: Response) -> requests.models.Response:
    """Convert hishel internal response model into niquests response model."""
    response = requests.models.Response()

class SSLContextCacheAdapter(CacheAdapter):
    """Cache adapter variant that can enforce a specific SSL context."""
    body = b"".join(model.stream) if model.stream is not None else b""
    metadata_headers = {snake_to_header(key): str(value) for key, value in model.metadata.items()}

    def __init__(self, *args: object, ssl_context: ssl.SSLContext | None = None, **kwargs: object) -> None:
        self._ssl_context = ssl_context
        super().__init__(*args, **kwargs)
    response.raw = HTTPResponse(
        body=body,
        headers={**model.headers, **metadata_headers},
        status=model.status_code,
        preload_content=False,
        decode_content=False,
    )
    response.status_code = model.status_code
    response.headers.update(model.headers)
    response.headers.update(metadata_headers)
    response._content = body
    response._content_consumed = True
    response.url = ""

    def init_poolmanager(self, connections: int, maxsize: int, block: bool = False, **pool_kwargs: object) -> None:
        if self._ssl_context is not None:
            pool_kwargs["ssl_context"] = self._ssl_context
        super().init_poolmanager(connections=connections, maxsize=maxsize, block=block, **pool_kwargs)
    return response

    def proxy_manager_for(self, proxy: str, **proxy_kwargs: object) -> object:
        if self._ssl_context is not None:
            proxy_kwargs["ssl_context"] = self._ssl_context
        return super().proxy_manager_for(proxy, **proxy_kwargs)

class _NiquetsCacheAdapter(CacheAdapter):
    """Bridges niquests PreparedRequest to hishel CacheAdapter.

    niquests.PreparedRequest does not inherit requests.PreparedRequest,
    so hishel's isinstance checks would fail without this bridge.
    """

    def send(
        self,
        request: requests.models.PreparedRequest,
        **_kwargs: object,
    ) -> requests.models.Response:
        """Bridge niquests request/response objects to hishel cache internals."""
        internal_request = _niquests_to_internal_request(request)
        internal_response = self._cache_proxy.handle_request(internal_request)
        response = _internal_to_niquests_response(internal_response)
        response.request = request
        response.connection = self  # type: ignore[assignment]
        return response


@dataclass
@@ -107,20 +139,6 @@ def resolve_ssl_verify(
    return resolved


def _resolve_ssl_context(verify: bool | str) -> ssl.SSLContext | None:
    """Build SSL context for adapter-level TLS verification.

    Args:
        verify: Resolved SSL verification behavior.

    Returns:
        SSL context when adapter-level override is needed, otherwise None.
    """
    if verify is True:
        return TruststoreSSLContext(ssl.PROTOCOL_TLS_CLIENT)
    return None


def download_to_file(
    url: str,
    destination: Path,
@@ -216,7 +234,6 @@ def create_cached_session(

    # Use http_config for SSL if no explicit verify provided
    verify_mode = http_config.verify_ssl if verify is None and http_config is not None else resolve_ssl_verify(verify, http_config)
    ssl_context = _resolve_ssl_context(verify_mode)
    session.verify = verify_mode

    # Resolve cache enabled flag: explicit param → http_config → default True
@@ -224,24 +241,23 @@ def create_cached_session(
        http_cache_enabled = http_config.cache_enabled if http_config is not None else True

    # If caching is disabled, optionally configure a pooled HTTP adapter and return.
    # niquests handles SSL natively, so no custom SSL context injection is needed.
    if not http_cache_enabled:
        if pool_config is not None or ssl_context is not None:
            retry_attempts = pool_config.retry_attempts if pool_config and pool_config.enable_retry else 0
        if pool_config is not None:
            retry_attempts = pool_config.retry_attempts if pool_config.enable_retry else 0
            retry_strategy = Retry(
                total=retry_attempts,
                backoff_factor=1,
                status_forcelist=[429, 500, 502, 503, 504],
                allowed_methods=["HEAD", "GET", "OPTIONS"],
            )
            adapter = SSLContextHTTPAdapter(
                pool_connections=pool_config.max_connections if pool_config is not None else 10,
                pool_maxsize=pool_config.max_per_host if pool_config is not None else 10,
            adapter = HTTPAdapter(
                pool_connections=pool_config.max_connections,
                pool_maxsize=pool_config.max_per_host,
                max_retries=retry_strategy,
                ssl_context=ssl_context,
            )
            session.mount("http://", adapter)
            session.mount("https://", adapter)
            if pool_config is not None:
            logger.debug(
                f"Configured connection pool without caching: max_connections={pool_config.max_connections}, max_per_host={pool_config.max_per_host}"
            )
@@ -279,12 +295,12 @@ def create_cached_session(
    )

    # Create a single cache adapter with pool settings so both behaviors are active.
    cache_adapter = SSLContextCacheAdapter(
    # _NiquetsCacheAdapter bridges niquests<->hishel type incompatibility in send().
    cache_adapter = _NiquetsCacheAdapter(
        pool_connections=pool_config.max_connections if pool_config is not None else 10,
        pool_maxsize=pool_config.max_per_host if pool_config is not None else 10,
        max_retries=max_retries,
        storage=storage,
        ssl_context=ssl_context,
    )

    # Mount the cache adapter
@@ -299,8 +315,6 @@ def create_cached_session(

__all__ = [
    "PoolConfig",
    "SSLContextCacheAdapter",
    "SSLContextHTTPAdapter",
    "create_cached_session",
    "download_to_file",
    "resolve_ssl_verify",
+16 −32
Original line number Diff line number Diff line
@@ -14,7 +14,7 @@ from niquests.adapters import HTTPAdapter
from tdoc_crawler.config import DEFAULT_HTTP_CACHE_FILENAME
from tdoc_crawler.config.settings import HttpConfig
from tdoc_crawler.http_client import PoolConfig, create_cached_session, download_to_file
from tdoc_crawler.http_client.session import SSLContextCacheAdapter, SSLContextHTTPAdapter, resolve_ssl_verify
from tdoc_crawler.http_client.session import resolve_ssl_verify


class TestCreateCachedSession:
@@ -279,16 +279,6 @@ class TestResolveSslVerify:
        monkeypatch.setenv("TDC_VERIFY_SSL", "false")
        assert resolve_ssl_verify(True) is True

    def test_env_false(self, monkeypatch: pytest.MonkeyPatch) -> None:
        """Known false env values should disable certificate validation."""
        monkeypatch.setenv("TDC_VERIFY_SSL", "no")
        assert resolve_ssl_verify() is False

    def test_env_custom_ca_bundle(self, monkeypatch: pytest.MonkeyPatch) -> None:
        """Non-boolean env value should be treated as CA bundle path."""
        monkeypatch.setenv("TDC_VERIFY_SSL", "C:/certs/corp.pem")
        assert resolve_ssl_verify() == "C:/certs/corp.pem"

    def test_default_true(self, monkeypatch: pytest.MonkeyPatch) -> None:
        """SSL verification should be enabled by default."""
        monkeypatch.delenv("TDC_VERIFY_SSL", raising=False)
@@ -296,35 +286,36 @@ class TestResolveSslVerify:


class TestSslContextAdapters:
    """Tests for SSL-context aware adapter behavior."""
    """Tests for adapter behavior with SSL verification settings."""

    def test_non_cached_verify_true_uses_ssl_context_adapter(self) -> None:
        """Non-cached sessions should mount SSL-aware adapters when verify=True."""
    def test_non_cached_verify_true_uses_http_adapter(self) -> None:
        """Non-cached sessions use default niquests HTTPAdapter; SSL is handled natively."""
        session = create_cached_session(http_cache_enabled=False, verify=True)
        try:
            assert isinstance(session.adapters["http://"], SSLContextHTTPAdapter)
            assert isinstance(session.adapters["https://"], SSLContextHTTPAdapter)
            assert isinstance(session.adapters["http://"], HTTPAdapter)
            assert isinstance(session.adapters["https://"], HTTPAdapter)
            assert not isinstance(session.adapters["http://"], CacheAdapter)
            assert session.verify is True
        finally:
            session.close()

    def test_non_cached_verify_false_uses_default_adapter(self) -> None:
        """Non-cached verify=False should use requests defaults when no pool override is set."""
        """Non-cached verify=False should use default HTTPAdapter."""
        session = create_cached_session(http_cache_enabled=False, verify=False)
        try:
            assert isinstance(session.adapters["http://"], HTTPAdapter)
            assert isinstance(session.adapters["https://"], HTTPAdapter)
            assert not isinstance(session.adapters["https://"], SSLContextHTTPAdapter)
            assert not isinstance(session.adapters["http://"], CacheAdapter)
            assert session.verify is False
        finally:
            session.close()

    def test_cached_verify_true_uses_ssl_cache_adapter(self, test_cache_dir: Path) -> None:
        """Cached sessions should mount SSL-aware cache adapters when verify=True."""
    def test_cached_verify_true_uses_cache_adapter(self, test_cache_dir: Path) -> None:
        """Cached sessions mount a CacheAdapter regardless of verify setting."""
        session = create_cached_session(http_cache_file=test_cache_dir / DEFAULT_HTTP_CACHE_FILENAME, verify=True)
        try:
            assert isinstance(session.adapters["http://"], SSLContextCacheAdapter)
            assert isinstance(session.adapters["https://"], SSLContextCacheAdapter)
            assert isinstance(session.adapters["http://"], CacheAdapter)
            assert isinstance(session.adapters["https://"], CacheAdapter)
            assert session.verify is True
        finally:
            session.close()
@@ -353,14 +344,6 @@ class TestHttpSessionWithConfig:

        assert result is True

    def test_resolve_ssl_verify_fallback_to_env(self, monkeypatch: pytest.MonkeyPatch) -> None:
        """Without http_config, falls back to env var."""
        monkeypatch.setenv("TDC_VERIFY_SSL", "false")

        result = resolve_ssl_verify()

        assert result is False

    def test_create_cached_session_uses_config_cache_enabled(self, test_cache_dir: Path) -> None:
        """create_cached_session respects http_config.cache_enabled=False."""
        http_config = HttpConfig(cache_enabled=False)
@@ -388,8 +371,9 @@ class TestHttpSessionWithConfig:

        # Explicit False should override http_config.cache_enabled=True
        assert session is not None
        # Should have plain HTTP adapter since caching is explicitly disabled
        assert isinstance(session.adapters["http://"], SSLContextHTTPAdapter)
        # Should have plain HTTP adapter since caching is explicitly disabled (no pool_config)
        assert isinstance(session.adapters["http://"], HTTPAdapter)
        assert not isinstance(session.adapters["http://"], CacheAdapter)
        session.close()