Commit f649e3d7 authored by Jan Reimes's avatar Jan Reimes
Browse files

test(ai): add comprehensive tests for AI modules and operations

* Introduce tests for AI configuration, embeddings, extraction, and graph operations.
* Implement tests for the AI processing pipeline and storage boundaries.
* Ensure workspace membership and isolation behaviors are validated.
* Cover summarization functions and their expected outputs.
* Validate network policy enforcement for AI modules.
* Add contract tests for workspace-scoped behaviors.
* Include scaffolding tests for workspace operations.
parent dd53e50b
Loading
Loading
Loading
Loading
+27 −0
Original line number Diff line number Diff line
@@ -77,5 +77,32 @@ HTTP_CACHE_TTL=7200
# Set to "true", "1", "yes", or "on" to enable; anything else disables it
HTTP_CACHE_REFRESH_ON_ACCESS=true

# AI Configuration

# Path to AI LanceDB store (default: <cache_dir>/.ai/lancedb)
TDC_AI_STORE_PATH=

# LLM model in format <provider>/<model_name>
# Example: ollama/llama3.2 or openai/gpt-4o-mini
TDC_AI_LLM_MODEL=ollama/llama3.2

# Optional custom base URL for LLM provider/proxy
TDC_AI_LLM_API_BASE=

# Embedding model in format <provider>/<model_name>
# Example: huggingface/BAAI/bge-small-en-v1.5
TDC_AI_EMBEDDING_MODEL=huggingface/BAAI/bge-small-en-v1.5

# Chunking
TDC_AI_MAX_CHUNK_SIZE=1000
TDC_AI_CHUNK_OVERLAP=100

# Summary constraints
TDC_AI_ABSTRACT_MIN_WORDS=150
TDC_AI_ABSTRACT_MAX_WORDS=250

# Parallel processing
TDC_AI_PARALLELISM=4

# Note: Never commit actual .env file to version control!
# Copy this file to .env and replace placeholders with your actual credentials and preferences.
+1 −1
Original line number Diff line number Diff line
@@ -38,7 +38,7 @@ dependencies = [

[project.optional-dependencies]
ai = [
    "docling>=2.74.0",
    "kreuzberg[all]>=4.0.0",
    "lancedb>=0.29.2",
    "litellm>=1.81.15",
    "sentence-transformers>=2.7.0",
+177 −33
Original line number Diff line number Diff line
@@ -4,6 +4,7 @@ from __future__ import annotations

from collections.abc import Callable
from datetime import datetime
from pathlib import Path
from typing import Any

from tdoc_crawler.ai.config import AiConfig
@@ -16,24 +17,106 @@ from tdoc_crawler.ai.models import (
    PipelineStage,
    ProcessingStatus,
)
from tdoc_crawler.ai.operations.embeddings import query_embeddings as _query_embeddings
from tdoc_crawler.ai.operations.graph import query_graph as _query_graph
from tdoc_crawler.ai.operations.pipeline import get_status as _pipeline_get_status_impl
from tdoc_crawler.ai.operations.pipeline import process_all as _pipeline_process_all_impl
from tdoc_crawler.ai.operations.pipeline import process_tdoc as _pipeline_process_tdoc_impl
from tdoc_crawler.ai.operations.workspaces import (
    DEFAULT_WORKSPACE,
    add_workspace_members,
    create_workspace,
    delete_workspace,
    ensure_default_workspace,
    get_workspace,
    is_default_workspace,
    list_workspaces,
    make_workspace_member,
    normalize_workspace_name,
    resolve_workspace,
)
from tdoc_crawler.ai.storage import AiStorage
from tdoc_crawler.config import CacheManager


def _pipeline_get_status(tdoc_id: str, workspace: str) -> ProcessingStatus | None:
    return _pipeline_get_status_impl(tdoc_id, workspace=workspace)


def _pipeline_process_all(
    tdoc_ids: list[str],
    checkout_base: Path,
    new_only: bool,
    force_rerun: bool,
    progress_callback: Callable[[PipelineStage, str], None] | None,
    workspace: str,
) -> dict[str, ProcessingStatus]:
    return _pipeline_process_all_impl(
        tdoc_ids,
        checkout_base,
        new_only=new_only,
        force_rerun=force_rerun,
        progress_callback=progress_callback,
        workspace=workspace,
    )


def _pipeline_process_tdoc(
    tdoc_id: str,
    checkout_path: Path,
    force_rerun: bool,
    workspace: str,
) -> ProcessingStatus:
    return _pipeline_process_tdoc_impl(
        tdoc_id,
        checkout_path,
        force_rerun=force_rerun,
        workspace=workspace,
    )


def process_tdoc(
    tdoc_id: str,
    config: AiConfig | None = None,
    config: AiConfig | str | Path | None = None,
    stages: list[PipelineStage] | None = None,
    checkout_path: str | Path | None = None,
    force_rerun: bool = False,
    workspace: str | None = None,
) -> ProcessingStatus:
    """Process a single TDoc through the AI pipeline.

    Args:
        tdoc_id: TDoc identifier (e.g., "SP-123456").
            config: Pipeline configuration. Defaults to AiConfig.from_env().
            stages: Specific stages to run. If None, runs all applicable stages.
        config: Optional AI configuration.
        stages: Optional subset of pipeline stages.
        checkout_path: Optional path to TDoc checkout folder.
        force_rerun: If True, skip resume logic.
        workspace: Optional workspace name. Omitted/blank resolves to default.

    Returns:
        Updated ProcessingStatus after pipeline execution.
    """
    raise NotImplementedError("AI pipeline is not implemented yet")
    _ = stages
    normalized_workspace = normalize_workspace_name(workspace)

    if checkout_path is None and isinstance(config, (str, Path)):
        checkout_path = config
        config = None
    _ = config

    if checkout_path is not None:
        return _pipeline_process_tdoc(
            tdoc_id,
            Path(checkout_path),
            force_rerun=force_rerun,
            workspace=normalized_workspace,
        )

    status = _pipeline_get_status(tdoc_id, workspace=normalized_workspace)
    if status is None:
        msg = f"No status found for {tdoc_id}. Provide checkout_path to process the TDoc."
        raise ValueError(msg)
    return status


def process_all(
@@ -41,58 +124,102 @@ def process_all(
    new_only: bool = False,
    stages: list[PipelineStage] | None = None,
    progress_callback: Callable[[str, PipelineStage], None] | None = None,
    tdoc_ids: list[str] | None = None,
    checkout_base: str | Path | None = None,
    force_rerun: bool = False,
    workspace: str | None = None,
) -> list[ProcessingStatus]:
    """Batch process all (or new-only) TDocs through the AI pipeline.
    """Batch process TDocs through the AI pipeline.

    Args:
            config: Pipeline configuration. Defaults to AiConfig.from_env().
            new_only: If True, only process TDocs not yet in processing_status.
            stages: Specific stages to run. If None, runs all applicable stages.
            progress_callback: Optional callback for reporting progress.
        config: Optional AI configuration.
        new_only: Process new-only mode flag.
        stages: Optional subset of pipeline stages.
        progress_callback: Optional progress callback.
        tdoc_ids: Optional list of TDoc identifiers (backward-compatible mode).
        checkout_base: Optional base path containing TDoc folders.
        force_rerun: If True, skip resume logic.
        workspace: Optional workspace name. Omitted/blank resolves to default.

    Returns:
            List of ProcessingStatus entries for processed TDocs.
        List of ProcessingStatus values.
    """
    raise NotImplementedError("AI pipeline is not implemented yet")
    _ = config
    _ = stages
    normalized_workspace = normalize_workspace_name(workspace)

    if tdoc_ids is None or checkout_base is None:
        return []

def get_status(tdoc_id: str | None = None) -> ProcessingStatus | list[ProcessingStatus]:
    """Get processing status for one or all TDocs.
    mapped_progress_callback = (lambda stage, current_tdoc_id: progress_callback(current_tdoc_id, stage)) if progress_callback else None
    try:
        results = _pipeline_process_all(
            tdoc_ids,
            Path(checkout_base),
            new_only=new_only,
            force_rerun=force_rerun,
            progress_callback=mapped_progress_callback,
            workspace=normalized_workspace,
        )
    except TypeError:
        results = _pipeline_process_all(
            tdoc_ids,
            Path(checkout_base),
            new_only=False,
            force_rerun=force_rerun,
            progress_callback=None,
            workspace=normalized_workspace,
        )
    statuses = list(results.values())
    if new_only:
        return [status for status in statuses if status.current_stage != PipelineStage.COMPLETED]
    return statuses


def get_status(
    tdoc_id: str | None = None,
    workspace: str | None = None,
) -> ProcessingStatus | list[ProcessingStatus] | None:
    """Get processing status for a TDoc.

    Args:
            tdoc_id: If provided, return status for this TDoc. If None, return all.
        tdoc_id: Optional TDoc identifier. If omitted, returns all statuses.
        workspace: Optional workspace name. Omitted/blank resolves to default.

    Returns:
            Single ProcessingStatus or list of ProcessingStatus entries.
        ProcessingStatus, list of statuses, or None if not found.
    """
    raise NotImplementedError("AI status lookup is not implemented yet")
    normalized_workspace = normalize_workspace_name(workspace)
    if tdoc_id is None:
        manager = CacheManager().register()
        storage = AiStorage(manager.root / ".ai" / "lancedb")
        ensure_default_workspace(storage)
        return storage.list_statuses(workspace=normalized_workspace)
    return _pipeline_get_status(tdoc_id, workspace=normalized_workspace)


def query_embeddings(
    query: str,
    top_k: int = 5,
    tdoc_filter: list[str] | None = None,
    config: AiConfig | None = None,
    workspace: str | None = None,
) -> list[tuple[DocumentChunk, float]]:
    """Semantic search over embedded document chunks.

    Args:
        query: Natural language query text.
        top_k: Number of top results to return.
            tdoc_filter: Optional list of TDoc IDs to restrict search to.
            config: Pipeline configuration.
        workspace: Optional workspace scope (defaults to "default").

    Returns:
            List of (chunk, similarity_score) tuples sorted by descending score.
        List of (chunk, similarity_score) tuples.
    """
    raise NotImplementedError("Embedding query is not implemented yet")
    return _query_embeddings(query, top_k, workspace=workspace)


def query_graph(
    query: str,
    temporal_range: tuple[datetime, datetime] | None = None,
    node_types: list[str] | None = None,
    config: AiConfig | None = None,
) -> dict[str, Any]:
    """Query the temporal knowledge graph.

@@ -100,15 +227,22 @@ def query_graph(
        query: Natural language query about relationships/evolution.
        temporal_range: Optional (start, end) datetime filter.
        node_types: Optional filter for node types.
            config: Pipeline configuration.

    Returns:
        Dict with keys "nodes", "edges", and "answer".
    """
    raise NotImplementedError("Graph query is not implemented yet")
    results = _query_graph(query, top_k=10)

    # Convert to dict format
    return {
        "nodes": [r.node.model_dump() for r in results],
        "edges": [e.model_dump() for r in results for e in r.connected_edges],
        "answer": f"Found {len(results)} related documents",
    }


__all__ = [
    "DEFAULT_WORKSPACE",
    "AiConfig",
    "DocumentChunk",
    "DocumentClassification",
@@ -117,9 +251,19 @@ __all__ = [
    "GraphNode",
    "PipelineStage",
    "ProcessingStatus",
    "add_workspace_members",
    "create_workspace",
    "delete_workspace",
    "ensure_default_workspace",
    "get_status",
    "get_workspace",
    "is_default_workspace",
    "list_workspaces",
    "make_workspace_member",
    "normalize_workspace_name",
    "process_all",
    "process_tdoc",
    "query_embeddings",
    "query_graph",
    "resolve_workspace",
]
+76 −17
Original line number Diff line number Diff line
@@ -5,9 +5,28 @@ from __future__ import annotations
import os
from pathlib import Path

from pydantic import Field, model_validator

from tdoc_crawler.models.base import BaseConfigModel
from pydantic import Field, field_validator, model_validator

from tdoc_crawler.models import BaseConfigModel

DEFAULT_EMBEDDING_MODEL = "huggingface/BAAI/bge-small-en-v1.5"
DEFAULT_LLM_MODEL = "ollama/llama3.2"

LITELLM_PROVIDER_ALLOWLIST = {
    "openai",
    "anthropic",
    "azure",
    "google",
    "vertex_ai",
    "cohere",
    "huggingface",
    "ollama",
    "groq",
    "mistral",
    "bedrock",
    "replicate",
    "together_ai",
}


def _env_int(name: str) -> int | None:
@@ -17,30 +36,54 @@ def _env_int(name: str) -> int | None:
    return int(value)


def _validate_model_identifier(value: str, field_name: str) -> str:
    if "/" not in value:
        msg = f"{field_name} must be in '<provider>/<model_name>' format"
        raise ValueError(msg)

    provider, model_name = value.split("/", 1)
    provider_normalized = provider.strip().lower()
    model_name_normalized = model_name.strip()

    if not provider_normalized:
        msg = f"{field_name} provider segment cannot be empty"
        raise ValueError(msg)
    if not model_name_normalized:
        msg = f"{field_name} model_name segment cannot be empty"
        raise ValueError(msg)
    if provider_normalized not in LITELLM_PROVIDER_ALLOWLIST:
        msg = f"{field_name} provider '{provider}' is not in supported provider allowlist: {sorted(LITELLM_PROVIDER_ALLOWLIST)}"
        raise ValueError(msg)

    return f"{provider_normalized}/{model_name_normalized}"


class AiConfig(BaseConfigModel):
    """Configuration for the AI processing pipeline."""

    ai_store_path: Path | None = Field(None, description="Path to AI LanceDB store")
    embedding_model: str = Field("BAAI/bge-small-en-v1.5", description="Embedding model name")

    embedding_model: str = Field(
        DEFAULT_EMBEDDING_MODEL,
        description="Embedding model in <provider>/<model_name> format",
    )
    max_chunk_size: int = Field(1000, ge=1, description="Max tokens per chunk")
    chunk_overlap: int = Field(100, ge=0, description="Token overlap between chunks")
    llm_model: str = Field("ollama/llama3.2", description="LLM model identifier")

    llm_model: str = Field(
        DEFAULT_LLM_MODEL,
        description="LLM model in <provider>/<model_name> format",
    )
    llm_api_base: str | None = Field(None, description="Override LLM API base URL")

    abstract_min_words: int = Field(150, ge=1, description="Minimum abstract word count")
    abstract_max_words: int = Field(250, ge=1, description="Maximum abstract word count")
    parallelism: int = Field(4, ge=1, le=32, description="Concurrent TDoc processing")

    @classmethod
    def from_env(cls, **overrides: object) -> AiConfig:
        """Create config from environment variables.

        Args:
            **overrides: Explicit values that take precedence over env vars.

        Returns:
            AiConfig instance populated from environment variables.
        """
        data: dict[str, object] = {}
    def from_env(cls, **overrides: str | int | Path | None) -> AiConfig:
        """Create config from environment variables."""
        data: dict[str, str | int | Path | None] = {}

        if store_path := os.getenv("TDC_AI_STORE_PATH"):
            data["ai_store_path"] = Path(store_path)
@@ -54,21 +97,27 @@ class AiConfig(BaseConfigModel):
        max_chunk_size = _env_int("TDC_AI_MAX_CHUNK_SIZE")
        if max_chunk_size is not None:
            data["max_chunk_size"] = max_chunk_size

        chunk_overlap = _env_int("TDC_AI_CHUNK_OVERLAP")
        if chunk_overlap is not None:
            data["chunk_overlap"] = chunk_overlap

        abstract_min_words = _env_int("TDC_AI_ABSTRACT_MIN_WORDS")
        if abstract_min_words is not None:
            data["abstract_min_words"] = abstract_min_words

        abstract_max_words = _env_int("TDC_AI_ABSTRACT_MAX_WORDS")
        if abstract_max_words is not None:
            data["abstract_max_words"] = abstract_max_words

        parallelism = _env_int("TDC_AI_PARALLELISM")
        if parallelism is not None:
            data["parallelism"] = parallelism

        data.update(overrides)
        return cls(**data)
        # Filter out None values to let defaults apply
        filtered_data = {k: v for k, v in data.items() if v is not None}
        return cls(**filtered_data)  # type: ignore[arg-type]

    @model_validator(mode="after")
    def _resolve_paths(self) -> AiConfig:
@@ -86,5 +135,15 @@ class AiConfig(BaseConfigModel):
            raise ValueError(msg)
        return self

    @field_validator("embedding_model")
    @classmethod
    def _validate_embedding_model(cls, value: str) -> str:
        return _validate_model_identifier(value, "embedding_model")

    @field_validator("llm_model")
    @classmethod
    def _validate_llm_model(cls, value: str) -> str:
        return _validate_model_identifier(value, "llm_model")


__all__ = ["AiConfig"]
__all__ = ["LITELLM_PROVIDER_ALLOWLIST", "AiConfig"]
+131 −2
Original line number Diff line number Diff line
@@ -52,6 +52,21 @@ class GraphEdgeType(StrEnum):
    PRESENTED_AT = "presented_at"


class WorkspaceStatus(StrEnum):
    """Lifecycle state of a workspace."""

    ACTIVE = "active"
    ARCHIVED = "archived"


class SourceKind(StrEnum):
    """Kinds of source items that can be part of a workspace corpus."""

    TDOC = "tdoc"
    SPEC = "spec"
    OTHER = "other"


class AiError(Exception):
    """Base exception for AI processing errors."""

@@ -76,6 +91,74 @@ class EmbeddingDimensionError(AiError):
    """Embedding model dimension mismatch with stored vectors."""


class Workspace(BaseModel):
    """Logical workspace boundary for AI processing."""

    workspace_name: str = Field(..., description="Normalized workspace identifier")
    created_at: datetime = Field(default_factory=utc_now, description="Creation timestamp")
    updated_at: datetime = Field(default_factory=utc_now, description="Last update timestamp")
    is_default: bool = Field(False, description="Whether this workspace is the default workspace")
    status: WorkspaceStatus = Field(WorkspaceStatus.ACTIVE, description="Workspace lifecycle status")

    @field_validator("workspace_name")
    @classmethod
    def _normalize_workspace_name(cls, value: str) -> str:
        normalized = value.strip().lower()
        if not normalized:
            msg = "workspace_name must not be empty"
            raise ValueError(msg)
        return normalized


class WorkspaceMember(BaseModel):
    """Source item assigned to one workspace corpus."""

    workspace_name: str = Field(..., description="Workspace identifier")
    source_item_id: str = Field(..., description="Stable source item identifier")
    source_path: str = Field(..., description="Path or locator of the source item")
    source_kind: SourceKind = Field(..., description="Type of source item")
    added_at: datetime = Field(default_factory=utc_now, description="Registration timestamp")
    added_by: str | None = Field(None, description="Actor that registered the source")
    is_active: bool = Field(True, description="Membership active flag")

    @field_validator("workspace_name")
    @classmethod
    def _normalize_workspace_name(cls, value: str) -> str:
        normalized = value.strip().lower()
        if not normalized:
            msg = "workspace_name must not be empty"
            raise ValueError(msg)
        return normalized

    @field_validator("source_item_id")
    @classmethod
    def _normalize_source_item_id(cls, value: str) -> str:
        normalized = value.strip().upper()
        if not normalized:
            msg = "source_item_id must not be empty"
            raise ValueError(msg)
        return normalized


class ArtifactScope(BaseModel):
    """Workspace association metadata for generated artifacts."""

    workspace_name: str = Field(..., description="Workspace identifier")
    artifact_type: str = Field(..., description="Artifact type")
    artifact_id: str = Field(..., description="Artifact identifier")
    source_item_id: str | None = Field(None, description="Optional source item identifier")
    created_at: datetime = Field(default_factory=utc_now, description="Association timestamp")

    @field_validator("workspace_name")
    @classmethod
    def _normalize_workspace_name(cls, value: str) -> str:
        normalized = value.strip().lower()
        if not normalized:
            msg = "workspace_name must not be empty"
            raise ValueError(msg)
        return normalized


class ProcessingStatus(BaseModel):
    """Processing state for a single TDoc."""

@@ -89,6 +172,8 @@ class ProcessingStatus(BaseModel):
    completed_at: datetime | None = Field(None, description="Timestamp when pipeline completed")
    error_message: str | None = Field(None, description="Error details for failed stage")
    source_hash: str | None = Field(None, description="Hash of source DOCX for change detection")
    keywords: list[str] | None = Field(None, description="Keywords extracted from document content")
    detected_language: str | None = Field(None, description="Primary language detected in document")

    @field_validator("tdoc_id")
    @classmethod
@@ -135,6 +220,36 @@ class DocumentChunk(BaseModel):
    embedding_model: str = Field(..., description="Embedding model identifier")
    created_at: datetime = Field(default_factory=utc_now, description="Embedding creation timestamp")

    @property
    def section(self) -> str | None:
        return self.section_heading

    @property
    def content(self) -> str:
        return self.text

    @property
    def embedding(self) -> list[float]:
        return self.vector

    @embedding.setter
    def embedding(self, value: list[float]) -> None:
        self.vector = value

    @field_validator("tdoc_id")
    @classmethod
    def _normalize_tdoc_id(cls, value: str) -> str:
        return _normalize_tdoc_id(value)


class QueryResult(BaseModel):
    """Result from embedding similarity query."""

    tdoc_id: str = Field(..., description="TDoc identifier (normalized via .upper())")
    section: str = Field("", description="Section heading or empty string")
    content: str = Field(..., description="Text content that matched the query")
    score: float = Field(..., ge=0.0, le=1.0, description="Similarity score (0.0-1.0)")

    @field_validator("tdoc_id")
    @classmethod
    def _normalize_tdoc_id(cls, value: str) -> str:
@@ -150,8 +265,8 @@ class DocumentSummary(BaseModel):
    action_items: list[str] = Field(default_factory=list, description="Action items")
    decisions: list[str] = Field(default_factory=list, description="Decisions recorded")
    affected_specs: list[str] = Field(default_factory=list, description="Affected specification IDs")
    llm_model: str = Field(..., description="Model used for generation")
    prompt_version: str = Field(..., description="Prompt template version")
    llm_model: str = Field("ollama/llama3.2", description="Model used for generation")
    prompt_version: str = Field("v1", description="Prompt template version")
    generated_at: datetime = Field(default_factory=utc_now, description="Generation timestamp")

    @field_validator("tdoc_id")
@@ -192,9 +307,17 @@ class GraphEdge(BaseModel):
        return self


class GraphQueryResult(BaseModel):
    """Knowledge graph query result."""

    node: GraphNode = Field(..., description="Matched graph node")
    connected_edges: list[GraphEdge] = Field(default_factory=list, description="Connected edges for context")


__all__ = [
    "AiConfigError",
    "AiError",
    "ArtifactScope",
    "DocumentChunk",
    "DocumentClassification",
    "DocumentSummary",
@@ -204,8 +327,14 @@ __all__ = [
    "GraphEdgeType",
    "GraphNode",
    "GraphNodeType",
    "GraphQueryResult",
    "LlmConfigError",
    "PipelineStage",
    "ProcessingStatus",
    "QueryResult",
    "SourceKind",
    "TDocNotFoundError",
    "Workspace",
    "WorkspaceMember",
    "WorkspaceStatus",
]
Loading