Commit 76993b4b authored by Jan Reimes's avatar Jan Reimes
Browse files

refactor: implement DI container for AI module and remove unsafe config access

- Add AiServiceContainer as singleton DI entry point for AiConfig, AiStorage, and EmbeddingsManager
- Replace ad-hoc AiStorage instantiations with container.get_storage() across all operations
- Remove direct AiConfig.from_env() usage in operational code; config now accessed via container
- Fix circular import between embeddings module and container using TYPE_CHECKING
- Fix EmbeddingsManager initialization to require explicit config and storage (no None defaults)
- Fix get_tdoc_evolution to require explicit storage parameter (dependency injection)
- Fix LanceDB schema handling: use list_tables() instead of deprecated table_names()
- Fix LanceDB vector schema for fixed-size list type
- Add helper functions (_is_nan, _is_nat, _table_to_records) for robust data handling
- Update get_ai_config() in ConfigService to delegate to AiServiceContainer
- Reorganize container methods: move class methods after instance methods for consistency
parent 8d70e7ae
Loading
Loading
Loading
Loading
+2 −1
Original line number Diff line number Diff line
@@ -5,6 +5,7 @@ from __future__ import annotations
import litellm

from tdoc_crawler.ai.config import AiConfig
from tdoc_crawler.ai.container import AiServiceContainer
from tdoc_crawler.ai.models import (
    DocumentChunk,
    DocumentClassification,
@@ -48,7 +49,6 @@ from tdoc_crawler.ai.operations.workspaces import (
    resolve_workspace,
)
from tdoc_crawler.ai.storage import AiStorage
from tdoc_crawler.ai.container import AiServiceContainer
from tdoc_crawler.config import CacheManager

litellm.suppress_debug_info = True  # Suppress provider/model info logs from litellm
@@ -64,6 +64,7 @@ _query_graph = query_graph
__all__ = [
    "DEFAULT_WORKSPACE",
    "AiConfig",
    "AiServiceContainer",
    "AiStorage",
    "CacheManager",
    "DocumentChunk",
+18 −20
Original line number Diff line number Diff line
@@ -12,11 +12,9 @@ The container implements lazy initialization and singleton pattern to ensure:

from __future__ import annotations

from pathlib import Path
from typing import Any

from tdoc_crawler.ai.config import AiConfig
from tdoc_crawler.ai.models import ProcessingStatus
from tdoc_crawler.ai.operations.embeddings import EmbeddingsManager
from tdoc_crawler.ai.storage import AiStorage

@@ -57,23 +55,6 @@ class AiServiceContainer:
            cls._instance._embeddings_manager = None
        return cls._instance

    @classmethod
    def get_instance(cls) -> AiServiceContainer:
        """Get the singleton container instance.

        Returns:
            AiServiceContainer singleton instance.
        """
        return cls()

    @classmethod
    def reset_for_testing(cls) -> None:
        """Reset the singleton for testing purposes.

        WARNING: Only use in tests, not in production code.
        """
        cls._instance = None

    def get_config(self) -> AiConfig:
        """Get the AI configuration singleton.

@@ -138,7 +119,7 @@ class AiServiceContainer:
        self._embeddings_manager = None
        self._config = None

    def __enter__(self) -> "AiServiceContainer":
    def __enter__(self) -> AiServiceContainer:
        """Context manager entry."""
        return self

@@ -146,6 +127,23 @@ class AiServiceContainer:
        """Context manager exit - ensures cleanup."""
        self.close()

    @classmethod
    def get_instance(cls) -> AiServiceContainer:
        """Get the singleton container instance.

        Returns:
            AiServiceContainer singleton instance.
        """
        return cls()

    @classmethod
    def reset_for_testing(cls) -> None:
        """Reset the singleton for testing purposes.

        WARNING: Only use in tests, not in production code.
        """
        cls._instance = None


# Convenience functions for backward compatibility
def get_ai_config() -> AiConfig:
+13 −13
Original line number Diff line number Diff line
@@ -5,7 +5,7 @@ import logging
import re
from collections.abc import Sequence
from pathlib import Path
from typing import Any, cast
from typing import TYPE_CHECKING, Any, cast

from sentence_transformers import SentenceTransformer

@@ -15,6 +15,9 @@ from tdoc_crawler.ai.operations.workspaces import normalize_workspace_name
from tdoc_crawler.ai.storage import AiStorage
from tdoc_crawler.utils.misc import utc_now

if TYPE_CHECKING:
    pass

logger = logging.getLogger(__name__)

# Chunk size settings
@@ -25,15 +28,13 @@ DEFAULT_OVERLAP = 50
class EmbeddingsManager:
    """Manages embedding model and storage for document embedding and retrieval."""

    def __init__(self, config: AiConfig | None = None, storage: AiStorage | None = None) -> None:
        """Initialize the embeddings manager.
    def __init__(self, config: AiConfig, storage: AiStorage) -> None:
        """Initialize embeddings manager.

        Args:
            config: AI configuration. If None, resolved from environment via AiConfig.from_env().
            storage: AI storage instance. If None, created lazily in generate_embeddings.
            config: AI configuration.
            storage: AI storage instance.
        """
        if config is None:
            config = AiConfig.from_env()
        self._config = config
        self._storage = storage
        # Use the full embedding_model identifier (e.g., 'perplexity-ai/pplx-embed-v1-0.6B')
@@ -84,11 +85,6 @@ class EmbeddingsManager:
        Returns:
            List of DocumentChunk objects with embeddings.
        """
        if self._storage is None:
            if self._config.ai_cache_dir is None:
                raise RuntimeError("ai_cache_dir is not configured in AiConfig")
            self._storage = AiStorage(self._config.ai_cache_dir, embedding_dimension=self.dimension)

        normalized_workspace = normalize_workspace_name(workspace)
        markdown_content = markdown.read_text(encoding="utf-8") if isinstance(markdown, Path) else markdown

@@ -141,7 +137,11 @@ class EmbeddingsManager:
        Returns:
            List of (DocumentChunk, score) tuples.
        """
        storage = self._storage or AiStorage(self._config.ai_cache_dir, embedding_dimension=self.dimension)
        if self._storage is None:
            from tdoc_crawler.ai.container import AiServiceContainer

            self._storage = AiServiceContainer.get_instance().get_storage()
        storage = self._storage
        normalized_workspace = normalize_workspace_name(workspace)

        # Encode query
+11 −29
Original line number Diff line number Diff line
@@ -6,7 +6,7 @@ import logging
import re
from datetime import datetime

from tdoc_crawler.ai.config import AiConfig
from tdoc_crawler.ai.container import AiServiceContainer
from tdoc_crawler.ai.models import GraphEdge, GraphEdgeType, GraphNode, GraphNodeType, GraphQueryResult
from tdoc_crawler.ai.operations.workspaces import normalize_workspace_name
from tdoc_crawler.ai.storage import AiStorage
@@ -143,12 +143,7 @@ def build_graph(
) -> tuple[list[GraphNode], list[GraphEdge]]:
    """Build knowledge graph from TDoc content with incremental updates."""
    if storage is None:
        config = AiConfig.from_env()
        store_path = config.ai_cache_dir
        if store_path is None:
            msg = "AI store path not configured"
            raise ValueError(msg)
        storage = AiStorage(store_path)
        storage = AiServiceContainer.get_instance().get_storage()

    existing_nodes, existing_edges = storage.query_graph(filters={}, workspace=workspace)
    existing_node_ids = {n.node_id for n in existing_nodes}
@@ -332,34 +327,29 @@ def build_graph(

def query_graph(
    query: str,
    workspace: str,
    node_types: list[GraphNodeType] | None = None,
    top_k: int = 10,
    meeting_ids: list[str] | None = None,
    date_range: tuple[datetime, datetime] | None = None,
    top_k: int = 10,
    storage: AiStorage | None = None,
) -> dict[str, list[GraphQueryResult]]:
    """Query the knowledge graph with temporal filtering and chronological sorting.
    workspace: str | None = None,
) -> dict:
    """Query knowledge graph for relevant nodes and edges.

    Args:
        query: Search query string.
        workspace: Workspace scope (required).
        query: Natural language query.
        node_types: Filter by node types.
        top_k: Number of results to return.
        meeting_ids: Filter by meeting IDs.
        date_range: Filter by date range.
        top_k: Maximum number of results to return.
        storage: Optional storage instance.
        workspace: Workspace name for filtering.

    Returns:
        Dict with 'results' key containing list of GraphQueryResult objects.
    """
    if storage is None:
        config = AiConfig.from_env()
        store_path = config.ai_cache_dir
        if store_path is None:
            msg = "AI store path not configured"
            raise ValueError(msg)
        storage = AiStorage(store_path)
        storage = AiServiceContainer.get_instance().get_storage()

    normalized_workspace = normalize_workspace_name(workspace)

@@ -392,16 +382,8 @@ def query_graph(
    return {"results": results}


def get_tdoc_evolution(document_id: str, storage: AiStorage | None = None) -> list[GraphNode]:
def get_tdoc_evolution(document_id: str, storage: AiStorage) -> list[GraphNode]:
    """Get evolution chain for a TDoc (revisions, supersessions)."""
    if storage is None:
        config = AiConfig.from_env()
        store_path = config.ai_cache_dir
        if store_path is None:
            msg = "AI store path not configured"
            raise ValueError(msg)
        storage = AiStorage(store_path)

    nodes, edges = storage.query_graph(filters={})
    related_ids = {document_id}

+3 −5
Original line number Diff line number Diff line
@@ -331,7 +331,7 @@ def process_tdoc(
        Exception: Re-raises pipeline exceptions after logging.
    """
    container = AiServiceContainer.get_instance()
    manager = container.get_embeddings_manager()
    container.get_embeddings_manager()
    storage = container.get_ai_storage()
    return run_pipeline(
        document_id,
@@ -364,7 +364,7 @@ def process_all(
        Dict mapping document_id to ProcessingStatus.
    """
    container = AiServiceContainer.get_instance()
    manager = container.get_embeddings_manager()
    container.get_embeddings_manager()
    storage = container.get_ai_storage()
    normalized_workspace = normalize_workspace_name(workspace)

@@ -417,9 +417,7 @@ def get_status(document_id: str, workspace: str | None = None) -> ProcessingStat
    Returns:
        ProcessingStatus if found, None otherwise.
    """
    config = AiConfig.from_env()
    manager = EmbeddingsManager(config=config)
    storage = AiStorage(config.ai_cache_dir, embedding_dimension=manager.dimension)  # type: ignore[arg-type]
    storage = AiServiceContainer.get_instance().get_storage()
    normalized_workspace = normalize_workspace_name(workspace)
    return storage.get_status(document_id, workspace=normalized_workspace)

Loading