Commit 2782ccc3 authored by Jan Reimes's avatar Jan Reimes
Browse files

refactor(ai): improve cache management and enhance embedding queries

* Simplified cache manager initialization by using a decorator.
* Updated EmbeddingsManager instantiation to ensure proper configuration loading.
* Enhanced result formatting to use section headings instead of sections.
* Added progress tracking for document processing phases.
parent 2730e4fe
Loading
Loading
Loading
Loading
+70 −35
Original line number Diff line number Diff line
@@ -4,6 +4,7 @@ from __future__ import annotations

import json
from datetime import UTC, datetime
from functools import cache
from pathlib import Path
from typing import Annotated

@@ -47,8 +48,8 @@ from tdoc_crawler.cli.args import (
    CheckoutPathOption,
    ConvertDocumentArgument,
    ConvertOutputOption,
    EmbeddingTopKOption,
    EmbeddingBackendOption,
    EmbeddingTopKOption,
    EndDateOption,
    GraphQueryArgument,
    GraphQueryOption,
@@ -91,17 +92,11 @@ from tdoc_crawler.utils.date_parser import parse_partial_date
ai_app = typer.Typer(help="AI document processing commands")
console = Console()

# Global cache manager instance - lazily initialized
_default_cache_manager: CacheManager | None = None


@cache
def _get_cache_manager() -> CacheManager:
    """Get or create the default cache manager (avoids repeated register() calls)."""
    global _default_cache_manager
    if _default_cache_manager is None:
        _default_cache_manager = CacheManager(name="default")
        _default_cache_manager.register()
    return _default_cache_manager
    return CacheManager().register()


def resolve_workspace(workspace: str | None) -> str:
@@ -184,7 +179,8 @@ def ai_query(
        raise typer.Exit(1)

    resolved_workspace = resolve_workspace(workspace)
    embeddings_manager = EmbeddingsManager.from_config(AiConfig.from_env())
    _get_cache_manager()  # Ensure cache manager is registered before config loads
    embeddings_manager = EmbeddingsManager(AiConfig.from_env())

    embedding_results = embeddings_manager.query_embeddings(query_text, workspace or "default", top_k)

@@ -193,7 +189,7 @@ def ai_query(
    # Format results as expected by tests: {"query": ..., "results": [...]}
    formatted_results = []
    for chunk, score in embedding_results:
        formatted_results.append({"document_id": chunk.document_id, "section": chunk.section, "content": chunk.content, "score": score})
        formatted_results.append({"document_id": chunk.document_id, "section": chunk.section_heading, "content": chunk.content, "score": score})
    payload = {"query": query_text, "results": formatted_results}

    if json_output:
@@ -207,7 +203,7 @@ def ai_query(
            table.add_column("Snippet", style="white")
            for chunk, score in embedding_results:
                snippet = chunk.content[:120].replace("\n", " ")
                table.add_row(chunk.document_id, str(chunk.section or ""), snippet, f"{score:.3f}")
                table.add_row(chunk.document_id, str(chunk.section_heading or ""), snippet, f"{score:.3f}")
            console.print(table)
        else:
            console.print("[yellow]No embedding results found.[/yellow]")
@@ -339,6 +335,7 @@ def ai_graph(
        console.print("[red]Error: query is required (positional or --query).[/red]")
        raise typer.Exit(1)

    _get_cache_manager()  # Ensure cache manager is registered before config loads
    # Tests expect None as default, so don't convert to "default"
    result = query_graph(query_text, workspace=workspace)

@@ -529,7 +526,7 @@ def workspace_clear(
) -> None:
    """Clear all AI artifacts (embeddings, summaries, etc.) while preserving workspace members."""
    workspace = resolve_workspace(workspace)
    storage = EmbeddingsManager.from_config(AiConfig.from_env()).storage
    storage = EmbeddingsManager(AiConfig.from_env()).storage

    removed_count = storage.clear_workspace_artifacts(workspace)

@@ -712,7 +709,56 @@ def workspace_process(
            console.print(f"[yellow]No active members found in workspace '{normalize_workspace_name(workspace)}'[/yellow]")
        return

    # Create progress bar for tracking document processing
    # Create progress bars for each phase
    # Phase 1: Extract (Classify + Extract)
    # Phase 2: Embed
    # Phase 3: Graph
    phase_names = {
        1: "Extracting (Classify + Extract)",
        2: "Embedding",
        3: "Building Graph",
    }

    # Track phase progression
    current_phase = 0
    processed_in_phase: dict[int, set[str]] = {1: set(), 2: set(), 3: set()}

    def progress_callback(stage: PipelineStage, doc_id: str) -> None:
        nonlocal current_phase

        # Determine which phase this callback belongs to
        if stage in (PipelineStage.CLASSIFYING, PipelineStage.EXTRACTING):
            phase = 1
        elif stage in (PipelineStage.EMBEDDING,):
            phase = 2
        elif stage in (PipelineStage.GRAPHING, PipelineStage.COMPLETED):
            phase = 3
        else:
            phase = current_phase

        # Phase transition - start new progress bar
        if phase != current_phase:
            current_phase = phase
            # Update all progress bars for phase transition
            for p in (1, 2, 3):
                if p in progress.tasks:
                    task = progress.tasks[p]
                    if p < phase:
                        progress.update(task, completed=task.total, description=f"[green]{phase_names[p]} - Done")
                    elif p == phase:
                        progress.update(task, description=f"[cyan]{phase_names[p]}")

        # Update current phase progress
        if phase in progress.tasks and doc_id not in processed_in_phase[phase]:
            processed_in_phase[phase].add(doc_id)
            task = progress.tasks[phase]
            progress.update(
                task,
                advance=1,
                description=f"[cyan]{phase_names[phase]}: {doc_id}",
            )

    # Run pipeline with progress tracking
    with Progress(
        SpinnerColumn(),
        TextColumn("[progress.description]{task.description}", justify="right"),
@@ -723,28 +769,17 @@ def workspace_process(
        console=console,
        refresh_per_second=10,
    ) as progress:
        # Track processed count for progress updates
        processed_count = [0]  # Use list to allow modification in callback

        # Create main task
        task = progress.add_task(
            f"[cyan]Processing workspace '{normalize_workspace_name(workspace)}'[/cyan]",
        # Add tasks for each phase (initially hidden)
        tasks = {}
        for phase in (1, 2, 3):
            tasks[phase] = progress.add_task(
                f"[dim]{phase_names[phase]}",
                total=len(document_ids),
                visible=False,
            )

        # Create progress callback
        def progress_callback(stage: PipelineStage, doc_id: str) -> None:
            # Lazy import to avoid circular dependency with PipelineStage
            # PipelineStage is defined in models, imported at top level
            # Update description with current document and stage
            stage_name = stage.value.replace("_", " ").title()
            processed_count[0] += 1
            remaining = len(document_ids) - processed_count[0]
            progress.update(
                task,
                advance=1,
                description=f"[cyan]{processed_count[0]}/{len(document_ids)}[/cyan] [cyan]Processing {doc_id}[/cyan] [dim]{remaining} remaining[/dim] [dim]- {stage_name}[/dim]",
            )
        # Show first phase
        progress.update(tasks[1], visible=True)

        # Process documents with progress tracking
        results = process_all(