Commit bf87c23a authored by Jan Reimes's avatar Jan Reimes
Browse files

feat(ai): enhance AI processing pipeline and workspace management

* Update embedding model to "sentence-transformers/all-MiniLM-L6-v2"
* Introduce new functions for processing TDocs and querying embeddings
* Improve workspace member management with new clear and remove commands
* Ensure embedding dimensions are handled correctly in storage
* Refactor CLI commands for better workspace interaction
parent e3a86341
Loading
Loading
Loading
Loading
+40 −22
Original line number Diff line number Diff line
@@ -48,24 +48,35 @@ def _pipeline_get_status(tdoc_id: str, workspace: str) -> ProcessingStatus | Non
    return _pipeline_get_status_impl(tdoc_id, workspace=workspace)


def query_embeddings(query: str, workspace: str, top_k: int = 10) -> list:
    """Query embeddings for a given query string."""
    return _query_embeddings(query, workspace=workspace, top_k=top_k)
def get_status(tdoc_id: str, workspace: str | None = None) -> ProcessingStatus | None:
    """Get processing status for a TDoc."""
    return _pipeline_get_status_impl(tdoc_id, workspace=workspace)


def query_graph(query: str, workspace: str, top_k: int = 10) -> list:
    """Query the knowledge graph for a given query string."""
    return _query_graph(query, workspace=workspace, top_k=top_k)
def process_tdoc(
    tdoc_id: str,
    checkout_path: Path,
    force_rerun: bool = False,
    workspace: str | None = None,
) -> ProcessingStatus:
    """Process a single TDoc through the AI pipeline."""
    return _pipeline_process_tdoc_impl(
        tdoc_id,
        checkout_path,
        force_rerun=force_rerun,
        workspace=workspace,
    )


def _pipeline_process_all(
def process_all(
    tdoc_ids: list[str],
    checkout_base: Path,
    new_only: bool,
    force_rerun: bool,
    progress_callback: Callable[[PipelineStage, str], None] | None,
    workspace: str,
    new_only: bool = False,
    force_rerun: bool = False,
    progress_callback: Callable[[PipelineStage, str], None] | None = None,
    workspace: str | None = None,
) -> dict[str, ProcessingStatus]:
    """Process multiple TDocs through the AI pipeline."""
    return _pipeline_process_all_impl(
        tdoc_ids,
        checkout_base,
@@ -76,18 +87,22 @@ def _pipeline_process_all(
    )


def _pipeline_process_tdoc(
    tdoc_id: str,
    checkout_path: Path,
    force_rerun: bool,
def query_embeddings(
    query: str,
    workspace: str,
) -> ProcessingStatus:
    return _pipeline_process_tdoc_impl(
        tdoc_id,
        checkout_path,
        force_rerun=force_rerun,
        workspace=workspace,
    )
    top_k: int = 5,
) -> list[tuple[DocumentChunk, float]]:
    """Query embeddings for semantic search."""
    return _query_embeddings(query, workspace, top_k)


def query_graph(
    query: str,
    workspace: str,
    top_k: int = 5,
) -> list[dict]:
    """Query knowledge graph."""
    return _query_graph(query, workspace, top_k)


__all__ = [
@@ -111,11 +126,14 @@ __all__ = [
    "delete_workspace",
    "ensure_ai_subfolder",
    "ensure_default_workspace",
    "get_status",
    "get_workspace",
    "is_default_workspace",
    "list_workspaces",
    "make_workspace_member",
    "normalize_workspace_name",
    "process_all",
    "process_tdoc",
    "query_embeddings",
    "query_graph",
    "resolve_tdoc_checkout_path",
+32 −3
Original line number Diff line number Diff line
@@ -9,7 +9,7 @@ from pydantic import Field, field_validator, model_validator

from tdoc_crawler.models import BaseConfigModel

DEFAULT_EMBEDDING_MODEL = "ollama/embeddinggemma"
DEFAULT_EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
DEFAULT_LLM_MODEL = "openrouter/openrouter/free"

LITELLM_PROVIDER_ALLOWLIST = {
@@ -67,6 +67,32 @@ def _validate_model_identifier(value: str, field_name: str) -> str:
    return f"{provider_normalized}/{model_name_normalized}"


def _validate_embedding_model_format(value: str) -> str:
    """Validate embedding model - accepts any HuggingFace-style model ID.

    Unlike LLM models, embedding models via sentence-transformers don't require
    LiteLLM provider validation. Accepts formats like:
    - sentence-transformers/all-MiniLM-L6-v2
    - perplexity-ai/pplx-embed-v1-0.6b
    """
    if "/" not in value:
        msg = "embedding_model must be in '<provider>/<model_name>' format"
        raise ValueError(msg)

    provider, model_name = value.split("/", 1)
    provider_normalized = provider.strip().lower()
    model_name_normalized = model_name.strip()

    if not provider_normalized:
        msg = "embedding_model provider segment cannot be empty"
        raise ValueError(msg)
    if not model_name_normalized:
        msg = "embedding_model model_name segment cannot be empty"
        raise ValueError(msg)

    return f"{provider_normalized}/{model_name_normalized}"


class AiConfig(BaseConfigModel):
    """Configuration for the AI processing pipeline."""

@@ -131,7 +157,10 @@ class AiConfig(BaseConfigModel):
    @model_validator(mode="after")
    def _resolve_paths(self) -> AiConfig:
        if self.ai_store_path is None:
            self.ai_store_path = self.cache_dir / ".ai" / "lancedb"
            # Include embedding model in path to avoid dimension conflicts
            # e.g., ~/.tdoc-crawler/.ai/lancedb/sentence-transformers/all-MiniLM-L6-v2
            # Keep slash to group models by provider
            self.ai_store_path = self.cache_dir / ".ai" / "lancedb" / self.embedding_model
        return self

    @model_validator(mode="after")
@@ -147,7 +176,7 @@ class AiConfig(BaseConfigModel):
    @field_validator("embedding_model")
    @classmethod
    def _validate_embedding_model(cls, value: str) -> str:
        return _validate_model_identifier(value, "embedding_model")
        return _validate_embedding_model_format(value)

    @field_validator("llm_model")
    @classmethod
+11 −4
Original line number Diff line number Diff line
@@ -71,6 +71,12 @@ class EmbeddingModelWrapper:
    @property
    def dimension(self) -> int:
        """Return embedding dimension."""
        if self._model is None:
            return DEFAULT_DIMENSION
        # Get dimension from the actual model
        try:
            return self._model.get_sentence_embedding_dimension()
        except Exception:
            return DEFAULT_DIMENSION

    @property
@@ -279,7 +285,8 @@ def _should_skip_embedding(tdoc_id: str, content_hash: str) -> bool:
    """
    try:
        config = AiConfig.from_env(cache_manager_name="default")
        storage_client = AiStorage(config.ai_store_path)  # type: ignore[arg-type]
        model = get_embedding_model()
        storage_client = AiStorage(config.ai_store_path, embedding_dimension=model.dimension)
        existing_hash = storage_client.get_embedding_hash(tdoc_id)
        return existing_hash == content_hash
    except Exception:
@@ -305,10 +312,10 @@ def generate_embeddings(
    """
    if storage is None:
        config = AiConfig.from_env(cache_manager_name="default")
        storage = AiStorage(config.ai_store_path)  # type: ignore[arg-type]
        model = get_embedding_model()
        storage = AiStorage(config.ai_store_path, embedding_dimension=model.dimension)

    normalized_workspace = normalize_workspace_name(workspace)

    markdown_content = markdown.read_text(encoding="utf-8") if isinstance(markdown, Path) else markdown

    # Compute content hash for idempotency
+1 −1
Original line number Diff line number Diff line
@@ -116,7 +116,7 @@ def extract_docx_to_markdown(
                ngram_range=(1, 3),
            ),
            language_detection=LanguageDetectionConfig(
                confidence_threshold=0.7,
                min_confidence=0.7,
            ),
            enable_quality_processing=True,
        )
+17 −11
Original line number Diff line number Diff line
@@ -8,6 +8,7 @@ from collections.abc import Callable
from pathlib import Path
from typing import Any

from tdoc_crawler.ai.config import AiConfig
from tdoc_crawler.ai.models import (
    DocumentClassification,
    DocumentSummary,
@@ -21,7 +22,6 @@ from tdoc_crawler.ai.operations.extract import extract_from_folder
from tdoc_crawler.ai.operations.summarize import summarize_document
from tdoc_crawler.ai.operations.workspaces import normalize_workspace_name
from tdoc_crawler.ai.storage import AiStorage
from tdoc_crawler.config import CacheManager
from tdoc_crawler.utils.misc import utc_now

logger = logging.getLogger(__name__)
@@ -330,7 +330,11 @@ def process_tdoc(
    Raises:
        Exception: Re-raises pipeline exceptions after logging.
    """
    storage = AiStorage(checkout_path.parent / ".ai" / "lancedb")
    from tdoc_crawler.ai.operations.embeddings import get_embedding_model

    config = AiConfig.from_env(cache_manager_name="default")
    model = get_embedding_model()
    storage = AiStorage(config.ai_store_path, embedding_dimension=model.dimension)
    return run_pipeline(
        tdoc_id,
        checkout_path,
@@ -361,7 +365,11 @@ def process_all(
    Returns:
        Dict mapping tdoc_id to ProcessingStatus.
    """
    storage = AiStorage(checkout_base / ".ai" / "lancedb")
    from tdoc_crawler.ai.operations.embeddings import get_embedding_model

    config = AiConfig.from_env(cache_manager_name="default")
    model = get_embedding_model()
    storage = AiStorage(config.ai_store_path, embedding_dimension=model.dimension)
    normalized_workspace = normalize_workspace_name(workspace)

    # Get workspace members and build a lookup map
@@ -371,7 +379,6 @@ def process_all(
        member_ids = {m.source_item_id for m in members if m.is_active and m.source_kind == "tdoc"}
        members_map = {m.source_item_id: m for m in members if m.is_active and m.source_kind == "tdoc"}
        tdoc_ids = [tid for tid in tdoc_ids if tid in member_ids]

    results: dict[str, ProcessingStatus] = {}
    for tdoc_id in tdoc_ids:
        if new_only and not force_rerun:
@@ -382,11 +389,7 @@ def process_all(

        # Use source_path from workspace member if available, otherwise fallback to default
        member = members_map.get(tdoc_id)
        folder_path = (
            Path(member.source_path)
            if member and member.source_path and Path(member.source_path).exists()
            else checkout_base / tdoc_id
        )
        folder_path = Path(member.source_path) if member and member.source_path and Path(member.source_path).exists() else checkout_base / tdoc_id

        if not folder_path.exists():
            logger.warning(f"Checkout folder not found: {folder_path}")
@@ -418,8 +421,11 @@ def get_status(tdoc_id: str, workspace: str | None = None) -> ProcessingStatus |
    Returns:
        ProcessingStatus if found, None otherwise.
    """
    manager = CacheManager().register()
    storage = AiStorage(manager.root / ".ai" / "lancedb")
    from tdoc_crawler.ai.operations.embeddings import get_embedding_model

    config = AiConfig.from_env(cache_manager_name="default")
    model = get_embedding_model()
    storage = AiStorage(config.ai_store_path, embedding_dimension=model.dimension)
    normalized_workspace = normalize_workspace_name(workspace)
    return storage.get_status(tdoc_id, workspace=normalized_workspace)

Loading