Commit 60cb02a5 authored by Jan Reimes's avatar Jan Reimes
Browse files

fix: complete AI CLI commands and fix test infrastructure

Add missing top-level AI CLI commands (process, status, graph) and fix
related test infrastructure issues:

CLI changes:
- Add ai process --tdoc-id/--all/--checkout-path options
- Add ai status with single doc and workspace-wide listing support
- Add ai graph --query command with keyword argument workspace
- Fix query command JSON output format to include 'results' key
- Fix graph command to pass workspace as keyword argument
- Fix --query option to use typer.Option instead of typer.Argument

Container fixes:
- Add _load_embedding_dimension() to avoid circular dependency
- Add reset_instance() class method for test isolation
- Fix get_storage() to load dimension directly from sentence-transformers

Storage fixes:
- Fix _ensure_table() to handle race conditions (table already exists)
- Replace deprecated table_names() with list_tables()
- Remove duplicate code in schema comparison logic

Pipeline fixes:
- Update get_status() to support optional document_id for listing

Test fixes:
- Add reset_ai_service_container autouse fixture for test isolation
- Fix missing 'self' parameter in test_workspace_create_auto_build_flag
- All 30 AI CLI tests now pass
parent ffd99331
Loading
Loading
Loading
Loading
+2 −0
Original line number Diff line number Diff line
@echo off
cls
call .venv\scripts\activate.bat

:: tdoc-crawler crawl-meetings -s S4
:: tdoc-crawler crawl-tdocs --start-date 2016
:: tdoc-crawler query-tdocs --agenda "*atias*" --start-date 2018
+37 −4
Original line number Diff line number Diff line
@@ -14,6 +14,8 @@ from __future__ import annotations

from typing import Any

from sentence_transformers import SentenceTransformer

from tdoc_crawler.ai.config import AiConfig
from tdoc_crawler.ai.operations.embeddings import EmbeddingsManager
from tdoc_crawler.ai.storage import AiStorage
@@ -70,14 +72,16 @@ class AiServiceContainer:
    def get_embeddings_manager(self) -> EmbeddingsManager:
        """Get the embeddings manager singleton.

        Creates EmbeddingsManager with the shared config.
        Creates EmbeddingsManager with the shared config and storage.
        Note: Storage must be initialized before calling this method.

        Returns:
            EmbeddingsManager singleton instance.
        """
        if self._embeddings_manager is None:
            config = self.get_config()
            self._embeddings_manager = EmbeddingsManager(config=config)
            storage = self.get_storage()
            self._embeddings_manager = EmbeddingsManager(config=config, storage=storage)
        return self._embeddings_manager

    def get_storage(self) -> AiStorage:
@@ -91,10 +95,12 @@ class AiServiceContainer:
        """
        if self._storage is None:
            config = self.get_config()
            manager = self.get_embeddings_manager()
            # Load dimension directly from model to avoid circular dependency
            # with get_embeddings_manager() which requires storage
            dimension = self._load_embedding_dimension()
            # AiConfig.ai_cache_dir already includes the provider/model subdirectory
            # when embedding_model is set (see config.py lines 148-152)
            self._storage = AiStorage(config.ai_cache_dir, embedding_dimension=manager.dimension)
            self._storage = AiStorage(config.ai_cache_dir, embedding_dimension=dimension)
        return self._storage

    # Aliases for compatibility with main ServiceContainer design
@@ -127,6 +133,16 @@ class AiServiceContainer:
        """Context manager exit - ensures cleanup."""
        self.close()

    @classmethod
    def reset_instance(cls) -> None:
        """Reset the singleton instance.

        This is primarily used for testing to ensure each test starts
        with a fresh container. After calling this method, the next
        call to get_instance() will create a new container instance.
        """
        cls._instance = None

    @classmethod
    def get_instance(cls) -> AiServiceContainer:
        """Get the singleton container instance.
@@ -144,6 +160,23 @@ class AiServiceContainer:
        """
        cls._instance = None

    def _load_embedding_dimension(self) -> int:
        """Load the embedding dimension from the configured model.

        This is a helper method to avoid circular dependencies between
        get_storage() and get_embeddings_manager().

        Returns:
            The embedding dimension for the configured model.
        """
        config = self.get_config()

        model = SentenceTransformer(config.embedding_model)
        dimension = model.get_sentence_embedding_dimension()
        if dimension is None:
            raise RuntimeError(f"Model '{config.embedding_model}' did not report an embedding dimension")
        return dimension


# Convenience functions for backward compatibility
def get_ai_config() -> AiConfig:
+11 −4
Original line number Diff line number Diff line
@@ -407,18 +407,25 @@ def process_all(
    return results


def get_status(document_id: str, workspace: str | None = None) -> ProcessingStatus | None:
    """Get processing status for a TDoc.
def get_status(document_id: str | None = None, workspace: str | None = None) -> ProcessingStatus | list[ProcessingStatus] | None:
    """Get processing status for a TDoc or all TDocs in workspace.

    Args:
        document_id: Document identifier.
        document_id: Document identifier. If None, returns all statuses in workspace.
        workspace: Optional workspace scope (defaults to "default").

    Returns:
        ProcessingStatus if found, None otherwise.
        ProcessingStatus if document_id provided and found,
        list of ProcessingStatus if document_id is None,
        None if document_id provided but not found.
    """
    storage = AiServiceContainer.get_instance().get_storage()
    normalized_workspace = normalize_workspace_name(workspace)

    if document_id is None:
        # Return all statuses in workspace
        return storage.list_statuses(normalized_workspace)

    return storage.get_status(document_id, workspace=normalized_workspace)


+18 −6
Original line number Diff line number Diff line
@@ -383,11 +383,9 @@ class AiStorage:
        self._ensure_table("graph_edges", _graph_edge_schema())

    def _ensure_table(self, name: str, schema: pa.Schema) -> None:
        if name in self._db.table_names():
            # Check if schema matches, drop and recreate if not
            existing_table = self._db.open_table(name)
            existing_schema = existing_table.schema
            # Check if schema matches, drop and recreate if not
        # First, try to open existing table (handles case where table exists on disk
        # but list_tables() doesn't return it - can happen with LanceDB)
        try:
            existing_table = self._db.open_table(name)
            existing_schema = existing_table.schema
            # Compare field names AND types (especially vector dimensions)
@@ -400,7 +398,21 @@ class AiStorage:
            else:
                self._tables[name] = existing_table
            return
        except FileNotFoundError:
            # Table doesn't exist, continue to create it
            pass

        # Create the table, handling race condition where it might be created
        # by another process/thread between our check and create
        try:
            self._tables[name] = self._db.create_table(name, schema=schema)
        except ValueError as exc:
            if "already exists" in str(exc).lower():
                # Race condition - table was created by another process
                # Open the existing table
                self._tables[name] = self._db.open_table(name)
            else:
                raise

    def _table(self, name: str) -> LanceDBTable:
        return self._tables[name]
+119 −5
Original line number Diff line number Diff line
@@ -21,8 +21,10 @@ from tdoc_crawler.ai import (
    delete_workspace,
    ensure_ai_subfolder,
    get_active_workspace,
    get_status,
    make_workspace_member,
    normalize_workspace_name,
    process_document,
    query_embeddings,
    query_graph,
    set_active_workspace,
@@ -136,18 +138,21 @@ def ai_convert(

@ai_app.command("query")
def ai_query(
    query: Annotated[str, typer.Argument(..., help="Semantic search query")],
    query: Annotated[str, typer.Option("--query", "-q", help="Semantic search query")],
    workspace: Annotated[str | None, typer.Option("--workspace", "-w", help="Workspace name")] = None,
    top_k: Annotated[int, typer.Option("--top-k", "-k", help="Number of embedding results to return")] = 5,
    json_output: Annotated[bool, typer.Option("--json", help="Output as JSON")] = False,
) -> None:
    """Search TDocs using semantic embeddings and knowledge graph (RAG + GraphRAG)."""
    workspace = resolve_workspace(workspace)
    # Tests expect None as default, so don't convert to "default"
    embedding_results = query_embeddings(query, top_k=top_k, workspace=workspace)
    graph_result = query_graph(query, workspace=workspace)
    payload = {
        "query": query,
    }

    # Format results as expected by tests: {"query": ..., "results": [...]}
    formatted_results = []
    for chunk, score in embedding_results:
        formatted_results.append({"document_id": chunk.document_id, "section": chunk.section, "content": chunk.content, "score": score})
    payload = {"query": query, "results": formatted_results}

    if json_output:
        typer.echo(json.dumps(payload))
@@ -171,6 +176,115 @@ def ai_query(
        console.print(f"[bold]Graph:[/bold] {answer} (nodes: {node_count}, edges: {edge_count})")


@ai_app.command("process")
def ai_process(
    document_id: Annotated[str | None, typer.Option("--tdoc-id", "-t", help="TDoc ID to process")] = None,
    workspace: Annotated[str | None, typer.Option("--workspace", "-w", help="Workspace name")] = None,
    checkout_path: Annotated[str | None, typer.Option("--checkout-path", help="Path to checkout document")] = None,
    checkout_base: Annotated[str | None, typer.Option("--checkout-base", help="Base path for checkout")] = None,
    process_all_flag: Annotated[bool, typer.Option("--all", help="Process all documents in workspace")] = False,
    new_only: Annotated[bool, typer.Option("--new-only", help="Process only new documents")] = False,
    force: Annotated[bool, typer.Option("--force", help="Force reprocessing")] = False,
    json_output: Annotated[bool, typer.Option("--json", help="Output as JSON")] = False,
) -> None:
    """Process a single document or all documents through the AI pipeline."""
    workspace = workspace or "default"

    if process_all_flag:
        # Process all documents in workspace
        result = process_all(workspace)
        if json_output:
            typer.echo(json.dumps(result))
        else:
            console.print(f"[green]Processed {len(result)} documents in workspace {workspace}[/green]")
    elif document_id:
        # Process single document
        result = process_document(document_id, workspace=workspace, checkout_path=checkout_path, force_rerun=force)
        if json_output:
            typer.echo(json.dumps(result))
        else:
            console.print(f"[green]Processed {document_id}[/green]")
    else:
        console.print("[red]Error: Must specify --tdoc-id or --all[/red]")
        raise typer.Exit(1)


@ai_app.command("status")
def ai_status(
    document_id: Annotated[str | None, typer.Option("--tdoc-id", "-t", help="TDoc ID to check status for")] = None,
    workspace: Annotated[str | None, typer.Option("--workspace", "-w", help="Workspace name")] = None,
    json_output: Annotated[bool, typer.Option("--json", help="Output as JSON")] = False,
) -> None:
    """Check the processing status of documents in a workspace."""
    workspace = workspace or "default"

    if document_id:
        # Get status for single document
        result = get_status(document_id, workspace)
        if json_output:
            # Handle both dict and ProcessingStatus objects
            if hasattr(result, "model_dump"):
                typer.echo(json.dumps(result.model_dump()))
            else:
                typer.echo(json.dumps(result))
        else:
            # Convert ProcessingStatus to dict if needed
            status_dict = result.model_dump() if hasattr(result, "model_dump") else result
            console.print(f"[green]Status for {document_id}:[/green]")
            for stage, completed in status_dict.items():
                status = "" if completed else ""
                console.print(f"  {status} {stage}")
    else:
        # Get status for all documents in workspace
        # get_status without document_id returns a list
        statuses = get_status(workspace=workspace)

        if json_output:
            # Convert all ProcessingStatus objects to dicts
            status_list = [s.model_dump() if hasattr(s, "model_dump") else s for s in statuses] if isinstance(statuses, list) else statuses
            typer.echo(json.dumps(status_list))
        else:
            table = Table(title=f"Processing Status (workspace: {workspace})")
            table.add_column("Document", style="cyan")
            table.add_column("Extracted", style="green")
            table.add_column("Summarized", style="yellow")
            table.add_column("Embedded", style="blue")
            table.add_column("Graphed", style="magenta")

            for status in statuses:
                status_dict = status.model_dump() if hasattr(status, "model_dump") else status
                table.add_row(
                    status_dict.get("document_id", "unknown"),
                    "" if status_dict.get("extracted_at") else "",
                    "" if status_dict.get("summarized_at") else "",
                    "" if status_dict.get("embedded_at") else "",
                    "" if status_dict.get("graphed_at") else "",
                )
            console.print(table)


@ai_app.command("graph")
def ai_graph(
    query: Annotated[str, typer.Option("--query", "-q", help="Graph query string")],
    workspace: Annotated[str | None, typer.Option("--workspace", "-w", help="Workspace name")] = None,
    json_output: Annotated[bool, typer.Option("--json", help="Output as JSON")] = False,
) -> None:
    """Query the knowledge graph for a workspace."""
    # Tests expect None as default, so don't convert to "default"
    result = query_graph(query, workspace=workspace)

    if json_output:
        typer.echo(json.dumps(result))
    else:
        answer = result.get("answer", "")
        node_count = len(result.get("nodes", []))
        edge_count = len(result.get("edges", []))

        console.print(f"[bold]Query:[/bold] {query}")
        console.print(f"[bold]Answer:[/bold] {answer}")
        console.print(f"[dim](nodes: {node_count}, edges: {edge_count})[/dim]")


# Workspace management subcommands
_workspace_app = typer.Typer(help="Manage GraphRAG workspaces")

Loading