Commit 37cfe8bd authored by Jan Reimes's avatar Jan Reimes
Browse files

feat(workspaces): add auto_build flag to workspace creation and management

* Introduce `auto_build` parameter in `create_workspace` and `ensure_workspace` functions.
* Update `AiStorage` to handle `auto_build` flag during workspace creation.
* Modify CLI commands to support `--auto-build` option for workspace creation.
* Enhance tests to verify the persistence of the `auto_build` flag in workspaces.
* Refactor related functions and update documentation accordingly.

feat(summarize): implement summarize_tdoc function for TDoc summarization

* Add `summarize_tdoc` function to generate concise summaries and extract keywords.
* Introduce prompts for summary generation and keyword extraction.
* Implement error handling for remote metadata fetching.
* Update CLI to include `ai summarize` command for summarizing TDocs.

feat(graph): enhance query_graph function to return structured results

* Modify `query_graph` to return a dictionary with results and metadata.
* Integrate workspace normalization in graph queries.
* Update related tests to validate new functionality.

fix(embeddings): enforce required workspace parameter in query_embeddings

* Change `query_embeddings` to require a workspace parameter.
* Update function documentation to reflect changes.

test: add comprehensive tests for new features and CLI commands

* Implement tests for workspace creation with `auto_build` flag.
* Add tests for the new `ai summarize` command and its output formats.
* Ensure existing tests are updated to reflect changes in functionality.
parent f90b0975
Loading
Loading
Loading
Loading
+18 −171
Original line number Diff line number Diff line
@@ -3,9 +3,7 @@
from __future__ import annotations

from collections.abc import Callable
from datetime import datetime
from pathlib import Path
from typing import Any

from tdoc_crawler.ai.config import AiConfig
from tdoc_crawler.ai.models import (
@@ -17,11 +15,13 @@ from tdoc_crawler.ai.models import (
    PipelineStage,
    ProcessingStatus,
)
from tdoc_crawler.ai.operations.convert import convert_tdoc
from tdoc_crawler.ai.operations.embeddings import query_embeddings as _query_embeddings
from tdoc_crawler.ai.operations.graph import query_graph as _query_graph
from tdoc_crawler.ai.operations.pipeline import get_status as _pipeline_get_status_impl
from tdoc_crawler.ai.operations.pipeline import process_all as _pipeline_process_all_impl
from tdoc_crawler.ai.operations.pipeline import process_tdoc as _pipeline_process_tdoc_impl
from tdoc_crawler.ai.operations.summarize import SummarizeResult, summarize_tdoc
from tdoc_crawler.ai.operations.workspaces import (
    DEFAULT_WORKSPACE,
    add_workspace_members,
@@ -40,9 +40,20 @@ from tdoc_crawler.config import CacheManager


def _pipeline_get_status(tdoc_id: str, workspace: str) -> ProcessingStatus | None:
    """Get processing status for a TDoc."""
    return _pipeline_get_status_impl(tdoc_id, workspace=workspace)


def query_embeddings(query: str, workspace: str, top_k: int = 10) -> list:
    """Query embeddings for a given query string."""
    return _query_embeddings(query, workspace=workspace, top_k=top_k)


def query_graph(query: str, workspace: str, top_k: int = 10) -> list:
    """Query the knowledge graph for a given query string."""
    return _query_graph(query, workspace=workspace, top_k=top_k)


def _pipeline_process_all(
    tdoc_ids: list[str],
    checkout_base: Path,
@@ -75,175 +86,11 @@ def _pipeline_process_tdoc(
    )


def process_tdoc(
    tdoc_id: str,
    config: AiConfig | str | Path | None = None,
    stages: list[PipelineStage] | None = None,
    checkout_path: str | Path | None = None,
    force_rerun: bool = False,
    workspace: str | None = None,
) -> ProcessingStatus:
    """Process a single TDoc through the AI pipeline.

    Args:
        tdoc_id: TDoc identifier (e.g., "SP-123456").
        config: Optional AI configuration.
        stages: Optional subset of pipeline stages.
        checkout_path: Optional path to TDoc checkout folder.
        force_rerun: If True, skip resume logic.
        workspace: Optional workspace name. Omitted/blank resolves to default.

    Returns:
        Updated ProcessingStatus after pipeline execution.
    """
    _ = stages
    normalized_workspace = normalize_workspace_name(workspace)

    if checkout_path is None and isinstance(config, (str, Path)):
        checkout_path = config
        config = None
    _ = config

    if checkout_path is not None:
        return _pipeline_process_tdoc(
            tdoc_id,
            Path(checkout_path),
            force_rerun=force_rerun,
            workspace=normalized_workspace,
        )

    status = _pipeline_get_status(tdoc_id, workspace=normalized_workspace)
    if status is None:
        msg = f"No status found for {tdoc_id}. Provide checkout_path to process the TDoc."
        raise ValueError(msg)
    return status


def process_all(
    config: AiConfig | None = None,
    new_only: bool = False,
    stages: list[PipelineStage] | None = None,
    progress_callback: Callable[[str, PipelineStage], None] | None = None,
    tdoc_ids: list[str] | None = None,
    checkout_base: str | Path | None = None,
    force_rerun: bool = False,
    workspace: str | None = None,
) -> list[ProcessingStatus]:
    """Batch process TDocs through the AI pipeline.

    Args:
        config: Optional AI configuration.
        new_only: Process new-only mode flag.
        stages: Optional subset of pipeline stages.
        progress_callback: Optional progress callback.
        tdoc_ids: Optional list of TDoc identifiers (backward-compatible mode).
        checkout_base: Optional base path containing TDoc folders.
        force_rerun: If True, skip resume logic.
        workspace: Optional workspace name. Omitted/blank resolves to default.

    Returns:
        List of ProcessingStatus values.
    """
    _ = config
    _ = stages
    normalized_workspace = normalize_workspace_name(workspace)

    if tdoc_ids is None or checkout_base is None:
        return []

    mapped_progress_callback = (lambda stage, current_tdoc_id: progress_callback(current_tdoc_id, stage)) if progress_callback else None
    try:
        results = _pipeline_process_all(
            tdoc_ids,
            Path(checkout_base),
            new_only=new_only,
            force_rerun=force_rerun,
            progress_callback=mapped_progress_callback,
            workspace=normalized_workspace,
        )
    except TypeError:
        results = _pipeline_process_all(
            tdoc_ids,
            Path(checkout_base),
            new_only=False,
            force_rerun=force_rerun,
            progress_callback=None,
            workspace=normalized_workspace,
        )
    statuses = list(results.values())
    if new_only:
        return [status for status in statuses if status.current_stage != PipelineStage.COMPLETED]
    return statuses


def get_status(
    tdoc_id: str | None = None,
    workspace: str | None = None,
) -> ProcessingStatus | list[ProcessingStatus] | None:
    """Get processing status for a TDoc.

    Args:
        tdoc_id: Optional TDoc identifier. If omitted, returns all statuses.
        workspace: Optional workspace name. Omitted/blank resolves to default.

    Returns:
        ProcessingStatus, list of statuses, or None if not found.
    """
    normalized_workspace = normalize_workspace_name(workspace)
    if tdoc_id is None:
        manager = CacheManager().register()
        storage = AiStorage(manager.root / ".ai" / "lancedb")
        ensure_default_workspace(storage)
        return storage.list_statuses(workspace=normalized_workspace)
    return _pipeline_get_status(tdoc_id, workspace=normalized_workspace)


def query_embeddings(
    query: str,
    top_k: int = 5,
    workspace: str | None = None,
) -> list[tuple[DocumentChunk, float]]:
    """Semantic search over embedded document chunks.

    Args:
        query: Natural language query text.
        top_k: Number of top results to return.
        workspace: Optional workspace scope (defaults to "default").

    Returns:
        List of (chunk, similarity_score) tuples.
    """
    return _query_embeddings(query, top_k, workspace=workspace)


def query_graph(
    query: str,
    temporal_range: tuple[datetime, datetime] | None = None,
    node_types: list[str] | None = None,
) -> dict[str, Any]:
    """Query the temporal knowledge graph.

    Args:
        query: Natural language query about relationships/evolution.
        temporal_range: Optional (start, end) datetime filter.
        node_types: Optional filter for node types.

    Returns:
        Dict with keys "nodes", "edges", and "answer".
    """
    results = _query_graph(query, top_k=10)

    # Convert to dict format
    return {
        "nodes": [r.node.model_dump() for r in results],
        "edges": [e.model_dump() for r in results for e in r.connected_edges],
        "answer": f"Found {len(results)} related documents",
    }


__all__ = [
    "DEFAULT_WORKSPACE",
    "AiConfig",
    "AiStorage",
    "CacheManager",
    "DocumentChunk",
    "DocumentClassification",
    "DocumentSummary",
@@ -251,19 +98,19 @@ __all__ = [
    "GraphNode",
    "PipelineStage",
    "ProcessingStatus",
    "SummarizeResult",
    "add_workspace_members",
    "convert_tdoc",
    "create_workspace",
    "delete_workspace",
    "ensure_default_workspace",
    "get_status",
    "get_workspace",
    "is_default_workspace",
    "list_workspaces",
    "make_workspace_member",
    "normalize_workspace_name",
    "process_all",
    "process_tdoc",
    "query_embeddings",
    "query_graph",
    "resolve_workspace",
    "summarize_tdoc",
]
+44 −0
Original line number Diff line number Diff line
@@ -2,10 +2,12 @@

from __future__ import annotations

import json
from datetime import datetime
from enum import StrEnum
from typing import Any

import yaml
from pydantic import BaseModel, Field, field_validator, model_validator

from tdoc_crawler.utils.misc import utc_now
@@ -98,6 +100,7 @@ class Workspace(BaseModel):
    created_at: datetime = Field(default_factory=utc_now, description="Creation timestamp")
    updated_at: datetime = Field(default_factory=utc_now, description="Last update timestamp")
    is_default: bool = Field(False, description="Whether this workspace is the default workspace")
    auto_build: bool | None = Field(default=False, description="Whether to auto-build embeddings on TDoc add")
    status: WorkspaceStatus = Field(WorkspaceStatus.ACTIVE, description="Workspace lifecycle status")

    @field_validator("workspace_name")
@@ -307,6 +310,46 @@ class GraphEdge(BaseModel):
        return self


class SummarizeResult(BaseModel):
    """Result of TDoc summarization operation."""

    summary: str = Field(..., description="Generated summary text")
    keywords: list[str] = Field(default_factory=list, description="Extracted keywords")
    metadata: dict = Field(
        default_factory=dict,
        description="Metadata including meeting_id, source, word_count",
    )
    word_count: int = Field(..., ge=0, description="Actual word count of summary")

    def to_markdown(self) -> str:
        """Format result as markdown."""
        lines = [f"## Summary\n\n{self.summary}\n"]
        if self.keywords:
            lines.append(f"## Keywords\n\n{', '.join(self.keywords)}\n")
        if self.metadata:
            lines.append("## Metadata\n")
            for key, value in self.metadata.items():
                lines.append(f"- **{key}**: {value}\n")
        lines.append(f"\n*Word count: {self.word_count}*\n")
        return "".join(lines)

    def to_json(self) -> str:
        """Format result as JSON."""
        return json.dumps(self.model_dump(), indent=2)

    def to_yaml(self) -> str:
        """Format result as YAML."""
        return yaml.dump(self.model_dump(), default_flow_style=False)

    @field_validator("keywords", mode="before")
    @classmethod
    def _normalize_keywords(cls, value: list[str] | None) -> list[str]:
        """Normalize keywords to a list of strings."""
        if value is None:
            return []
        return [str(kw).strip() for kw in value if str(kw).strip()]


class GraphQueryResult(BaseModel):
    """Knowledge graph query result."""

@@ -333,6 +376,7 @@ __all__ = [
    "ProcessingStatus",
    "QueryResult",
    "SourceKind",
    "SummarizeResult",
    "TDocNotFoundError",
    "Workspace",
    "WorkspaceMember",
+137 −0
Original line number Diff line number Diff line
"""TDoc to Markdown conversion operations."""

from __future__ import annotations

from pathlib import Path

from tdoc_crawler.config import resolve_cache_manager
from tdoc_crawler.database.meetings import MeetingDatabase
from tdoc_crawler.logging import get_logger
from tdoc_crawler.tdocs.models import TDocMetadata
from tdoc_crawler.tdocs.sources.whatthespec import resolve_via_whatthespec

logger = get_logger(__name__)


def _get_meeting_info(meeting_id: int) -> str | None:
    """Get meeting short name from database.

    Args:
        meeting_id: Meeting identifier.

    Returns:
        Meeting short name if found, None otherwise.
    """
    try:
        manager = resolve_cache_manager()
        with MeetingDatabase(manager.db_file) as db:
            meetings = db._table_rows("meetings")
            for meeting in meetings:
                if meeting.meeting_id == meeting_id:
                    return meeting.short_name
    except Exception as exc:
        logger.debug(f"Failed to get meeting info for {meeting_id}: {exc}")
    return None


def _format_markdown(metadata: TDocMetadata) -> str:
    """Format TDoc metadata as markdown.

    Args:
        metadata: TDoc metadata to format.

    Returns:
        Markdown formatted string.
    """
    lines: list[str] = []

    # Title
    lines.append(f"# {metadata.title}\n")

    # TDoc ID
    lines.append(f"**TDoc ID:** {metadata.tdoc_id}\n")

    # Meeting info
    meeting_name = _get_meeting_info(metadata.meeting_id)
    if meeting_name:
        lines.append(f"**Meeting:** {meeting_name}\n")
    else:
        lines.append(f"**Meeting ID:** {metadata.meeting_id}\n")

    # Source
    if metadata.source:
        lines.append(f"**Source:** {metadata.source}\n")

    # Contact
    if metadata.contact:
        lines.append(f"**Contact:** {metadata.contact}\n")

    # Document type
    lines.append(f"**Type:** {metadata.tdoc_type}\n")

    # Purpose
    lines.append(f"**For:** {metadata.for_purpose}\n")

    # Agenda item
    if metadata.agenda_item_text:
        lines.append(f"**Agenda Item:** {metadata.agenda_item_text}\n")

    # Status
    if metadata.status:
        lines.append(f"**Status:** {metadata.status}\n")

    # URL
    if metadata.url:
        lines.append(f"\n**URL:** {metadata.url}\n")

    return "".join(lines)


def convert_tdoc(
    tdoc_id: str,
    output_path: Path | None = None,
) -> str:
    """Convert a TDoc to markdown format.

    Fetches TDoc metadata from WhatTheSpec and converts it to a markdown
    representation containing title, meeting info, source, and description.

    Args:
        tdoc_id: TDoc identifier (e.g., "S4-260001").
        output_path: Optional path to write markdown file. If None, returns
            the markdown string.

    Returns:
        Markdown formatted string if output_path is None, otherwise returns
        the output path.

    Raises:
        ValueError: If TDoc cannot be found via WhatTheSpec.
    """
    # Normalize TDoc ID
    normalized_id = tdoc_id.strip().upper()

    # Fetch metadata from WhatTheSpec
    logger.info(f"Fetching TDoc metadata for {normalized_id} via WhatTheSpec")
    metadata = resolve_via_whatthespec(normalized_id)

    if metadata is None:
        msg = f"TDoc {normalized_id} not found via WhatTheSpec"
        raise ValueError(msg)

    # Convert to markdown
    markdown_content = _format_markdown(metadata)

    # Write to file or return string
    if output_path is not None:
        output_path.parent.mkdir(parents=True, exist_ok=True)
        output_path.write_text(markdown_content, encoding="utf-8")
        logger.info(f"Wrote markdown to {output_path}")
        return str(output_path)

    return markdown_content


__all__ = [
    "convert_tdoc",
]
+2 −3
Original line number Diff line number Diff line
@@ -357,15 +357,15 @@ def generate_embeddings(

def query_embeddings(
    query: str,
    workspace: str,
    top_k: int = 5,
    workspace: str | None = None,
) -> list[tuple[DocumentChunk, float]]:
    """Query embeddings using semantic search.

    Args:
        query: Search query.
        workspace: Workspace scope (required).
        top_k: Number of results to return.
        workspace: Optional workspace scope (defaults to "default").

    Returns:
        List of (DocumentChunk, score) tuples.
@@ -379,7 +379,6 @@ def query_embeddings(

    # Encode query
    query_embedding = cast(Any, model.encode([query])[0])
    query_vector = query_embedding.tolist() if hasattr(query_embedding, "tolist") else list(query_embedding)
    query_vector = [float(value) for value in query_embedding.tolist()] if hasattr(query_embedding, "tolist") else [float(value) for value in query_embedding]

    # Search in storage
+22 −5
Original line number Diff line number Diff line
@@ -8,6 +8,7 @@ from datetime import datetime

from tdoc_crawler.ai.config import AiConfig
from tdoc_crawler.ai.models import GraphEdge, GraphEdgeType, GraphNode, GraphNodeType, GraphQueryResult
from tdoc_crawler.ai.operations.workspaces import normalize_workspace_name
from tdoc_crawler.ai.storage import AiStorage

logger = logging.getLogger(__name__)
@@ -157,13 +158,27 @@ def build_graph(

def query_graph(
    query: str,
    workspace: str,
    node_types: list[GraphNodeType] | None = None,
    top_k: int = 10,
    meeting_ids: list[str] | None = None,
    date_range: tuple[datetime, datetime] | None = None,
    storage: AiStorage | None = None,
) -> list[GraphQueryResult]:
    """Query the knowledge graph with temporal filtering and chronological sorting."""
) -> dict[str, list[GraphQueryResult]]:
    """Query the knowledge graph with temporal filtering and chronological sorting.

    Args:
        query: Search query string.
        workspace: Workspace scope (required).
        node_types: Filter by node types.
        top_k: Number of results to return.
        meeting_ids: Filter by meeting IDs.
        date_range: Filter by date range.
        storage: Optional storage instance.

    Returns:
        Dict with 'results' key containing list of GraphQueryResult objects.
    """
    if storage is None:
        config = AiConfig.from_env(cache_manager_name="default")
        store_path = config.ai_store_path
@@ -172,8 +187,10 @@ def query_graph(
            raise ValueError(msg)
        storage = AiStorage(store_path)

    nodes = storage.get_all_graph_nodes()
    edges = storage.get_all_graph_edges()
    normalized_workspace = normalize_workspace_name(workspace)

    nodes = storage.get_all_graph_nodes(workspace=normalized_workspace)
    edges = storage.get_all_graph_edges(workspace=normalized_workspace)

    matching_nodes = [n for n in nodes if not node_types or n.node_type in node_types]

@@ -198,7 +215,7 @@ def query_graph(
        results.append(GraphQueryResult(node=node, connected_edges=node_edges))

    logger.info(f"Graph query '{query}' returned {len(results)} results")
    return results
    return {"results": results}


def get_tdoc_evolution(tdoc_id: str, storage: AiStorage | None = None) -> list[GraphNode]:
Loading