Commit 599fb5e7 authored by Jan Reimes's avatar Jan Reimes
Browse files

feat(cli): add cache directory option for CLI commands

* Introduced CacheDirOption to specify cache directory in CLI.
* Updated source, title, and agenda options to accept lists for repeatable patterns.
* Removed unused embedding dimension constants from ProviderConfig.
* Enhanced embedding dimension detection in TDocRAG for better flexibility.
parent 257b015a
Loading
Loading
Loading
Loading
+16 −12
Original line number Diff line number Diff line
@@ -9,6 +9,10 @@ import typer

# Common
JsonOutputOption = Annotated[bool, typer.Option("--json", help="Output as JSON")]
CacheDirOption = Annotated[
    Path | None,
    typer.Option("--cache-dir", "-c", help="Cache directory", envvar="TDC_CACHE_DIR"),
]

# Summarize
SummarizeDocumentArgument = Annotated[str, typer.Argument(help="Document ID to summarize")]
@@ -59,26 +63,26 @@ EndDateOption = Annotated[
    typer.Option("--end-date", help="Filter until ISO timestamp (YYYY-MM-DD or YYYY-MM-DDTHH:MM:SS)"),
]
SourcePatternOption = Annotated[
    str | None,
    typer.Option("--source", help="Glob/regex pattern for source field"),
    list[str] | None,
    typer.Option("--source", help="Glob pattern for source field (repeatable, values are OR'd)"),
]
SourcePatternExcludeOption = Annotated[
    str | None,
    typer.Option("--source-ex", help="Pattern to exclude source field"),
    list[str] | None,
    typer.Option("--source-ex", help="Glob pattern to exclude source field (repeatable)"),
]
TitlePatternOption = Annotated[
    str | None,
    typer.Option("--title", help="Glob/regex pattern for title field"),
    list[str] | None,
    typer.Option("--title", help="Glob pattern for title field (repeatable, values are OR'd)"),
]
TitlePatternExcludeOption = Annotated[
    str | None,
    typer.Option("--title-ex", help="Pattern to exclude title field"),
    list[str] | None,
    typer.Option("--title-ex", help="Glob pattern to exclude title field (repeatable)"),
]
AgendaPatternOption = Annotated[
    str | None,
    typer.Option("--agenda", help="Glob/regex pattern for agenda field"),
    list[str] | None,
    typer.Option("--agenda", help="Glob pattern for agenda field (repeatable, values are OR'd)"),
]
AgendaPatternExcludeOption = Annotated[
    str | None,
    typer.Option("--agenda-ex", help="Pattern to exclude agenda field"),
    list[str] | None,
    typer.Option("--agenda-ex", help="Glob pattern to exclude agenda field (repeatable)"),
]
+14 −7
Original line number Diff line number Diff line
@@ -17,7 +17,7 @@ import typer
from rich.progress import Progress, SpinnerColumn, TaskProgressColumn, TextColumn, TimeElapsedColumn
from rich.table import Table

from tdoc_crawler.config import resolve_cache_manager
from tdoc_crawler.config import CacheManager, resolve_cache_manager
from tdoc_crawler.database import SpecDatabase, TDocDatabase
from tdoc_crawler.logging import get_console, get_logger
from tdoc_crawler.models.base import OutputFormat, SortOrder
@@ -46,6 +46,7 @@ from threegpp_ai import (
from threegpp_ai.args import (
    AgendaPatternExcludeOption,
    AgendaPatternOption,
    CacheDirOption,
    ConvertDocumentArgument,
    ConvertForceOption,
    ConvertOutputOption,
@@ -87,6 +88,12 @@ console = get_console()
_logger = get_logger(__name__)


@app.callback()
def _app_init(cache_dir: CacheDirOption = None) -> None:
    """Register a CacheManager so all sub-commands can resolve file paths."""
    CacheManager(cache_dir).register(force=True)


def _resolve_workspace_name(workspace: str | None) -> str:
    if workspace:
        return normalize_workspace_name(workspace)
@@ -106,12 +113,12 @@ def _resolve_workspace_items(
    source_kind: SourceKind,
    start_date: str | None,
    end_date: str | None,
    source: str | None,
    source_ex: str | None,
    title: str | None,
    title_ex: str | None,
    agenda: str | None,
    agenda_ex: str | None,
    source: list[str] | None,
    source_ex: list[str] | None,
    title: list[str] | None,
    title_ex: list[str] | None,
    agenda: list[str] | None,
    agenda_ex: list[str] | None,
    limit: int | None,
) -> list[str]:
    if items is not None:
+5 −21
Original line number Diff line number Diff line
@@ -40,7 +40,6 @@ class ProviderConfig:

    complete_func: Callable | None = None
    embed_func: Callable | None = None
    default_dim: int = 1024


PROVIDERS: dict[str, ProviderConfig] = {
@@ -52,15 +51,6 @@ PROVIDERS: dict[str, ProviderConfig] = {
}

PROVIDER_ALIASES = {"zai": "zhipu"}
EMBEDDING_DIMENSIONS = {
    "qwen3-embedding:0.6b": 1024,
    "nomic-embed-text": 768,
    "mxbai-embed-large": 1024,
    "all-MiniLM-L6-v2": 384,
    "all-mpnet-base-v2": 768,
}


def _resolve_provider(provider: str) -> str:
    """Resolve provider alias to canonical name."""
    resolved = PROVIDER_ALIASES.get(provider, provider)
@@ -83,15 +73,6 @@ def _get_provider(provider: str) -> ProviderConfig:
    return PROVIDERS[_resolve_provider(provider)]


def _get_embedding_dimension(model_name: str, provider: str = "") -> int:
    """Get embedding dimension for a model."""
    if model_name in EMBEDDING_DIMENSIONS:
        return EMBEDDING_DIMENSIONS[model_name]
    if provider:
        return _get_provider(provider).default_dim
    return 1024


class TDocRAG:
    """Thin wrapper around LightRAG with Ollama and optional pg0 support.

@@ -189,14 +170,17 @@ class TDocRAG:
        embed_config = _get_provider(embed_provider)
        llm_config = _get_provider(llm_provider)

        embed_dim = _get_embedding_dimension(embed_model_name, embed_provider)

        embed_kwargs = self._build_provider_kwargs(embed_provider, embed_model_name, is_embedding=True)
        llm_kwargs = self._build_provider_kwargs(llm_provider, llm_model_name)

        async def wrapped_embed_func(texts: list[str], **kwargs: Any) -> list[list[float]]:
            return await embed_config.embed_func(texts, **embed_kwargs, **kwargs)  # type: ignore[call-arg]

        # Probe the model to detect the actual embedding dimension — never assume.
        probe = await wrapped_embed_func(["probe"])
        embed_dim = len(probe[0])
        logger.info("Detected embedding dimension: %d for model '%s'", embed_dim, self.config.embedding.model)

        embedding_func = EmbeddingFunc(embedding_dim=embed_dim, func=wrapped_embed_func)

        self._rag = LightRAG(