Commit 41e07f86 authored by Jan Reimes's avatar Jan Reimes
Browse files

feat(ai): enhance configuration and models for graph query processing

* Add GraphQueryLevel type for defining query sophistication levels.
* Introduce ProcessingFailureType for classifying processing failures.
* Update AiConfig to include graph query level configuration.
* Modify models to support new failure types and properties.
* Adjust storage handling to accommodate new failure type field.
parent 674f8208
Loading
Loading
Loading
Loading
+23 −14
Original line number Diff line number Diff line
@@ -17,7 +17,7 @@ TDC_EOL_PROMPT=false
# Cache and Directory Configuration

# Cache directory for storing downloaded metadata and files (default: ~/.tdoc-crawler)
TDC_CACHE_DIR=/path/to/cache/dir
# TDC_CACHE_DIR=/path/to/cache/dir

# Checkout directory for downloaded TDocs is managed under the cache directory
# by default: <cache_dir>/checkout (use `--cache-dir` or `TDC_CACHE_DIR` to change)
@@ -34,32 +34,32 @@ TDC_TIMEOUT=60
TDC_MAX_RETRIES=3

# Maximum total crawl duration in seconds (default: None = unlimited)
TDC_OVERALL_TIMEOUT=
# TDC_OVERALL_TIMEOUT=

# Filtering and Limits

# Filter by working group (comma-separated list)
TDC_WORKING_GROUP=SA2,RAN1
# TDC_WORKING_GROUP=SA2,RAN1

# Filter by sub-working group (comma-separated list)
TDC_SUB_GROUP=RAN1,RAN2
# TDC_SUB_GROUP=RAN1,RAN2

# Limit number of TDocs to crawl (default: None = no limit)
TDC_LIMIT_TDOCS=100
# TDC_LIMIT_TDOCS=100

# Limit total meetings to crawl (default: None = no limit)
TDC_LIMIT_MEETINGS=10
# TDC_LIMIT_MEETINGS=50

# Query date range - start date (YYYY, YYYY-MM, or YYYY-MM-DD format)
TDC_START_DATE=2024-01-01
# TDC_START_DATE=2024-01-01

# Query date range - end date (YYYY, YYYY-MM, or YYYY-MM-DD format)
TDC_END_DATE=2024-12-31
# TDC_END_DATE=2024-12-31

# Output Configuration

# Output format for query results (e.g., table, csv, json)
TDC_OUTPUT=table
# TDC_OUTPUT=table

# Logging

@@ -71,17 +71,17 @@ TDC_VERBOSE=false
# Controls caching behavior for all HTTP requests

# Time-to-live for cached HTTP responses in seconds (default: 7200 = 2 hours)
HTTP_CACHE_TTL=7200
# HTTP_CACHE_TTL=7200

# Whether to refresh TTL when a cached response is accessed (default: true)
# Set to "true", "1", "yes", or "on" to enable; anything else disables it
HTTP_CACHE_REFRESH_ON_ACCESS=true
# HTTP_CACHE_REFRESH_ON_ACCESS=true

# AI Configuration
# Note: AI module requires API keys for cloud providers. See docs/ai.md for details.

# Path to AI LanceDB store (default: <cache_dir>/.ai/lancedb)
TDC_AI_STORE_PATH=
# TDC_AI_STORE_PATH=

# LLM model in format <provider>/<model_name>
# Recommended: openrouter/openrouter/free (free tier, no subscription required)
@@ -89,12 +89,15 @@ TDC_AI_STORE_PATH=
TDC_AI_LLM_MODEL=openrouter/openrouter/free

# Optional custom base URL for LLM provider/proxy
TDC_AI_LLM_API_BASE=
# TDC_AI_LLM_API_BASE=

# Optional API key for LLM provider, will override default environment variable (e.g., OPENROUTER_API_KEY for OpenRouter)
# TDC_AI_LLM_API_KEY=

# Embedding model (HuggingFace sentence-transformers model ID)
# Default: sentence-transformers/all-MiniLM-L6-v2 (384 dimensions, popular and fast)
# See https://huggingface.co/models?library=sentence-transformers for alternatives
TDC_AI_EMBEDDING_MODEL=perplexity-ai/pplx-embed-context-v1-0.6b
TDC_AI_EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2

# Embedding backend for sentence-transformers (default: torch)
# Options: torch (default), onnx (faster inference), openvino (Intel hardware optimization)
@@ -116,6 +119,12 @@ TDC_AI_ABSTRACT_MAX_WORDS=250
# Parallel processing
TDC_AI_PARALLELISM=4

# Graph query level (simple|medium|advanced) - default: simple
# simple: Return count and list without synthesis
# medium: Parse query keywords, filter nodes, generate simple text summary
# advanced: Use LLM to synthesize answer from graph + embeddings (GraphRAG)
TDC_GRAPH_QUERY_LEVEL=simple

# Note: Never commit actual .env file to version control!
# Copy this file to .env and replace placeholders with your actual credentials and preferences.

+4 −1
Original line number Diff line number Diff line
@@ -19,8 +19,11 @@ dependencies = [
    "kreuzberg[all]>=4.0.0",
    "lancedb>=0.29.2",
    "litellm>=1.81.15",
    "sentence-transformers[openvino,onnx-gpu]>=2.7.0",
    "sentence-transformers[openvino,onnx,onnx-gpu]>=2.7.0",
    "tokenizers>=0.22.2",
    "optimum[openvino]"

    #"nvidia-cudnn-cu12", "nvidia-cuda-runtime-cu12", "nvidia-cufft-cu12"
]

[project.urls]
+25 −2
Original line number Diff line number Diff line
@@ -3,8 +3,8 @@
from __future__ import annotations

import os
from typing import Literal
from pathlib import Path
from typing import Literal

import litellm
from pydantic import Field, field_validator, model_validator
@@ -17,6 +17,10 @@ DEFAULT_LLM_MODEL = "openrouter/openrouter/free"

type Backend = Literal["torch", "onnx", "openvino"]

# Graph query level type
type GraphQueryLevel = Literal["simple", "medium", "advanced"]


def _env_int(name: str) -> int | None:
    value = os.getenv(name)
    if value is None or value == "":
@@ -104,6 +108,11 @@ class AiConfig(BaseConfigModel):
    abstract_max_words: int = Field(250, ge=1, description="Maximum abstract word count")
    parallelism: int = Field(4, ge=1, le=32, description="Concurrent TDoc processing")

    graph_query_level: GraphQueryLevel = Field(
        "simple",
        description="Level of graph query answer generation (simple|medium|advanced)",
    )

    @classmethod
    def from_env(cls, **overrides: str | int | Path | None) -> AiConfig:
        """Create config from environment variables."""
@@ -145,6 +154,9 @@ class AiConfig(BaseConfigModel):
        if parallelism is not None:
            data["parallelism"] = parallelism

        if graph_query_level := os.getenv("TDC_GRAPH_QUERY_LEVEL"):
            data["graph_query_level"] = graph_query_level

        data.update(overrides)
        # Filter out None values to let defaults apply
        filtered_data = {k: v for k, v in data.items() if v is not None}
@@ -190,5 +202,16 @@ class AiConfig(BaseConfigModel):
    def _validate_llm_model(cls, value: str) -> str:
        return _validate_model_identifier(value, "llm_model")

    @field_validator("graph_query_level")
    @classmethod
    def _validate_graph_query_level(cls, value: GraphQueryLevel | str) -> GraphQueryLevel:
        if isinstance(value, str):
            value = value.strip().lower()
            if value not in ["simple", "medium", "advanced"]:
                msg = "graph_query_level must be one of: simple, medium, advanced"
                raise ValueError(msg)
            return value  # type: ignore[return-value]
        return value


__all__ = ["AiConfig"]
__all__ = ["AiConfig", "Backend", "GraphQueryLevel"]
+56 −0
Original line number Diff line number Diff line
@@ -44,6 +44,14 @@ class GraphNodeType(StrEnum):
    CONCEPT = "concept"


class GraphQueryLevel(StrEnum):
    """Level of sophistication for graph query answer generation."""

    SIMPLE = "simple"  # Return count and list without synthesis
    MEDIUM = "medium"  # Parse query, filter, generate simple text summary
    ADVANCED = "advanced"  # Use LLM to synthesize answer from graph + embeddings (GraphRAG)


class GraphEdgeType(StrEnum):
    """Types of edges in the knowledge graph."""

@@ -54,6 +62,34 @@ class GraphEdgeType(StrEnum):
    AUTHORED_BY = "authored_by"
    MERGED_INTO = "merged_into"
    PRESENTED_AT = "presented_at"
    REVISION_OF = "revision_of"  # is_revision_of metadata relationship


class ProcessingFailureType(StrEnum):
    """Classification of processing failures to determine retry behavior.

    Permanent failures - do NOT retry:
    - NOT_FOUND_ONLINE: Document withdrawn or never existed
    - DOWNLOAD_FAILED: Could not download source file
    - BROKEN_SOURCE: Downloaded file is corrupt/invalid
    - CLASSIFICATION_FAILED: Could not identify main document

    Retryable failures - CAN retry in next run:
    - EXTRACTION_FAILED: DOCX to Markdown conversion failed
    - EMBEDDING_FAILED: Embedding generation failed
    - GRAPH_FAILED: Graph building failed
    """

    # Permanent failures
    NOT_FOUND_ONLINE = "not_found_online"
    DOWNLOAD_FAILED = "download_failed"
    BROKEN_SOURCE = "broken_source"
    CLASSIFICATION_FAILED = "classification_failed"

    # Retryable failures
    EXTRACTION_FAILED = "extraction_failed"
    EMBEDDING_FAILED = "embedding_failed"
    GRAPH_FAILED = "graph_failed"


class WorkspaceStatus(StrEnum):
@@ -176,10 +212,30 @@ class ProcessingStatus(BaseModel):
    graphed_at: datetime | None = Field(None, description="Timestamp when graphing completed")
    completed_at: datetime | None = Field(None, description="Timestamp when pipeline completed")
    error_message: str | None = Field(None, description="Error details for failed stage")
    failure_type: ProcessingFailureType | None = Field(None, description="Type of failure if permanent")
    source_hash: str | None = Field(None, description="Hash of source DOCX for change detection")
    keywords: list[str] | None = Field(None, description="Keywords extracted from document content")
    detected_language: str | None = Field(None, description="Primary language detected in document")

    @property
    def is_permanent_failure(self) -> bool:
        """Check if this status represents a permanent failure that should not be retried."""
        return self.failure_type is not None and self.failure_type in (
            ProcessingFailureType.NOT_FOUND_ONLINE,
            ProcessingFailureType.DOWNLOAD_FAILED,
            ProcessingFailureType.BROKEN_SOURCE,
            ProcessingFailureType.CLASSIFICATION_FAILED,
        )

    @property
    def is_retryable_failure(self) -> bool:
        """Check if this status represents a retryable failure."""
        return self.failure_type is not None and self.failure_type in (
            ProcessingFailureType.EXTRACTION_FAILED,
            ProcessingFailureType.EMBEDDING_FAILED,
            ProcessingFailureType.GRAPH_FAILED,
        )

    @field_validator("document_id")
    @classmethod
    def _normalize_document_id(cls, value: str) -> str:
+3 −2
Original line number Diff line number Diff line
@@ -95,7 +95,7 @@ class AiStorage:
                unscoped_record = dict(record)
                unscoped_record["document_id"] = _from_scoped_document_id(scoped_document_id)[1]
                # Handle NaN/NaT values from LanceDB
                for field in ["error_message", "source_hash", "detected_language"]:
                for field in ["error_message", "source_hash", "detected_language", "failure_type"]:
                    if field in unscoped_record and (unscoped_record[field] is None or _is_nan(unscoped_record[field])):
                        unscoped_record[field] = None
                # Handle NaT for optional datetime fields
@@ -121,7 +121,7 @@ class AiStorage:
            unscoped_record = dict(record)
            unscoped_record["document_id"] = unscoped_document_id
            # Handle NaN/NaT values from LanceDB
            for field in ["error_message", "source_hash", "detected_language"]:
            for field in ["error_message", "source_hash", "detected_language", "failure_type"]:
                if field in unscoped_record and (unscoped_record[field] is None or _is_nan(unscoped_record[field])):
                    unscoped_record[field] = None
            # Handle NaT for optional datetime fields
@@ -511,6 +511,7 @@ def _processing_status_schema() -> pa.Schema:
            pa.field("graphed_at", pa.timestamp("us"), nullable=True),
            pa.field("completed_at", pa.timestamp("us"), nullable=True),
            pa.field("error_message", pa.string(), nullable=True),
            pa.field("failure_type", pa.string(), nullable=True),
            pa.field("source_hash", pa.string(), nullable=True),
            pa.field("keywords", pa.list_(pa.string()), nullable=True),
            pa.field("detected_language", pa.string(), nullable=True),