Commit 983cf4a6 authored by Jan Reimes's avatar Jan Reimes
Browse files

feat(ai): add embedding backend option for AI processing

parent 2d24121d
Loading
Loading
Loading
Loading
+13 −1
Original line number Diff line number Diff line
@@ -48,6 +48,7 @@ from tdoc_crawler.cli.args import (
    ConvertDocumentArgument,
    ConvertOutputOption,
    EmbeddingTopKOption,
    EmbeddingBackendOption,
    EndDateOption,
    GraphQueryArgument,
    GraphQueryOption,
@@ -226,10 +227,12 @@ def ai_process(
    process_all_flag: ProcessAllOption = False,
    new_only: ProcessNewOnlyOption = False,
    force: ProcessForceOption = False,
    accelerate: EmbeddingBackendOption = "torch",
    json_output: JsonOutputOption = False,
) -> None:
    """Process a single document or all documents through the AI pipeline."""
    workspace = workspace or "default"
    config = AiConfig.from_env(embedding_backend=accelerate)

    if process_all_flag:
        # Process all documents in workspace
@@ -241,6 +244,7 @@ def ai_process(
            new_only=new_only,
            force_rerun=force,
            workspace=workspace,
            config=config,
        )
        if json_output:
            typer.echo(json.dumps(result))
@@ -250,7 +254,13 @@ def ai_process(
        # Process single document
        manager = _get_cache_manager()
        resolved_checkout = Path(checkout_path) if checkout_path else manager.checkout_dir / document_id
        result = process_document(document_id, workspace=workspace, checkout_path=resolved_checkout, force_rerun=force)
        result = process_document(
            document_id,
            workspace=workspace,
            checkout_path=resolved_checkout,
            force_rerun=force,
            config=config,
        )
        if json_output:
            typer.echo(json.dumps(result))
        else:
@@ -287,6 +297,8 @@ def ai_status(
                console.print(f"  {status} {stage}")
    else:
        # Get status for all documents in workspace
        # Ensure cache manager is registered before calling get_status
        _get_cache_manager()
        # get_status without document_id returns a list
        statuses = get_status(workspace=workspace)

+10 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@
from __future__ import annotations

from pathlib import Path
from typing import Annotated
from typing import Annotated, Literal

import typer

@@ -170,6 +170,15 @@ WorkspaceNameOption = Annotated[str | None, typer.Option("--workspace", "-w", he
EmbeddingTopKOption = Annotated[int, typer.Option("--top-k", "-k", help="Number of embedding results to return")]

ProcessTDocIdOption = Annotated[str | None, typer.Option("--tdoc-id", "-t", help="TDoc ID to process")]
EmbeddingBackendOption = Annotated[
    Literal["torch", "onnx", "openvino"],
    typer.Option(
        "--accelerate",
        "-a",
        help="Embedding backend (torch, onnx, openvino)",
        envvar="TDC_AI_EMBEDDING_BACKEND",
    ),
]
CheckoutPathOption = Annotated[str | None, typer.Option("--checkout-path", help="Path to checkout document")]
CheckoutBaseOption = Annotated[str | None, typer.Option("--checkout-base", help="Base path for checkout")]
ProcessAllOption = Annotated[bool, typer.Option("--all", help="Process all documents in workspace")]