Source code for agenix.core.llm

"""Unified LLM interface supporting multiple providers."""

import json
import os
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, AsyncIterator, Callable, Dict, List, Optional

from .messages import (AssistantMessage, Message, TextContent, ToolCall,
                       ToolResultMessage, Usage, UserMessage)


[docs] @dataclass class StreamEvent: """LLM stream event.""" type: str delta: str = "" tool_call: Optional[ToolCall] = None
[docs] class LLMProvider(ABC): """Abstract LLM provider interface."""
[docs] def __init__(self, api_key: Optional[str] = None, base_url: Optional[str] = None): self.api_key = api_key self.base_url = base_url
[docs] @abstractmethod async def stream( self, model: str, messages: List[Message], system_prompt: Optional[str] = None, tools: Optional[List[Dict[str, Any]]] = None, max_tokens: int = 4096, ) -> AsyncIterator[StreamEvent]: """Stream LLM responses.""" pass
[docs] @abstractmethod async def complete( self, model: str, messages: List[Message], system_prompt: Optional[str] = None, tools: Optional[List[Dict[str, Any]]] = None, max_tokens: int = 4096, ) -> AssistantMessage: """Complete LLM request (non-streaming).""" pass
def _messages_to_dict(self, messages: List[Message]) -> List[Dict[str, Any]]: """Convert messages to API format.""" result = [] for msg in messages: if isinstance(msg, UserMessage): result.append({ "role": "user", "content": msg.content if isinstance(msg.content, str) else self._format_content(msg.content) }) elif isinstance(msg, AssistantMessage): entry = {"role": "assistant"} if isinstance(msg.content, str): entry["content"] = msg.content else: # Has tool calls if msg.tool_calls: entry["tool_calls"] = [ { "id": tc.id, "type": "function", "function": { "name": tc.name, "arguments": json.dumps(tc.arguments) } } for tc in msg.tool_calls ] # Extract text content text_parts = [ c.text for c in msg.content if isinstance(c, TextContent)] if text_parts: entry["content"] = "\n".join(text_parts) elif not msg.tool_calls: entry["content"] = "" result.append(entry) elif isinstance(msg, ToolResultMessage): result.append({ "role": "tool", "tool_call_id": msg.tool_call_id, "content": msg.content if isinstance(msg.content, str) else self._format_content(msg.content) }) return result def _format_content(self, content: List[Any]) -> str: """Format mixed content as string.""" parts = [] for item in content: if isinstance(item, TextContent): parts.append(item.text) elif hasattr(item, 'text'): parts.append(item.text) return "\n".join(parts)
[docs] class OpenAIProvider(LLMProvider): """OpenAI/compatible API provider."""
[docs] def __init__(self, api_key: Optional[str] = None, base_url: Optional[str] = None): super().__init__( api_key=api_key or os.getenv("OPENAI_API_KEY"), base_url=base_url or os.getenv( "OPENAI_BASE_URL") or "https://api.openai.com/v1" ) # Validate API key if not self.api_key: raise ValueError( "OpenAI API key not found. Please set OPENAI_API_KEY environment variable " "or pass api_key parameter." )
[docs] async def stream( self, model: str, messages: List[Message], system_prompt: Optional[str] = None, tools: Optional[List[Dict[str, Any]]] = None, max_tokens: int = 4096, ) -> AsyncIterator[StreamEvent]: """Stream OpenAI responses.""" try: import openai except ImportError: raise ImportError( "openai package required. Install with: pip install openai") try: client = openai.AsyncOpenAI( api_key=self.api_key, base_url=self.base_url) except Exception as e: raise ValueError(f"Failed to create OpenAI client: {e}") try: api_messages = [] if system_prompt: api_messages.append( {"role": "system", "content": system_prompt}) api_messages.extend(self._messages_to_dict(messages)) kwargs = { "model": model, "messages": api_messages, "max_tokens": max_tokens, "stream": True, } if tools: kwargs["tools"] = [self._convert_tool(t) for t in tools] stream = await client.chat.completions.create(**kwargs) # Accumulate tool calls during streaming tool_calls_accumulator = {} # index -> {id, name, arguments_str} async for chunk in stream: if not chunk.choices: continue delta = chunk.choices[0].delta if delta.content: yield StreamEvent(type="text_delta", delta=delta.content) if delta.tool_calls: for tc in delta.tool_calls: # Tool calls come in increments, need to accumulate # Use index as key since id might not be in every delta tc_index = tc.index if tc.index is not None else 0 tc_id = tc.id or f"call_{tc_index}" if tc_index not in tool_calls_accumulator: tool_calls_accumulator[tc_index] = { "id": tc_id, "name": "", "arguments": "" } # Update id if provided if tc.id: tool_calls_accumulator[tc_index]["id"] = tc.id if tc.function: if tc.function.name: tool_calls_accumulator[tc_index]["name"] = tc.function.name if tc.function.arguments: tool_calls_accumulator[tc_index]["arguments"] += tc.function.arguments # After stream ends, yield complete tool calls for tc_data in tool_calls_accumulator.values(): try: arguments = json.loads( tc_data["arguments"]) if tc_data["arguments"] else {} except json.JSONDecodeError: # If JSON parsing fails, treat as empty arguments = {} yield StreamEvent( type="tool_call", tool_call=ToolCall( id=tc_data["id"], name=tc_data["name"], arguments=arguments ) ) finally: # Close the client properly await client.close()
[docs] async def complete( self, model: str, messages: List[Message], system_prompt: Optional[str] = None, tools: Optional[List[Dict[str, Any]]] = None, max_tokens: int = 4096, ) -> AssistantMessage: """Complete OpenAI request.""" try: import openai except ImportError: raise ImportError( "openai package required. Install with: pip install openai") try: client = openai.AsyncOpenAI( api_key=self.api_key, base_url=self.base_url) except Exception as e: raise ValueError(f"Failed to create OpenAI client: {e}") try: api_messages = [] if system_prompt: api_messages.append( {"role": "system", "content": system_prompt}) api_messages.extend(self._messages_to_dict(messages)) kwargs = { "model": model, "messages": api_messages, "max_tokens": max_tokens, } if tools: kwargs["tools"] = [self._convert_tool(t) for t in tools] response = await client.chat.completions.create(**kwargs) choice = response.choices[0] # Build assistant message content_parts = [] tool_calls = [] if choice.message.content: content_parts.append(TextContent(text=choice.message.content)) if choice.message.tool_calls: for tc in choice.message.tool_calls: tool_calls.append(ToolCall( id=tc.id, name=tc.function.name, arguments=json.loads(tc.function.arguments) )) usage = None if response.usage: usage = Usage( input_tokens=response.usage.prompt_tokens, output_tokens=response.usage.completion_tokens, ) return AssistantMessage( content=content_parts, tool_calls=tool_calls, model=model, usage=usage, stop_reason=choice.finish_reason, ) finally: # Close the client properly await client.close()
def _convert_tool(self, tool: Dict[str, Any]) -> Dict[str, Any]: """Convert tool definition to OpenAI format.""" return { "type": "function", "function": { "name": tool["name"], "description": tool["description"], "parameters": tool.get("parameters", {}) } }
[docs] class AnthropicProvider(LLMProvider): """Anthropic Claude API provider."""
[docs] def __init__(self, api_key: Optional[str] = None, base_url: Optional[str] = None): super().__init__( api_key=api_key or os.getenv("ANTHROPIC_API_KEY"), base_url=base_url ) # Validate API key if not self.api_key: raise ValueError( "Anthropic API key not found. Please set ANTHROPIC_API_KEY environment variable " "or pass api_key parameter." )
[docs] async def stream( self, model: str, messages: List[Message], system_prompt: Optional[str] = None, tools: Optional[List[Dict[str, Any]]] = None, max_tokens: int = 4096, ) -> AsyncIterator[StreamEvent]: """Stream Anthropic responses.""" try: import anthropic except ImportError: raise ImportError( "anthropic package required. Install with: pip install anthropic") try: kwargs = {"api_key": self.api_key} if self.base_url: kwargs["base_url"] = self.base_url client = anthropic.AsyncAnthropic(**kwargs) except Exception as e: raise ValueError(f"Failed to create Anthropic client: {e}") kwargs = { "model": model, "messages": self._anthropic_messages(messages), "max_tokens": max_tokens, "stream": True, } if system_prompt: kwargs["system"] = system_prompt if tools: kwargs["tools"] = [self._convert_tool(t) for t in tools] async with client.messages.stream(**kwargs) as stream: async for event in stream: if event.type == "content_block_delta": if hasattr(event.delta, "text"): yield StreamEvent(type="text_delta", delta=event.delta.text) elif hasattr(event.delta, "partial_json"): # Tool call in progress pass elif event.type == "content_block_start": if hasattr(event.content_block, "type") and event.content_block.type == "tool_use": # Tool call started pass
[docs] async def complete( self, model: str, messages: List[Message], system_prompt: Optional[str] = None, tools: Optional[List[Dict[str, Any]]] = None, max_tokens: int = 4096, ) -> AssistantMessage: """Complete Anthropic request.""" try: import anthropic except ImportError: raise ImportError( "anthropic package required. Install with: pip install anthropic") try: kwargs = {"api_key": self.api_key} if self.base_url: kwargs["base_url"] = self.base_url client = anthropic.AsyncAnthropic(**kwargs) except Exception as e: raise ValueError(f"Failed to create Anthropic client: {e}") kwargs = { "model": model, "messages": self._anthropic_messages(messages), "max_tokens": max_tokens, } if system_prompt: kwargs["system"] = system_prompt if tools: kwargs["tools"] = [self._convert_tool(t) for t in tools] response = await client.messages.create(**kwargs) # Parse response content_parts = [] tool_calls = [] for block in response.content: if block.type == "text": content_parts.append(TextContent(text=block.text)) elif block.type == "tool_use": tool_calls.append(ToolCall( id=block.id, name=block.name, arguments=block.input )) usage = None if response.usage: usage = Usage( input_tokens=response.usage.input_tokens, output_tokens=response.usage.output_tokens, ) return AssistantMessage( content=content_parts, tool_calls=tool_calls, model=model, usage=usage, stop_reason=response.stop_reason, )
def _anthropic_messages(self, messages: List[Message]) -> List[Dict[str, Any]]: """Convert messages to Anthropic format.""" result = [] for msg in messages: if isinstance(msg, UserMessage): result.append({ "role": "user", "content": msg.content if isinstance(msg.content, str) else self._format_content(msg.content) }) elif isinstance(msg, AssistantMessage): content = [] # Add text content if isinstance(msg.content, str): content.append({"type": "text", "text": msg.content}) else: for item in msg.content: if isinstance(item, TextContent): content.append({"type": "text", "text": item.text}) # Add tool calls for tc in msg.tool_calls: content.append({ "type": "tool_use", "id": tc.id, "name": tc.name, "input": tc.arguments }) result.append({"role": "assistant", "content": content}) elif isinstance(msg, ToolResultMessage): result.append({ "role": "user", "content": [ { "type": "tool_result", "tool_use_id": msg.tool_call_id, "content": msg.content if isinstance(msg.content, str) else self._format_content(msg.content), "is_error": msg.is_error } ] }) return result def _convert_tool(self, tool: Dict[str, Any]) -> Dict[str, Any]: """Convert tool definition to Anthropic format.""" return { "name": tool["name"], "description": tool["description"], "input_schema": tool.get("parameters", {}) }
# Registry of providers PROVIDERS = { "openai": OpenAIProvider, "anthropic": AnthropicProvider, }
[docs] def get_provider(provider_name: str, **kwargs) -> LLMProvider: """Get LLM provider by name.""" provider_class = PROVIDERS.get(provider_name) if not provider_class: raise ValueError(f"Unknown provider: {provider_name}") return provider_class(**kwargs)