Commit 26717a61 authored by Jan Reimes's avatar Jan Reimes
Browse files

♻️ refactor(cli): improve workspace member processing with progress tracking

parent a13789e1
Loading
Loading
Loading
Loading
+38 −12
Original line number Diff line number Diff line
@@ -9,12 +9,13 @@ from __future__ import annotations
import asyncio
import json
import shutil
from collections.abc import Callable
from datetime import UTC, datetime
from pathlib import Path
from typing import Any

import typer
from rich.progress import Progress, SpinnerColumn, TaskProgressColumn, TextColumn, TimeElapsedColumn
from rich.progress import MofNCompleteColumn, Progress, SpinnerColumn, TextColumn, TimeElapsedColumn
from rich.table import Table

from tdoc_crawler.config import CacheManager, resolve_cache_manager
@@ -245,7 +246,21 @@ def _try_build_tdoc_metadata(source_item_id: str) -> RAGMetadata | None:
        return None


async def _process_workspace_members(workspace: str, members: list[Any]) -> list[dict[str, Any]]:
async def _process_workspace_members(
    workspace: str,
    members: list[Any],
    on_progress: Callable[[int, str], None] | None = None,
) -> list[dict[str, Any]]:
    """Process workspace members with optional progress callback.

    Args:
        workspace: Workspace name
        members: List of workspace members to process
        on_progress: Optional callback(completed_count, source_item_id) called after each member

    Returns:
        List of processing results
    """
    processor = TDocProcessor(LightRAGConfig.from_env())
    results: list[dict[str, Any]] = []

@@ -259,8 +274,10 @@ async def _process_workspace_members(workspace: str, members: list[Any]) -> list
                        "source_item_id": member.source_item_id,
                        "status": "skipped",
                        "reason": "path or supported file not found",
                    }
                    },
                )
                if on_progress:
                    on_progress(len(results), member.source_item_id)
                continue

            metadata = _try_build_tdoc_metadata(member.source_item_id)
@@ -273,8 +290,10 @@ async def _process_workspace_members(workspace: str, members: list[Any]) -> list
                    "chars_extracted": process_result.chars_extracted,
                    "reason": process_result.reason,
                    "error": process_result.error,
                }
                },
            )
            if on_progress:
                on_progress(len(results), member.source_item_id)
    finally:
        await processor.rag.stop()

@@ -358,8 +377,8 @@ def workspace_list(
                        "created_at": entry.created_at,
                    }
                    for entry in workspaces
                ]
            )
                ],
            ),
        )
        return

@@ -519,8 +538,8 @@ def workspace_list_members(
                        "added_at": entry.added_at,
                    }
                    for entry in members
                ]
            )
                ],
            ),
        )
        return

@@ -556,13 +575,20 @@ def workspace_process(
    with Progress(
        SpinnerColumn(),
        TextColumn("[progress.description]{task.description}"),
        TaskProgressColumn(),
        MofNCompleteColumn(),
        TimeElapsedColumn(),
        console=console,
    ) as progress:
        task = progress.add_task("[cyan]Processing workspace members", total=len(members))
        results = asyncio.run(_process_workspace_members(workspace_name, members))
        progress.update(task, completed=len(results))
        task = progress.add_task("[cyan]Processing...", total=len(members))
        completed = 0

        def on_progress(count: int, source_item_id: str) -> None:
            nonlocal completed
            completed = count
            progress.update(task, completed=completed, description=f"[cyan]{source_item_id}")

        results = asyncio.run(_process_workspace_members(workspace_name, members, on_progress=on_progress))
        progress.update(task, completed=len(results), description="[cyan]Processing complete")

    success_count = sum(1 for row in results if row["status"] == "success")
    skipped_count = sum(1 for row in results if row["status"] == "skipped")