Commit 821cf120 authored by Jan Reimes's avatar Jan Reimes
Browse files

cli: make ai query/graph positional and add cached CacheManager helper

parent 2e2e19b3
Loading
Loading
Loading
Loading
+23 −24
Original line number Diff line number Diff line
@@ -25,13 +25,13 @@ from tdoc_crawler.ai import (
    make_workspace_member,
    normalize_workspace_name,
    process_document,
    query_embeddings,
    query_graph,
    set_active_workspace,
    summarize_document,
)
from tdoc_crawler.ai.models import PipelineStage, SourceKind
from tdoc_crawler.ai.operations.pipeline import process_all
from tdoc_crawler.ai.operations.workspace_registry import WorkspaceRegistry
from tdoc_crawler.ai.operations.workspaces import (
    add_workspace_members,
    get_workspace,
@@ -57,18 +57,17 @@ from tdoc_crawler.utils.date_parser import parse_partial_date
ai_app = typer.Typer(help="AI document processing commands")
console = Console()


def _get_ai_dir() -> Path:
    """Get the .ai directory path."""
    manager = CacheManager().register()
    return manager.ai_cache_dir
# Global cache manager instance - lazily initialized
_default_cache_manager: CacheManager | None = None


def clear_active_workspace() -> None:
    """Clear the active workspace file (deprecated - use registry)."""
    active_file = _get_ai_dir() / "active_workspace"
    if active_file.exists():
        active_file.unlink()
def _get_cache_manager() -> CacheManager:
    """Get or create the default cache manager (avoids repeated register() calls)."""
    global _default_cache_manager
    if _default_cache_manager is None:
        _default_cache_manager = CacheManager(name="default")
        _default_cache_manager.register()
    return _default_cache_manager


def resolve_workspace(workspace: str | None) -> str:
@@ -138,14 +137,16 @@ def ai_convert(

@ai_app.command("query")
def ai_query(
    query: Annotated[str, typer.Option("--query", "-q", help="Semantic search query")],
    query: Annotated[str, typer.Argument(..., help="Semantic search query")],
    workspace: Annotated[str | None, typer.Option("--workspace", "-w", help="Workspace name")] = None,
    top_k: Annotated[int, typer.Option("--top-k", "-k", help="Number of embedding results to return")] = 5,
    json_output: Annotated[bool, typer.Option("--json", help="Output as JSON")] = False,
) -> None:
    """Search TDocs using semantic embeddings and knowledge graph (RAG + GraphRAG)."""
    # Get embeddings manager directly from container
    embeddings_manager = AiServiceContainer.get_instance().get_embeddings_manager()
    # Tests expect None as default, so don't convert to "default"
    embedding_results = query_embeddings(query, top_k=top_k, workspace=workspace)
    embedding_results = embeddings_manager.query_embeddings(query, workspace or "default", top_k)
    graph_result = query_graph(query, workspace=workspace)

    # Format results as expected by tests: {"query": ..., "results": [...]}
@@ -265,7 +266,7 @@ def ai_status(

@ai_app.command("graph")
def ai_graph(
    query: Annotated[str, typer.Option("--query", "-q", help="Graph query string")],
    query: Annotated[str, typer.Argument(..., help="Graph query string")],
    workspace: Annotated[str | None, typer.Option("--workspace", "-w", help="Workspace name")] = None,
    json_output: Annotated[bool, typer.Option("--json", help="Output as JSON")] = False,
) -> None:
@@ -299,7 +300,7 @@ def workspace_create(
    json_output: Annotated[bool, typer.Option("--json", help="Output as JSON")] = False,
) -> None:
    """Create a new workspace."""
    CacheManager().register()
    _get_cache_manager()

    registry = create_workspace(name, auto_build=auto_build)
    workspace = registry.get_workspace(name)
@@ -327,9 +328,7 @@ def workspace_list(
    json_output: Annotated[bool, typer.Option("--json", help="Output as JSON")] = False,
) -> None:
    """List all workspaces."""
    from tdoc_crawler.ai.operations.workspace_registry import WorkspaceRegistry

    CacheManager().register()
    _get_cache_manager()
    registry = WorkspaceRegistry.load()
    workspaces = registry.list_workspaces()

@@ -375,7 +374,7 @@ def workspace_info(
    json_output: Annotated[bool, typer.Option("--json", help="Output as JSON")] = False,
) -> None:
    """Get detailed workspace information including member counts."""
    CacheManager().register()
    _get_cache_manager()

    workspace = get_workspace(name)

@@ -464,7 +463,7 @@ def workspace_clear(
) -> None:
    """Clear all AI artifacts (embeddings, summaries, etc.) while preserving workspace members."""
    workspace = resolve_workspace(workspace)
    CacheManager().register()
    _get_cache_manager()
    storage = AiServiceContainer.get_instance().get_storage()

    removed_count = storage.clear_workspace_artifacts(workspace)
@@ -504,7 +503,7 @@ def workspace_add_members(
    If no items are provided, queries the database using the provided filters.
    """
    workspace = resolve_workspace(workspace)
    manager = CacheManager().register()
    manager = _get_cache_manager()

    # Normalize kind to singular form (accept both 'tdoc' and 'tdocs')
    kind_normalized = kind.lower().rstrip("s")
@@ -590,7 +589,7 @@ def workspace_list_members(
) -> None:
    """List members of a workspace."""
    workspace = resolve_workspace(workspace)
    CacheManager().register()
    _get_cache_manager()

    members = list_workspace_members(workspace, include_inactive=include_inactive)

@@ -639,7 +638,7 @@ def workspace_process(
) -> None:
    """Process all active document members in a workspace through the AI pipeline."""
    workspace = resolve_workspace(workspace)
    manager = CacheManager().register()
    manager = _get_cache_manager()

    # Get workspace members
    members = list_workspace_members(workspace, include_inactive=False)
@@ -718,7 +717,7 @@ def workspace_delete(
    json_output: Annotated[bool, typer.Option("--json", help="Output as JSON")] = False,
) -> None:
    """Delete a workspace (default workspace cannot be deleted)."""
    CacheManager().register()
    _get_cache_manager()

    result = delete_workspace(name, preserve_artifacts=preserve_artifacts)