Commit 34c31f00 authored by Jan Reimes's avatar Jan Reimes
Browse files

refactor(main): simplify entry point by removing asyncio.run

refactor(server): enhance server functionality with new search tools
* Add serialization and parsing functions for search responses
* Implement async search and term info retrieval tools

test(client): replace AsyncMock with Mock for raise_for_status
parent d8c75dc2
Loading
Loading
Loading
Loading
+1 −3
Original line number Diff line number Diff line
"""Entry point for running teddi_mcp as a module (python -m teddi_mcp)."""

import asyncio

from teddi_mcp.server import main

if __name__ == "__main__":
    asyncio.run(main())
    main()
+119 −3
Original line number Diff line number Diff line
"""FastMCP 3.0 server for TEDDI search."""

import logging
from typing import Any

from fastmcp import FastMCP

from teddi_mcp.client import TeddiClient
from teddi_mcp.models import SearchIn, SearchPattern, SearchRequest, SearchResponse, TechnicalBody

logger = logging.getLogger(__name__)

# Create FastMCP server
server = FastMCP("teddi-mcp")


def _serialize_response(request: SearchRequest, response: SearchResponse) -> dict[str, Any]:
    """Serialize a search response to a JSON-compatible payload."""
    return {
        "query": {
            "term": request.term,
            "search_in": request.search_in.value,
            "search_pattern": request.search_pattern.value,
            "technical_bodies": ([tb.value for tb in request.technical_bodies] if request.technical_bodies else None),
        },
        "total_count": response.total_count,
        "results": [
            {
                "term": result.term,
                "description": result.description,
                "documents": [
                    {
                        "technical_body": doc.technical_body,
                        "specification": doc.specification,
                        "url": doc.url,
                    }
                    for doc in result.documents
                ],
            }
            for result in response.results
        ],
    }


def _parse_technical_bodies(technical_bodies: str | None) -> list[TechnicalBody] | None:
    """Parse optional comma-separated technical body filter.

    Current implementation supports only the `all` value from the TechnicalBody enum.
    """
    if not technical_bodies:
        return None

    values = [v.strip().lower() for v in technical_bodies.split(",") if v.strip()]
    if not values:
        return None

    try:
        return [TechnicalBody(v) for v in values]
    except ValueError as exc:
        raise ValueError("Invalid technical_bodies value. Currently supported values: all") from exc


@server.tool
async def list_technical_bodies() -> dict[str, list[str]]:
    """List supported technical body filters for TEDDI search."""
    async with TeddiClient() as client:
        bodies = await client.get_available_technical_bodies()

    return {"technical_bodies": [tb.value for tb in bodies]}


@server.tool
async def search_teddi(
    term: str,
    search_in: str = SearchIn.BOTH.value,
    search_pattern: str = SearchPattern.ALL_OCCURRENCES.value,
    technical_bodies: str | None = None,
) -> dict[str, Any]:
    """Search TEDDI terms with optional scope and pattern filters.

    Args:
        term: Term or abbreviation to search.
        search_in: One of abbreviations, definitions, both.
        search_pattern: One of exactmatch, startingwith, endingwith, alloccurrences.
        technical_bodies: Optional comma-separated technical body filter.
    """
    try:
        request = SearchRequest(
            term=term,
            search_in=SearchIn(search_in.lower()),
            search_pattern=SearchPattern(search_pattern.lower()),
            technical_bodies=_parse_technical_bodies(technical_bodies),
        )
    except ValueError as exc:
        valid_search_in = ", ".join(v.value for v in SearchIn)
        valid_search_pattern = ", ".join(v.value for v in SearchPattern)
        raise ValueError(
            f"Invalid search arguments. search_in must be one of: {valid_search_in}; search_pattern must be one of: {valid_search_pattern}."
        ) from exc

    async with TeddiClient() as client:
        response = await client.search_terms(request)

    return _serialize_response(request, response)


@server.tool
async def get_term_info(
    term: str,
    technical_bodies: str | None = None,
) -> dict[str, Any]:
    """Get exact-match TEDDI entries for one term.

    This helper uses search_in=both and search_pattern=exactmatch.
    """
    request = SearchRequest(
        term=term,
        search_in=SearchIn.BOTH,
        search_pattern=SearchPattern.EXACT_MATCH,
        technical_bodies=_parse_technical_bodies(technical_bodies),
    )

    async with TeddiClient() as client:
        response = await client.search_terms(request)

    payload = _serialize_response(request, response)
    payload["exact_match"] = response.total_count > 0
    return payload


def main() -> None:
    """Main entry point for the MCP server."""
    logging.basicConfig(
@@ -23,6 +141,4 @@ def main() -> None:


if __name__ == "__main__":
    import asyncio

    asyncio.run(main())
    main()
+8 −5
Original line number Diff line number Diff line
"""Tests for TeddiClient."""

from unittest.mock import AsyncMock, patch
from unittest.mock import AsyncMock, Mock, patch

import pytest

@@ -32,7 +32,7 @@ class TestTeddiClient:
        mock_client = AsyncMock()
        mock_response = AsyncMock()
        mock_response.text = mock_html
        mock_response.raise_for_status = AsyncMock()
        mock_response.raise_for_status = Mock()
        mock_client.post.return_value = mock_response

        client = TeddiClient(client=mock_client)
@@ -60,7 +60,7 @@ class TestTeddiClient:
        mock_client = AsyncMock()
        mock_response = AsyncMock()
        mock_response.text = mock_html
        mock_response.raise_for_status = AsyncMock()
        mock_response.raise_for_status = Mock()
        mock_client.post.return_value = mock_response

        client = TeddiClient(client=mock_client)
@@ -93,6 +93,7 @@ class TestTeddiClient:
        mock_client = AsyncMock()
        mock_response = AsyncMock()
        mock_response.text = mock_html
        mock_response.raise_for_status = Mock()
        mock_client.post.return_value = mock_response

        client = TeddiClient(client=mock_client)
@@ -117,7 +118,9 @@ class TestTeddiClient:
        """Test using TeddiClient as async context manager."""
        with patch("src.teddi_mcp.client.create_cached_teddi_async_client") as mock_create:
            mock_client = AsyncMock()
            mock_client.post.return_value = AsyncMock(text="<table></table>")
            mock_response = AsyncMock(text="<table></table>")
            mock_response.raise_for_status = Mock()
            mock_client.post.return_value = mock_response
            mock_create.return_value = mock_client

            async with TeddiClient() as client:
@@ -132,7 +135,7 @@ class TestTeddiClient:
        mock_client = AsyncMock()
        mock_response = AsyncMock()
        mock_response.content = b"PDF content"
        mock_response.raise_for_status = AsyncMock()
        mock_response.raise_for_status = Mock()
        mock_client.get.return_value = mock_response

        client = TeddiClient(client=mock_client)