Commit 0bcfa080 authored by Jan Reimes's avatar Jan Reimes
Browse files

feat(ai): add embedding backend configuration and update related components

* Introduced TDC_AI_EMBEDDING_BACKEND environment variable for backend selection.
* Updated AiConfig to include embedding backend with validation.
* Modified EmbeddingsManager to utilize the selected backend.
* Enhanced CLI commands to accept backend options via --accelerate flag.
* Added tests for embedding backend configuration and validation.
parent 983cf4a6
Loading
Loading
Loading
Loading
+5 −0
Original line number Diff line number Diff line
@@ -96,6 +96,11 @@ TDC_AI_LLM_API_BASE=
# See https://huggingface.co/models?library=sentence-transformers for alternatives
TDC_AI_EMBEDDING_MODEL=perplexity-ai/pplx-embed-context-v1-0.6b

# Embedding backend for sentence-transformers (default: torch)
# Options: torch (default), onnx (faster inference), openvino (Intel hardware optimization)
# Can also be set via --accelerate/-a CLI option for ai process command
TDC_AI_EMBEDDING_BACKEND=torch

# Activate workspace after creation (default: true)
# Set to "true", "1", or "yes" to enable; anything else disables it
TDC_AI_WORKSPACE_ACTIVATE=true
+1 −1
Original line number Diff line number Diff line
@@ -19,7 +19,7 @@ dependencies = [
    "kreuzberg[all]>=4.0.0",
    "lancedb>=0.29.2",
    "litellm>=1.81.15",
    "sentence-transformers[openvino]>=2.7.0",
    "sentence-transformers[openvino,onnx-gpu]>=2.7.0",
    "tokenizers>=0.22.2",
]

+18 −0
Original line number Diff line number Diff line
@@ -3,6 +3,7 @@
from __future__ import annotations

import os
from typing import Literal
from pathlib import Path

import litellm
@@ -14,6 +15,7 @@ from tdoc_crawler.models import BaseConfigModel
DEFAULT_EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
DEFAULT_LLM_MODEL = "openrouter/openrouter/free"

type Backend = Literal["torch", "onnx", "openvino"]

def _env_int(name: str) -> int | None:
    value = os.getenv(name)
@@ -85,6 +87,10 @@ class AiConfig(BaseConfigModel):
        DEFAULT_EMBEDDING_MODEL,
        description="Embedding model in <provider>/<model_name> format",
    )
    embedding_backend: Backend = Field(
        "torch",
        description="Sentence-transformers backend (torch, onnx, openvino)",
    )
    max_chunk_size: int = Field(1000, ge=1, description="Max tokens per chunk")
    chunk_overlap: int = Field(100, ge=0, description="Token overlap between chunks")

@@ -112,6 +118,8 @@ class AiConfig(BaseConfigModel):

        if embedding_model := os.getenv("TDC_AI_EMBEDDING_MODEL"):
            data["embedding_model"] = embedding_model
        if embedding_backend := os.getenv("TDC_AI_EMBEDDING_BACKEND"):
            data["embedding_backend"] = embedding_backend
        if llm_model := os.getenv("TDC_AI_LLM_MODEL"):
            data["llm_model"] = llm_model
        if llm_api_base := os.getenv("TDC_AI_LLM_API_BASE"):
@@ -167,6 +175,16 @@ class AiConfig(BaseConfigModel):
    def _validate_embedding_model(cls, value: str) -> str:
        return _validate_embedding_model_format(value)

    @field_validator("embedding_backend")
    @classmethod
    def _validate_embedding_backend(cls, value: str) -> str:
        normalized = value.strip().lower()
        allowed = {"torch", "onnx", "openvino"}
        if normalized not in allowed:
            msg = "embedding_backend must be one of: torch, onnx, openvino"
            raise ValueError(msg)
        return normalized

    @field_validator("llm_model")
    @classmethod
    def _validate_llm_model(cls, value: str) -> str:
+17 −4
Original line number Diff line number Diff line
@@ -5,11 +5,11 @@ import logging
import re
from collections.abc import Sequence
from pathlib import Path
from typing import Any, cast
from typing import Any, cast, Literal

from sentence_transformers import SentenceTransformer

from tdoc_ai.config import AiConfig
from tdoc_ai.config import AiConfig, Backend
from tdoc_ai.models import DocumentChunk
from tdoc_ai.operations.workspace_names import normalize_workspace_name
from tdoc_ai.storage import AiStorage
@@ -50,6 +50,11 @@ class EmbeddingsManager:
        # Use the full embedding_model identifier (e.g., 'perplexity-ai/pplx-embed-v1-0.6B')
        self._model_name: str = config.embedding_model

    @property
    def embedding_backend(self) -> Backend:
        """Return the embedding backend."""
        return self.config.embedding_backend

    @property
    def config(self) -> AiConfig:
        """Return the AI configuration."""
@@ -81,7 +86,11 @@ class EmbeddingsManager:
    def model(self) -> SentenceTransformer:
        """Return the sentence-transformers model, loading it lazily on first access."""
        if self._model is None:
            self._model = SentenceTransformer(self._model_name, trust_remote_code=True)
            self._model = SentenceTransformer(
                self._model_name,
                trust_remote_code=True,
                backend=self.embedding_backend,
            )
            logger.info(f"Loaded embedding model: {self._model_name}")

            # If storage not yet created, create it now
@@ -211,7 +220,11 @@ class EmbeddingsManager:
        Returns:
            EmbeddingsManager with .storage property available.
        """
        model = SentenceTransformer(config.embedding_model, trust_remote_code=True)
        model = SentenceTransformer(
            config.embedding_model,
            trust_remote_code=True,
            backend=config.embedding_backend,
        )
        dimension = cls._get_dimension(model)
        # ai_cache_dir is guaranteed to be set by AiConfig's _resolve_paths validator
        assert config.ai_cache_dir is not None, "ai_cache_dir must be set"
+17 −4
Original line number Diff line number Diff line
@@ -31,6 +31,7 @@ def run_pipeline( # noqa: PLR0915
    progress_callback: Callable[[PipelineStage, str], None] | None = None,
    force_rerun: bool = False,
    workspace: str | None = None,
    config: AiConfig | None = None,
) -> ProcessingStatus:
    """Run the complete AI processing pipeline for a TDoc.

@@ -107,7 +108,14 @@ def run_pipeline( # noqa: PLR0915
        if not force_rerun and status.embedded_at:
            logger.info(f"Skipping embedding for {document_id} - already embedded")
        else:
            _run_embedding_stage(document_id, folder_path, storage, status, workspace=normalized_workspace)
            _run_embedding_stage(
                document_id,
                folder_path,
                storage,
                status,
                workspace=normalized_workspace,
                config=config,
            )
            if progress_callback:
                progress_callback(PipelineStage.EMBEDDING, document_id)
    except Exception as e:
@@ -229,6 +237,7 @@ def _run_embedding_stage(
    storage: AiStorage,
    status: ProcessingStatus,
    workspace: str | None = None,
    config: AiConfig | None = None,
) -> None:
    """Run embedding stage based on extracted markdown artifact."""
    status.current_stage = PipelineStage.EMBEDDING
@@ -241,7 +250,7 @@ def _run_embedding_stage(
        raise FileNotFoundError(msg)

    # Get embeddings manager from config
    config = AiConfig.from_env()
    config = config or AiConfig.from_env()
    embeddings_manager = EmbeddingsManager.from_config(config)
    embeddings_manager.generate_embeddings(document_id, artifact_path, workspace=workspace)

@@ -255,6 +264,7 @@ def process_tdoc(
    checkout_path: Path,
    force_rerun: bool = False,
    workspace: str | None = None,
    config: AiConfig | None = None,
) -> ProcessingStatus:
    """Process a single TDoc through the AI pipeline.

@@ -272,7 +282,7 @@ def process_tdoc(
    Raises:
        Exception: Re-raises pipeline exceptions after logging.
    """
    config = AiConfig.from_env()
    config = config or AiConfig.from_env()
    embeddings_manager = EmbeddingsManager.from_config(config)
    storage = embeddings_manager.storage
    return run_pipeline(
@@ -281,6 +291,7 @@ def process_tdoc(
        storage,
        force_rerun=force_rerun,
        workspace=workspace,
        config=config,
    )


@@ -291,6 +302,7 @@ def process_all(
    force_rerun: bool = False,
    progress_callback: Callable[[PipelineStage, str], None] | None = None,
    workspace: str | None = None,
    config: AiConfig | None = None,
) -> dict[str, ProcessingStatus]:
    """Process multiple TDocs through the AI pipeline.

@@ -305,7 +317,7 @@ def process_all(
    Returns:
        Dict mapping document_id to ProcessingStatus.
    """
    config = AiConfig.from_env()
    config = config or AiConfig.from_env()
    embeddings_manager = EmbeddingsManager.from_config(config)
    storage = embeddings_manager.storage
    normalized_workspace = normalize_workspace_name(workspace)
@@ -340,6 +352,7 @@ def process_all(
                progress_callback=progress_callback,
                force_rerun=force_rerun,
                workspace=normalized_workspace,
                config=config,
            )
            results[document_id] = status
        except Exception as e:
Loading