Commit 759592b7 authored by Jan Reimes's avatar Jan Reimes
Browse files

Fix: Call list_workspace_members from workspaces module instead of AiStorage

The process_all function was incorrectly calling storage.list_workspace_members()
which doesn't exist on AiStorage. Now it correctly imports and calls the
standalone list_workspace_members function from the workspaces module.

Also updated test to mock the correct function.
parent 43846beb
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -20,7 +20,7 @@ from tdoc_crawler.ai.operations.classify import classify_document_files
from tdoc_crawler.ai.operations.embeddings import EmbeddingsManager, generate_embeddings
from tdoc_crawler.ai.operations.extract import extract_from_folder
from tdoc_crawler.ai.operations.summarize import summarize_document
from tdoc_crawler.ai.operations.workspaces import normalize_workspace_name
from tdoc_crawler.ai.operations.workspaces import list_workspace_members, normalize_workspace_name
from tdoc_crawler.ai.storage import AiStorage
from tdoc_crawler.utils.misc import utc_now

@@ -371,7 +371,7 @@ def process_all(
    # Get workspace members and build a lookup map
    members_map: dict[str, Any] = {}
    if normalized_workspace != "default":
        members = storage.list_workspace_members(normalized_workspace)
        members = list_workspace_members(normalized_workspace)
        member_ids = {m.source_item_id for m in members if m.is_active and m.source_kind == "tdoc"}
        members_map = {m.source_item_id: m for m in members if m.is_active and m.source_kind == "tdoc"}
        document_ids = [tid for tid in document_ids if tid in member_ids]
+4 −9
Original line number Diff line number Diff line
@@ -85,17 +85,12 @@ class TestRunPipeline:
        results = process_all(tdoc_ids, base_path)
        assert isinstance(results, dict)

    @patch("tdoc_crawler.ai.operations.pipeline.AiStorage")
    @patch("tdoc_crawler.ai.operations.pipeline.list_workspace_members")
    @patch("tdoc_crawler.ai.operations.pipeline.run_pipeline")
    def test_process_all_scopes_to_workspace_members(
        self, mock_run_pipeline: MagicMock, mock_ai_storage: MagicMock, mock_storage: MagicMock, test_data_dir: Path
    ) -> None:
    def test_process_all_scopes_to_workspace_members(self, mock_run_pipeline: MagicMock, mock_list_members: MagicMock, test_data_dir: Path) -> None:
        """Test process_all filters input by workspace members."""
        mock_ai_storage.return_value = mock_storage

        mock_storage.list_workspace_members.return_value = [
        mock_list_members.return_value = [
            WorkspaceMember(workspace_name="test_ws", source_item_id="S4-251003", source_path="/path", source_kind="tdoc", status="included"),
            # Note: "26260-j10" is NOT in the workspace
        ]

        mock_run_pipeline.return_value = ProcessingStatus(document_id="S4-251003", current_stage=PipelineStage.COMPLETED)
@@ -106,7 +101,7 @@ class TestRunPipeline:

        assert "S4-251003" in results
        assert "26260-j10" not in results
        mock_storage.list_workspace_members.assert_called_once_with("test_ws")
        mock_list_members.assert_called_once_with("test_ws")

    def test_progress_callback_invocation(self, mock_storage: MagicMock, test_data_dir: Path) -> None:
        """Test progress callback is invoked after each stage."""