Commit d8b7cf81 authored by Jan Reimes's avatar Jan Reimes
Browse files

feat(ai): integrate litellm for enhanced LLM functionality

* Added litellm dependency for improved LLM operations.
* Updated AiConfig to utilize litellm's supported providers.
* Refactored LiteLLMClient to initialize without loading client dynamically.
* Simplified embedding model loading in EmbeddingModelWrapper.
* Cleaned up unused code and imports in summarize.py and embeddings.py.
parent e6997dba
Loading
Loading
Loading
Loading
+10 −28
Original line number Diff line number Diff line
@@ -5,6 +5,7 @@ from __future__ import annotations
import os
from pathlib import Path

import litellm
from pydantic import Field, field_validator, model_validator

from tdoc_crawler.models import BaseConfigModel
@@ -12,31 +13,6 @@ from tdoc_crawler.models import BaseConfigModel
DEFAULT_EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
DEFAULT_LLM_MODEL = "openrouter/openrouter/free"

LITELLM_PROVIDER_ALLOWLIST = {
    "openai",
    "anthropic",
    "azure",
    "google",
    "vertex_ai",
    "cohere",
    "huggingface",
    "ollama",
    "groq",
    "mistral",
    "bedrock",
    "replicate",
    "together_ai",
    # Cloud platforms and aggregators
    "openrouter",  # Multi-provider aggregator
    "github_copilot",  # GitHub Copilot (Azure-based)
    "nvidia",  # NVIDIA NIM
    "sambanova",  # SambaNova Cloud
    "fireworks_ai",  # Fireworks AI
    "anyscale",  # Anyscale Endpoints
    "perplexity",  # Perplexity API
    "deepinfra",  # DeepInfra
}


def _env_int(name: str) -> int | None:
    value = os.getenv(name)
@@ -60,8 +36,14 @@ def _validate_model_identifier(value: str, field_name: str) -> str:
    if not model_name_normalized:
        msg = f"{field_name} model_name segment cannot be empty"
        raise ValueError(msg)
    if provider_normalized not in LITELLM_PROVIDER_ALLOWLIST:
        msg = f"{field_name} provider '{provider}' is not in supported provider allowlist: {sorted(LITELLM_PROVIDER_ALLOWLIST)}"

    supported_providers = litellm.LITELLM_CHAT_PROVIDERS

    if provider_normalized not in supported_providers:
        msg = (
            f"{field_name} provider '{provider}' is not supported by litellm. "
            f"See https://docs.litellm.ai/docs/providers for the full list of {len(supported_providers)} supported providers."
        )
        raise ValueError(msg)

    return f"{provider_normalized}/{model_name_normalized}"
@@ -184,4 +166,4 @@ class AiConfig(BaseConfigModel):
        return _validate_model_identifier(value, "llm_model")


__all__ = ["LITELLM_PROVIDER_ALLOWLIST", "AiConfig"]
__all__ = ["AiConfig"]
+4 −15
Original line number Diff line number Diff line
@@ -3,13 +3,14 @@
from __future__ import annotations

import hashlib
import importlib
import logging
import re
from collections.abc import Sequence
from pathlib import Path
from typing import Any, cast

from sentence_transformers import SentenceTransformer

from tdoc_crawler.ai.config import AiConfig
from tdoc_crawler.ai.models import DocumentChunk
from tdoc_crawler.ai.operations.workspaces import normalize_workspace_name
@@ -61,18 +62,12 @@ class EmbeddingModelWrapper:
        Returns:
            List of embedding vectors.
        """
        if self._model is None:
            msg = "sentence-transformers not installed"
            raise RuntimeError(msg)

        embeddings = self._model.encode(texts, normalize_embeddings=True)
        return embeddings.tolist()

    @property
    def dimension(self) -> int:
        """Return embedding dimension."""
        if self._model is None:
            return DEFAULT_DIMENSION
        # Get dimension from the actual model
        try:
            return self._model.get_sentence_embedding_dimension()
@@ -86,14 +81,8 @@ class EmbeddingModelWrapper:

    def _load_model(self) -> None:
        """Load the sentence-transformers model."""
        try:
            sentence_transformers = importlib.import_module("sentence_transformers")
            sentence_transformer_class = sentence_transformers.SentenceTransformer
            self._model = sentence_transformer_class(DEFAULT_MODEL)
        self._model = SentenceTransformer(DEFAULT_MODEL)
        logger.info(f"Loaded embedding model: {DEFAULT_MODEL}")
        except ImportError:
            logger.warning("sentence-transformers not installed")
            self._model = None


def _chunk_by_headings(markdown: str) -> list[dict[str, str]]:
+4 −17
Original line number Diff line number Diff line
@@ -3,12 +3,13 @@
from __future__ import annotations

import hashlib
import importlib
import json
import logging
import os
import re

import litellm

from tdoc_crawler.ai.config import AiConfig
from tdoc_crawler.ai.models import DocumentSummary, LlmConfigError, SummarizeResult
from tdoc_crawler.ai.operations.workspaces import normalize_workspace_name
@@ -92,8 +93,7 @@ class LiteLLMClient:
        return cls._instance

    def __init__(self) -> None:
        self._client = None
        self._load_client()
        logger.info("LiteLLM client initialized")

    def complete(
        self,
@@ -113,15 +113,11 @@ class LiteLLMClient:
        Returns:
            Generated text.
        """
        if self._client is None:
            msg = "litellm not installed"
            raise RuntimeError(msg)

        try:
            # Check for TDC_AI_LLM_API_KEY - takes precedence over provider-specific keys
            api_key = os.environ.get("TDC_AI_LLM_API_KEY")

            response = self._client.completion(
            response = litellm.completion(
                model=model or AiConfig().llm_model,
                messages=[
                    {"role": "system", "content": system_prompt},
@@ -135,15 +131,6 @@ class LiteLLMClient:
            logger.error(f"LLM completion failed: {e}")
            raise

    def _load_client(self) -> None:
        """Load the litellm client."""
        try:
            self._client = importlib.import_module("litellm")
            logger.info("Loaded litellm client")
        except ImportError:
            logger.warning("litellm not installed")
            self._client = None


def _should_skip_summary(
    tdoc_id: str,
+0 −1
Original line number Diff line number Diff line
@@ -3,7 +3,6 @@
from pathlib import Path

from tdoc_crawler.database.specs import SpecDatabase
from tdoc_crawler.specs.models import Specification, SpecificationDownload, SpecificationSourceRecord, SpecificationVersion
from tdoc_crawler.specs.models import Specification, SpecificationDownload, SpecificationSourceRecord, SpecificationVersion, SpecQueryFilters