Commit a8c73d3b authored by Jan Reimes's avatar Jan Reimes
Browse files

Fix: Cache embedding model across document processing

The embedding model was being reloaded for each document because
run_single_pipeline() and process_all() created new EmbeddingsManager
instances instead of using the singleton AiServiceContainer.

Now both functions use AiServiceContainer.get_instance().get_embeddings_manager()
to get the cached manager, ensuring the model is loaded once and reused
across all documents.

Also uses AiServiceContainer.get_ai_storage() for consistent storage access.
parent 7f1e381e
Loading
Loading
Loading
Loading
+8 −8
Original line number Diff line number Diff line
@@ -8,7 +8,7 @@ from collections.abc import Callable
from pathlib import Path
from typing import Any

from tdoc_crawler.ai.config import AiConfig
from tdoc_crawler.ai.container import AiServiceContainer
from tdoc_crawler.ai.models import (
    DocumentClassification,
    DocumentSummary,
@@ -17,7 +17,7 @@ from tdoc_crawler.ai.models import (
    ProcessingStatus,
)
from tdoc_crawler.ai.operations.classify import classify_document_files
from tdoc_crawler.ai.operations.embeddings import EmbeddingsManager, generate_embeddings
from tdoc_crawler.ai.operations.embeddings import generate_embeddings
from tdoc_crawler.ai.operations.extract import extract_from_folder
from tdoc_crawler.ai.operations.summarize import summarize_document
from tdoc_crawler.ai.operations.workspaces import list_workspace_members, normalize_workspace_name
@@ -330,9 +330,9 @@ def process_tdoc(
    Raises:
        Exception: Re-raises pipeline exceptions after logging.
    """
    config = AiConfig.from_env()
    manager = EmbeddingsManager(config=config)
    storage = AiStorage(config.ai_cache_dir, embedding_dimension=manager.dimension)  # type: ignore[arg-type]
    container = AiServiceContainer.get_instance()
    manager = container.get_embeddings_manager()
    storage = container.get_ai_storage()
    return run_pipeline(
        document_id,
        checkout_path,
@@ -363,9 +363,9 @@ def process_all(
    Returns:
        Dict mapping document_id to ProcessingStatus.
    """
    config = AiConfig.from_env()
    manager = EmbeddingsManager(config=config)
    storage = AiStorage(config.ai_cache_dir, embedding_dimension=manager.dimension)  # type: ignore[arg-type]
    container = AiServiceContainer.get_instance()
    manager = container.get_embeddings_manager()
    storage = container.get_ai_storage()
    normalized_workspace = normalize_workspace_name(workspace)

    # Get workspace members and build a lookup map