Commit 357f3877 authored by Jan Reimes's avatar Jan Reimes
Browse files

feat(ai): implement AI document processing pipeline and storage layer

- Phase 2 implemented
- Add configuration models for AI processing.
- Create data models for processing status, document classification, and chunks.
- Implement storage layer using LanceDB for AI artifacts.
- Define operations for processing TDocs through the AI pipeline.
- Introduce optional dependencies for AI processing.
parent c2b3731a
Loading
Loading
Loading
Loading
+8 −0
Original line number Diff line number Diff line
@@ -36,6 +36,14 @@ dependencies = [
    "zipinspect>=0.1.2",
]

[project.optional-dependencies]
ai = [
    "docling>=2.74.0",
    "lancedb>=0.29.2",
    "litellm>=1.81.15",
    "sentence-transformers>=2.7.0",
]

[project.urls]
Repository = "https://forge.3gpp.org/rep/reimes/tdoc-crawler"

+125 −0
Original line number Diff line number Diff line
"""AI document processing domain package."""

from __future__ import annotations

from collections.abc import Callable
from datetime import datetime
from typing import Any

from tdoc_crawler.ai.config import AiConfig
from tdoc_crawler.ai.models import (
    DocumentChunk,
    DocumentClassification,
    DocumentSummary,
    GraphEdge,
    GraphNode,
    PipelineStage,
    ProcessingStatus,
)


def process_tdoc(
    tdoc_id: str,
    config: AiConfig | None = None,
    stages: list[PipelineStage] | None = None,
) -> ProcessingStatus:
    """Process a single TDoc through the AI pipeline.

    Args:
            tdoc_id: TDoc identifier (e.g., "SP-123456").
            config: Pipeline configuration. Defaults to AiConfig.from_env().
            stages: Specific stages to run. If None, runs all applicable stages.

    Returns:
            Updated ProcessingStatus after pipeline execution.
    """
    raise NotImplementedError("AI pipeline is not implemented yet")


def process_all(
    config: AiConfig | None = None,
    new_only: bool = False,
    stages: list[PipelineStage] | None = None,
    progress_callback: Callable[[str, PipelineStage], None] | None = None,
) -> list[ProcessingStatus]:
    """Batch process all (or new-only) TDocs through the AI pipeline.

    Args:
            config: Pipeline configuration. Defaults to AiConfig.from_env().
            new_only: If True, only process TDocs not yet in processing_status.
            stages: Specific stages to run. If None, runs all applicable stages.
            progress_callback: Optional callback for reporting progress.

    Returns:
            List of ProcessingStatus entries for processed TDocs.
    """
    raise NotImplementedError("AI pipeline is not implemented yet")


def get_status(tdoc_id: str | None = None) -> ProcessingStatus | list[ProcessingStatus]:
    """Get processing status for one or all TDocs.

    Args:
            tdoc_id: If provided, return status for this TDoc. If None, return all.

    Returns:
            Single ProcessingStatus or list of ProcessingStatus entries.
    """
    raise NotImplementedError("AI status lookup is not implemented yet")


def query_embeddings(
    query: str,
    top_k: int = 5,
    tdoc_filter: list[str] | None = None,
    config: AiConfig | None = None,
) -> list[tuple[DocumentChunk, float]]:
    """Semantic search over embedded document chunks.

    Args:
            query: Natural language query text.
            top_k: Number of top results to return.
            tdoc_filter: Optional list of TDoc IDs to restrict search to.
            config: Pipeline configuration.

    Returns:
            List of (chunk, similarity_score) tuples sorted by descending score.
    """
    raise NotImplementedError("Embedding query is not implemented yet")


def query_graph(
    query: str,
    temporal_range: tuple[datetime, datetime] | None = None,
    node_types: list[str] | None = None,
    config: AiConfig | None = None,
) -> dict[str, Any]:
    """Query the temporal knowledge graph.

    Args:
            query: Natural language query about relationships/evolution.
            temporal_range: Optional (start, end) datetime filter.
            node_types: Optional filter for node types.
            config: Pipeline configuration.

    Returns:
            Dict with keys "nodes", "edges", and "answer".
    """
    raise NotImplementedError("Graph query is not implemented yet")


__all__ = [
    "AiConfig",
    "DocumentChunk",
    "DocumentClassification",
    "DocumentSummary",
    "GraphEdge",
    "GraphNode",
    "PipelineStage",
    "ProcessingStatus",
    "get_status",
    "process_all",
    "process_tdoc",
    "query_embeddings",
    "query_graph",
]
+90 −0
Original line number Diff line number Diff line
"""Configuration for the AI document processing pipeline."""

from __future__ import annotations

import os
from pathlib import Path

from pydantic import Field, model_validator

from tdoc_crawler.models.base import BaseConfigModel


def _env_int(name: str) -> int | None:
    value = os.getenv(name)
    if value is None or value == "":
        return None
    return int(value)


class AiConfig(BaseConfigModel):
    """Configuration for the AI processing pipeline."""

    ai_store_path: Path | None = Field(None, description="Path to AI LanceDB store")
    embedding_model: str = Field("BAAI/bge-small-en-v1.5", description="Embedding model name")
    max_chunk_size: int = Field(1000, ge=1, description="Max tokens per chunk")
    chunk_overlap: int = Field(100, ge=0, description="Token overlap between chunks")
    llm_model: str = Field("ollama/llama3.2", description="LLM model identifier")
    llm_api_base: str | None = Field(None, description="Override LLM API base URL")
    abstract_min_words: int = Field(150, ge=1, description="Minimum abstract word count")
    abstract_max_words: int = Field(250, ge=1, description="Maximum abstract word count")
    parallelism: int = Field(4, ge=1, le=32, description="Concurrent TDoc processing")

    @classmethod
    def from_env(cls, **overrides: object) -> AiConfig:
        """Create config from environment variables.

        Args:
            **overrides: Explicit values that take precedence over env vars.

        Returns:
            AiConfig instance populated from environment variables.
        """
        data: dict[str, object] = {}

        if store_path := os.getenv("TDC_AI_STORE_PATH"):
            data["ai_store_path"] = Path(store_path)
        if embedding_model := os.getenv("TDC_AI_EMBEDDING_MODEL"):
            data["embedding_model"] = embedding_model
        if llm_model := os.getenv("TDC_AI_LLM_MODEL"):
            data["llm_model"] = llm_model
        if llm_api_base := os.getenv("TDC_AI_LLM_API_BASE"):
            data["llm_api_base"] = llm_api_base

        max_chunk_size = _env_int("TDC_AI_MAX_CHUNK_SIZE")
        if max_chunk_size is not None:
            data["max_chunk_size"] = max_chunk_size
        chunk_overlap = _env_int("TDC_AI_CHUNK_OVERLAP")
        if chunk_overlap is not None:
            data["chunk_overlap"] = chunk_overlap
        abstract_min_words = _env_int("TDC_AI_ABSTRACT_MIN_WORDS")
        if abstract_min_words is not None:
            data["abstract_min_words"] = abstract_min_words
        abstract_max_words = _env_int("TDC_AI_ABSTRACT_MAX_WORDS")
        if abstract_max_words is not None:
            data["abstract_max_words"] = abstract_max_words
        parallelism = _env_int("TDC_AI_PARALLELISM")
        if parallelism is not None:
            data["parallelism"] = parallelism

        data.update(overrides)
        return cls(**data)

    @model_validator(mode="after")
    def _resolve_paths(self) -> AiConfig:
        if self.ai_store_path is None:
            self.ai_store_path = self.cache_dir / ".ai" / "lancedb"
        return self

    @model_validator(mode="after")
    def _validate_bounds(self) -> AiConfig:
        if self.abstract_max_words < self.abstract_min_words:
            msg = "abstract_max_words must be >= abstract_min_words"
            raise ValueError(msg)
        if self.chunk_overlap >= self.max_chunk_size:
            msg = "chunk_overlap must be less than max_chunk_size"
            raise ValueError(msg)
        return self


__all__ = ["AiConfig"]
+211 −0
Original line number Diff line number Diff line
"""Pydantic models for the AI document processing pipeline."""

from __future__ import annotations

from datetime import datetime
from enum import StrEnum
from typing import Any

from pydantic import BaseModel, Field, field_validator, model_validator

from tdoc_crawler.utils.misc import utc_now


def _normalize_tdoc_id(value: str) -> str:
    return value.strip().upper()


class PipelineStage(StrEnum):
    """Stages of the AI processing pipeline."""

    PENDING = "pending"
    CLASSIFYING = "classifying"
    EXTRACTING = "extracting"
    EMBEDDING = "embedding"
    SUMMARIZING = "summarizing"
    GRAPHING = "graphing"
    COMPLETED = "completed"
    FAILED = "failed"


class GraphNodeType(StrEnum):
    """Types of nodes in the knowledge graph."""

    TDOC = "tdoc"
    MEETING = "meeting"
    SPEC = "spec"
    WORK_ITEM = "work_item"
    CHANGE_REQUEST = "cr"
    COMPANY = "company"
    CONCEPT = "concept"


class GraphEdgeType(StrEnum):
    """Types of edges in the knowledge graph."""

    DISCUSSES = "discusses"
    REVISES = "revises"
    REFERENCES = "references"
    SUPERSEDES = "supersedes"
    AUTHORED_BY = "authored_by"
    MERGED_INTO = "merged_into"
    PRESENTED_AT = "presented_at"


class AiError(Exception):
    """Base exception for AI processing errors."""


class TDocNotFoundError(AiError):
    """TDoc not found in database or has no files."""


class ExtractionError(AiError):
    """DOCX extraction failed (corrupt, password-protected, etc.)."""


class LlmConfigError(AiError):
    """LLM endpoint not configured or unreachable."""


class AiConfigError(AiError):
    """Invalid or missing AI configuration."""


class EmbeddingDimensionError(AiError):
    """Embedding model dimension mismatch with stored vectors."""


class ProcessingStatus(BaseModel):
    """Processing state for a single TDoc."""

    tdoc_id: str = Field(..., description="TDoc identifier (normalized via .upper())")
    current_stage: PipelineStage = Field(PipelineStage.PENDING, description="Current pipeline stage")
    classified_at: datetime | None = Field(None, description="Timestamp when classification completed")
    extracted_at: datetime | None = Field(None, description="Timestamp when extraction completed")
    embedded_at: datetime | None = Field(None, description="Timestamp when embedding completed")
    summarized_at: datetime | None = Field(None, description="Timestamp when summarization completed")
    graphed_at: datetime | None = Field(None, description="Timestamp when graphing completed")
    completed_at: datetime | None = Field(None, description="Timestamp when pipeline completed")
    error_message: str | None = Field(None, description="Error details for failed stage")
    source_hash: str | None = Field(None, description="Hash of source DOCX for change detection")

    @field_validator("tdoc_id")
    @classmethod
    def _normalize_tdoc_id(cls, value: str) -> str:
        return _normalize_tdoc_id(value)


class DocumentClassification(BaseModel):
    """Classification of a file within a TDoc folder."""

    tdoc_id: str = Field(..., description="TDoc identifier (normalized via .upper())")
    file_path: str = Field(..., description="Relative path within checkout folder")
    is_main_document: bool = Field(..., description="Whether this file is the main document")
    confidence: float = Field(..., description="Confidence score between 0.0 and 1.0")
    decisive_heuristic: str = Field(..., description="Rule that determined the classification")
    file_size_bytes: int = Field(..., ge=0, description="File size in bytes")
    classified_at: datetime = Field(default_factory=utc_now, description="Classification timestamp")

    @field_validator("tdoc_id")
    @classmethod
    def _normalize_tdoc_id(cls, value: str) -> str:
        return _normalize_tdoc_id(value)

    @field_validator("confidence")
    @classmethod
    def _validate_confidence(cls, value: float) -> float:
        if not 0.0 <= value <= 1.0:
            msg = "confidence must be between 0.0 and 1.0"
            raise ValueError(msg)
        return value


class DocumentChunk(BaseModel):
    """A chunk of extracted document text with its embedding."""

    chunk_id: str = Field(..., description="Unique chunk identifier '{tdoc_id}:{chunk_index}'")
    tdoc_id: str = Field(..., description="TDoc identifier (normalized via .upper())")
    section_heading: str | None = Field(None, description="Heading for the chunk's section")
    chunk_index: int = Field(..., ge=0, description="Position within the document")
    text: str = Field(..., description="Chunk text content")
    char_offset_start: int = Field(..., ge=0, description="Start offset in Markdown")
    char_offset_end: int = Field(..., ge=0, description="End offset in Markdown")
    vector: list[float] = Field(..., description="Embedding vector")
    embedding_model: str = Field(..., description="Embedding model identifier")
    created_at: datetime = Field(default_factory=utc_now, description="Embedding creation timestamp")

    @field_validator("tdoc_id")
    @classmethod
    def _normalize_tdoc_id(cls, value: str) -> str:
        return _normalize_tdoc_id(value)


class DocumentSummary(BaseModel):
    """AI-generated summary for a TDoc."""

    tdoc_id: str = Field(..., description="TDoc identifier (normalized via .upper())")
    abstract: str = Field(..., description="150-250 word abstract")
    key_points: list[str] = Field(default_factory=list, description="Key findings")
    action_items: list[str] = Field(default_factory=list, description="Action items")
    decisions: list[str] = Field(default_factory=list, description="Decisions recorded")
    affected_specs: list[str] = Field(default_factory=list, description="Affected specification IDs")
    llm_model: str = Field(..., description="Model used for generation")
    prompt_version: str = Field(..., description="Prompt template version")
    generated_at: datetime = Field(default_factory=utc_now, description="Generation timestamp")

    @field_validator("tdoc_id")
    @classmethod
    def _normalize_tdoc_id(cls, value: str) -> str:
        return _normalize_tdoc_id(value)


class GraphNode(BaseModel):
    """A node in the temporal knowledge graph."""

    node_id: str = Field(..., description="Unique node identifier")
    node_type: GraphNodeType = Field(..., description="Node type")
    label: str = Field(..., description="Human-readable label")
    valid_from: datetime | None = Field(None, description="Temporal validity start")
    valid_to: datetime | None = Field(None, description="Temporal validity end")
    properties: dict[str, Any] = Field(default_factory=dict, description="Type-specific properties")
    created_at: datetime = Field(default_factory=utc_now, description="Node creation timestamp")


class GraphEdge(BaseModel):
    """An edge in the temporal knowledge graph."""

    edge_id: str = Field(..., description="Edge identifier '{source}->{edge_type}->{target}'")
    source_id: str = Field(..., description="Source node id")
    target_id: str = Field(..., description="Target node id")
    edge_type: GraphEdgeType = Field(..., description="Edge type")
    weight: float = Field(1.0, description="Relationship strength")
    temporal_context: str | None = Field(None, description="Meeting or date context")
    provenance: str = Field(..., description="How this edge was derived")
    created_at: datetime = Field(default_factory=utc_now, description="Edge creation timestamp")

    @model_validator(mode="after")
    def _validate_weight(self) -> GraphEdge:
        if self.weight <= 0:
            msg = "weight must be positive"
            raise ValueError(msg)
        return self


__all__ = [
    "AiConfigError",
    "AiError",
    "DocumentChunk",
    "DocumentClassification",
    "DocumentSummary",
    "EmbeddingDimensionError",
    "ExtractionError",
    "GraphEdge",
    "GraphEdgeType",
    "GraphNode",
    "GraphNodeType",
    "LlmConfigError",
    "PipelineStage",
    "ProcessingStatus",
    "TDocNotFoundError",
]
+1 −0
Original line number Diff line number Diff line
"""AI processing operations."""
Loading