Commit 38d0f720 authored by Jan Reimes's avatar Jan Reimes
Browse files

feat(pipeline): refactor pipeline stages and improve processing flow

* Consolidate classify and extract stages into a single function.
* Introduce separate functions for embedding and graphing stages.
* Enhance error handling and logging throughout the pipeline.
* Allow reuse of embeddings manager across stages for efficiency.
* Update process_all to utilize new phase functions for clarity.
* Improve workspace member resolution and document folder handling.
* Refactor status retrieval and creation logic for better maintainability.
parent 27387dd6
Loading
Loading
Loading
Loading
+618 −372

File changed.

Preview size limit exceeded, changes collapsed.

+41 −37
Original line number Diff line number Diff line
@@ -4,13 +4,20 @@ from __future__ import annotations

import json
import logging
import platform
from datetime import UTC, datetime
from functools import cache
from pathlib import Path

import typer
from rich.progress import BarColumn, Progress, SpinnerColumn, TaskProgressColumn, TextColumn, TimeRemainingColumn
from rich.progress import (
    BarColumn,
    Progress,
    SpinnerColumn,
    TaskProgressColumn,
    TextColumn,
    TimeRemainingColumn,
    TransferSpeedColumn,
)
from rich.table import Table
from tdoc_ai import (
    checkout_spec_to_workspace,
@@ -31,7 +38,7 @@ from tdoc_ai import (
from tdoc_ai.config import AiConfig
from tdoc_ai.models import SourceKind
from tdoc_ai.operations.embeddings import EmbeddingsManager
from tdoc_ai.operations.pipeline import process_all
from tdoc_ai.operations.pipeline import process_all, process_embed_phase, process_extract_phase, process_graph_phase
from tdoc_ai.operations.workspace_registry import WorkspaceRegistry
from tdoc_ai.operations.workspaces import (
    add_workspace_members,
@@ -687,7 +694,6 @@ def workspace_process(
    """
    workspace = resolve_workspace(workspace)
    manager = _get_cache_manager()
    is_windows = platform.system() == "Windows"

    # Get workspace members
    members = list_workspace_members(workspace, include_inactive=False)
@@ -705,48 +711,46 @@ def workspace_process(
            console.print(f"[yellow]No active members found in workspace '{normalize_workspace_name(workspace)}'[/yellow]")
        return

    # Simple Windows output (no rich progress bar due to WindowsLegacyRenderer bugs)
    if is_windows:
        console.print(f"[cyan]Processing {len(document_ids)} document(s) in workspace '{normalize_workspace_name(workspace)}'...[/cyan]")
        console.print("[dim]Phase 1: Extracting (Classify + Extract)[/dim]")
        results = process_all(
            document_ids=document_ids,
            checkout_base=manager.root,
            new_only=new_only,
            force_rerun=force_rerun,
            workspace=workspace,
        )
        console.print(f"[green]OK: Processed {len(results)} documents[/green]")
    else:
        # Unix: Three separate progress bars, one per phase
    with Progress(
        SpinnerColumn(),
        TextColumn("[progress.description]{task.description}"),
        BarColumn(),
        TaskProgressColumn(),
        TransferSpeedColumn(),
        TimeRemainingColumn(),
        console=console,
        refresh_per_second=10,
    ) as progress:
            # Phase 1: Extract
            extract_task = progress.add_task("[cyan]Phase 1: Extracting (Classify + Extract)", total=len(document_ids))
            results = process_all(
        extract_task = progress.add_task("[cyan]Phase 1: Classifying/Extracting", total=len(document_ids))
        extracted_ids, extract_results = process_extract_phase(
            document_ids=document_ids,
            checkout_base=manager.root,
            new_only=new_only,
            force_rerun=force_rerun,
            workspace=workspace,
            progress_callback=lambda _: progress.advance(extract_task),
        )
            progress.update(extract_task, completed=len(document_ids))

            # Phase 2: Embed
            embed_task = progress.add_task("[cyan]Phase 2: Embedding", total=len(document_ids))
            progress.update(embed_task, completed=len(document_ids))
        embed_task = progress.add_task("[cyan]Phase 2: Embedding", total=len(extracted_ids))
        embedded_ids, embed_results = process_embed_phase(
            document_ids=extracted_ids,
            checkout_base=manager.root,
            force_rerun=force_rerun,
            workspace=workspace,
            progress_callback=lambda _: progress.advance(embed_task),
        )

            # Phase 3: Graph
            graph_task = progress.add_task("[cyan]Phase 3: Building Graph", total=len(document_ids))
            progress.update(graph_task, completed=len(document_ids))
        graph_task = progress.add_task("[cyan]Phase 3: Building Graph", total=len(embedded_ids))
        _, graph_results = process_graph_phase(
            document_ids=embedded_ids,
            checkout_base=manager.root,
            workspace=workspace,
            progress_callback=lambda _: progress.advance(graph_task),
        )

        results = {**extract_results, **embed_results, **graph_results}

    console.print(f"[green]OK: Processed {len(results)} documents[/green]")
    if json_output:
        typer.echo(
            json.dumps(