Commit 2730e4fe authored by Jan Reimes's avatar Jan Reimes
Browse files

refactor(ai): streamline EmbeddingsManager initialization and config handling

* Remove unnecessary parameters from EmbeddingsManager constructor.
* Update ai_cache_dir handling to use a property for better clarity.
* Simplify storage retrieval in summarize and conftest modules.
parent 41e07f86
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -69,7 +69,7 @@ def create_embeddings_manager(config: AiConfig | None = None) -> EmbeddingsManag
    """
    if config is None:
        config = AiConfig.from_env()
    return EmbeddingsManager.from_config(config)
    return EmbeddingsManager(config)


# Backward compatibility alias
+6 −13
Original line number Diff line number Diff line
@@ -85,8 +85,6 @@ def _validate_embedding_model_format(value: str) -> str:
class AiConfig(BaseConfigModel):
    """Configuration for the AI processing pipeline."""

    ai_cache_dir: Path | None = Field(None, description="Path to AI cache directory")

    embedding_model: str = Field(
        DEFAULT_EMBEDDING_MODEL,
        description="Embedding model in <provider>/<model_name> format",
@@ -122,9 +120,6 @@ class AiConfig(BaseConfigModel):
        if cache_manager_name := overrides.get("cache_manager_name"):
            data["cache_manager_name"] = cache_manager_name

        # NOTE: ai_cache_dir is NOT set here - it will be resolved in _resolve_paths
        # validator using ai_embed_dir(embedding_model) to include provider/model subdirectory

        if embedding_model := os.getenv("TDC_AI_EMBEDDING_MODEL"):
            data["embedding_model"] = embedding_model
        if embedding_backend := os.getenv("TDC_AI_EMBEDDING_BACKEND"):
@@ -162,15 +157,13 @@ class AiConfig(BaseConfigModel):
        filtered_data = {k: v for k, v in data.items() if v is not None}
        return cls(**filtered_data)

    @model_validator(mode="after")
    def _resolve_paths(self) -> AiConfig:
        if self.ai_cache_dir is None:

    @property
    def ai_cache_dir(self) -> Path:
        # Use CacheManager to resolve the embedding directory
        # e.g., ~/.tdoc-crawler/.ai/sentence-transformers/all-MiniLM-L6-v2
        # The ai_embed_dir method handles the provider/model subdirectory structure
            self.ai_cache_dir = resolve_cache_manager(self.cache_manager_name).ai_embed_dir(self.embedding_model)

        return self
        return resolve_cache_manager(self.cache_manager_name).ai_embed_dir(self.embedding_model)

    @model_validator(mode="after")
    def _validate_bounds(self) -> AiConfig:
+17 −58
Original line number Diff line number Diff line
@@ -5,7 +5,7 @@ import logging
import re
from collections.abc import Sequence
from pathlib import Path
from typing import Any, cast, Literal
from typing import Any, cast

from sentence_transformers import SentenceTransformer

@@ -27,28 +27,22 @@ class EmbeddingsManager:
    """Manages embedding model and storage for document embedding and retrieval.

    This class handles both the embedding model lifecycle and storage initialization.
    Use the `from_config()` class method for automatic storage creation with the
    correct embedding dimension derived from the model.
    """

    def __init__(
        self,
        config: AiConfig,
        storage: AiStorage | None = None,
        model: SentenceTransformer | None = None,
        config: AiConfig
    ) -> None:
        """Initialize embeddings manager.

        Args:
            config: AI configuration.
            storage: Optional pre-created storage. If None, created lazily via from_config().
            model: Optional pre-loaded model. If None, loaded lazily.
        """
        self._config = config
        self._storage = storage
        self._model = model
        # Use the full embedding_model identifier (e.g., 'perplexity-ai/pplx-embed-v1-0.6B')
        self._model_name: str = config.embedding_model

        # will be initialized lazily on first access to allow for faster startup and to ensure correct device selection based on backend
        self._storage = None
        self._model = None

    @property
    def embedding_backend(self) -> Backend:
@@ -74,12 +68,11 @@ class EmbeddingsManager:
        Raises:
            RuntimeError: If neither storage nor model was provided at init.
        """

        if self._storage is None:
            # Force model load which creates storage
            _ = self.model
        if self._storage is None:
            msg = "No storage available. Use from_config() or provide storage at init."
            raise RuntimeError(msg)
            self._storage = AiStorage(self.config.ai_cache_dir, embedding_dimension=self.dimension)
            logger.info(f"Initialized AiStorage at {self.config.ai_cache_dir} with dimension {self.dimension}")

        return self._storage

    @property
@@ -87,18 +80,11 @@ class EmbeddingsManager:
        """Return the sentence-transformers model, loading it lazily on first access."""
        if self._model is None:
            self._model = SentenceTransformer(
                self._model_name,
                self._config.embedding_model,
                trust_remote_code=True,
                backend=self.embedding_backend,
                device=None, # uses GPU if available, otherwise CPU (note: OpenVINO backend requires device="cpu")
            )
            logger.info(f"Loaded embedding model: {self._model_name}")

            # If storage not yet created, create it now
            if self._storage is None:
                dimension = self._get_dimension(self._model)
                # ai_cache_dir is guaranteed to be set by AiConfig's _resolve_paths validator
                assert self._config.ai_cache_dir is not None, "ai_cache_dir must be set"
                self._storage = AiStorage(self._config.ai_cache_dir, embedding_dimension=dimension)
            logger.info(f"Loaded embedding model: {self._config.embedding_model}")

        return self._model

@@ -107,13 +93,13 @@ class EmbeddingsManager:
        """Return embedding dimension."""
        dim = self.model.get_sentence_embedding_dimension()
        if dim is None:
            raise RuntimeError(f"Model '{self._model_name}' did not report an embedding dimension")
            raise RuntimeError(f"Model '{self._config.embedding_model}' did not report an embedding dimension")
        return dim

    @property
    def model_name(self) -> str:
        """Return model name."""
        return self._model_name
        return self._config.embedding_model

    def encode(self, texts: Sequence[str]) -> list[list[float]]:
        """Encode texts to embeddings.
@@ -151,7 +137,7 @@ class EmbeddingsManager:
            return []

        # Create chunks
        chunks = self._create_chunks(document_id, markdown_content, self._model_name)
        chunks = self._create_chunks(document_id, markdown_content, self.model_name)

        if not chunks:
            return []
@@ -166,7 +152,7 @@ class EmbeddingsManager:
            # Attach embeddings to chunks
            for chunk, embedding in zip(chunks, embeddings, strict=False):
                chunk.embedding = [float(value) for value in embedding]
                chunk.embedding_model = self._model_name
                chunk.embedding_model = self._config.embedding_model

            storage.save_chunks(chunks, workspace=normalized_workspace)

@@ -204,33 +190,6 @@ class EmbeddingsManager:
        # Search in storage
        return storage.search_chunks(query_vector, top_k, workspace=normalized_workspace)

    @classmethod
    def from_config(cls, config: AiConfig) -> EmbeddingsManager:
        """Create EmbeddingsManager with storage from configuration.

        This is the recommended factory method. It:
        1. Loads the embedding model
        2. Extracts dimension from model
        3. Creates AiStorage with correct path and dimension
        4. Returns fully initialized manager

        Args:
            config: AI configuration with embedding_model and ai_cache_dir set.

        Returns:
            EmbeddingsManager with .storage property available.
        """
        model = SentenceTransformer(
            config.embedding_model,
            trust_remote_code=True,
            backend=config.embedding_backend,
        )
        dimension = cls._get_dimension(model)
        # ai_cache_dir is guaranteed to be set by AiConfig's _resolve_paths validator
        assert config.ai_cache_dir is not None, "ai_cache_dir must be set"
        storage = AiStorage(config.ai_cache_dir, embedding_dimension=dimension)
        return cls(config=config, storage=storage, model=model)

    @classmethod
    def _chunk_by_paragraphs(cls, text: str, max_chars: int = DEFAULT_MAX_CHARS) -> list[str]:
        """Split text into chunks by paragraphs.
+3 −7
Original line number Diff line number Diff line
@@ -10,6 +10,7 @@ import re

import litellm

from tdoc_ai.operations.embeddings import EmbeddingsManager
from tdoc_ai.config import AiConfig
from tdoc_ai.models import DocumentSummary, LlmConfigError, SummarizeResult
from tdoc_ai.operations.workspace_names import normalize_workspace_name
@@ -141,10 +142,7 @@ def _should_skip_summary(
            existing = storage_client.get_summary_hash(document_id)
            return existing == content_hash

        # Get storage via EmbeddingsManager
        from tdoc_ai.operations.embeddings import EmbeddingsManager

        storage_client = EmbeddingsManager.from_config(AiConfig.from_env()).storage
        storage_client = EmbeddingsManager(AiConfig.from_env()).storage
        existing = storage_client.get_summary_hash(document_id)
        return existing == content_hash
    except Exception:
@@ -207,10 +205,8 @@ def summarize_document(
        DocumentSummary object.
    """
    config = AiConfig.from_env()
    # Get storage via EmbeddingsManager
    from tdoc_ai.operations.embeddings import EmbeddingsManager

    storage_client = storage or EmbeddingsManager.from_config(config).storage
    storage_client = storage or EmbeddingsManager(config).storage

    normalized_workspace = normalize_workspace_name(workspace)

+1 −1
Original line number Diff line number Diff line
@@ -181,7 +181,7 @@ def ai_storage(test_cache_dir: Path) -> AiStorage:
    # Patch environment variable to point to test cache directory
    with patch.dict("os.environ", {"TDOC_CRAWLER_AI_CACHE_DIR": str(test_cache_dir / "ai")}):
        # Use new factory method
        embeddings_manager = EmbeddingsManager.from_config(AiConfig.from_env())
        embeddings_manager = EmbeddingsManager(AiConfig.from_env())
        return embeddings_manager.storage