Source code for agenix.core.llm

"""Unified LLM interface supporting multiple providers."""

import json
import os
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, AsyncIterator, Callable, Dict, List, Optional, Union

from .messages import (AssistantMessage, ImageContent, Message, TextContent,
                       ToolCall, ToolResultMessage, Usage, UserMessage)


[docs] @dataclass class StreamEvent: """LLM stream event.""" type: str delta: str = "" tool_call: Optional[ToolCall] = None finish_reason: Optional[str] = None # "stop", "length", "tool_calls", "content_filter"
[docs] class LLMProvider(ABC): """Abstract LLM provider interface."""
[docs] def __init__(self, api_key: Optional[str] = None, base_url: Optional[str] = None): self.api_key = api_key self.base_url = base_url
[docs] @abstractmethod async def stream( self, model: str, messages: List[Message], system_prompt: Optional[str] = None, tools: Optional[List[Dict[str, Any]]] = None, max_tokens: int = 4096, ) -> AsyncIterator[StreamEvent]: """Stream LLM responses.""" pass
[docs] @abstractmethod async def complete( self, model: str, messages: List[Message], system_prompt: Optional[str] = None, tools: Optional[List[Dict[str, Any]]] = None, max_tokens: int = 4096, ) -> AssistantMessage: """Complete LLM request (non-streaming).""" pass
def _messages_to_dict(self, messages: List[Message]) -> List[Dict[str, Any]]: """Convert messages to API format.""" result = [] for msg in messages: if isinstance(msg, UserMessage): result.append({ "role": "user", "content": msg.content if isinstance(msg.content, str) else self._format_content(msg.content) }) elif isinstance(msg, AssistantMessage): entry = {"role": "assistant"} if isinstance(msg.content, str): entry["content"] = msg.content else: # Extract text content text_parts = [ c.text for c in msg.content if isinstance(c, TextContent)] if text_parts: entry["content"] = "\n".join(text_parts) else: # No text content - set empty string to satisfy API requirements entry["content"] = "" # Add tool calls if present if msg.tool_calls: entry["tool_calls"] = [ { "id": tc.id, "type": "function", "function": { "name": tc.name, "arguments": json.dumps(tc.arguments) } } for tc in msg.tool_calls ] result.append(entry) elif isinstance(msg, ToolResultMessage): formatted_content = msg.content if isinstance(msg.content, list): formatted_content = self._format_content(msg.content) # If _format_content returns a list (contains images), we need to: # 1. First send a tool message with text only (to satisfy OpenAI's requirement) # 2. Then send a user message with images if isinstance(formatted_content, list): # Extract text parts for tool response text_parts = [item["text"] for item in formatted_content if item["type"] == "text"] tool_response = "\n".join(text_parts) if text_parts else "[Image content]" # Add tool message (required by OpenAI API) result.append({ "role": "tool", "tool_call_id": msg.tool_call_id, "content": tool_response }) # Add user message with images result.append({ "role": "user", "content": formatted_content }) continue result.append({ "role": "tool", "tool_call_id": msg.tool_call_id, "content": formatted_content if isinstance(formatted_content, str) else str(formatted_content) }) return result def _format_content(self, content: List[Any]) -> Union[str, List[Dict]]: """Format mixed content, preserving images.""" from .messages import TextContent, ImageContent # Check if content is text-only (more generic than has_images) is_text_only = all(isinstance(item, TextContent) for item in content) if is_text_only: # Text-only: return simple string for backward compatibility return "\n".join([item.text for item in content]) # Rich content: build list of content blocks result = [] for item in content: if isinstance(item, TextContent): result.append({ "type": "text", "text": item.text }) elif isinstance(item, ImageContent): # OpenAI format: data URL with base64 result.append({ "type": "image_url", "image_url": { "url": f"data:{item.source['media_type']};base64,{item.source['data']}" } }) else: # Fallback for unknown content types: convert to text # This allows future extensions without breaking result.append({ "type": "text", "text": str(item) if not hasattr(item, 'text') else item.text }) return result
[docs] class OpenAIProvider(LLMProvider): """OpenAI/compatible API provider."""
[docs] def __init__(self, api_key: Optional[str] = None, base_url: Optional[str] = None): super().__init__( api_key=api_key or os.getenv("OPENAI_API_KEY"), base_url=base_url or os.getenv( "OPENAI_BASE_URL") or "https://api.openai.com/v1" ) # Validate API key if not self.api_key: raise ValueError( "OpenAI API key not found. Please set OPENAI_API_KEY environment variable " "or pass api_key parameter." )
[docs] async def stream( self, model: str, messages: List[Message], system_prompt: Optional[str] = None, tools: Optional[List[Dict[str, Any]]] = None, max_tokens: int = 4096, ) -> AsyncIterator[StreamEvent]: """Stream OpenAI responses.""" try: import openai except ImportError: raise ImportError( "openai package required. Install with: pip install openai") try: client = openai.AsyncOpenAI( api_key=self.api_key, base_url=self.base_url) except Exception as e: raise ValueError(f"Failed to create OpenAI client: {e}") try: api_messages = [] if system_prompt: api_messages.append( {"role": "system", "content": system_prompt}) api_messages.extend(self._messages_to_dict(messages)) kwargs = { "model": model, "messages": api_messages, "max_tokens": max_tokens, "stream": True, } if tools: kwargs["tools"] = [self._convert_tool(t) for t in tools] stream = await client.chat.completions.create(**kwargs) # Accumulate tool calls during streaming tool_calls_accumulator = {} # index -> {id, name, arguments_str} finish_reason = None # Capture finish reason from last chunk async for chunk in stream: if not chunk.choices: continue choice = chunk.choices[0] delta = choice.delta # Capture finish_reason when available if choice.finish_reason: finish_reason = choice.finish_reason if delta.content: yield StreamEvent(type="text_delta", delta=delta.content) if delta.tool_calls: for tc in delta.tool_calls: # Tool calls come in increments, need to accumulate # Use index as key since id might not be in every delta tc_index = tc.index if tc.index is not None else 0 tc_id = tc.id or f"call_{tc_index}" if tc_index not in tool_calls_accumulator: tool_calls_accumulator[tc_index] = { "id": tc_id, "name": "", "arguments": "" } # Update id if provided if tc.id: tool_calls_accumulator[tc_index]["id"] = tc.id if tc.function: if tc.function.name: tool_calls_accumulator[tc_index]["name"] = tc.function.name if tc.function.arguments: tool_calls_accumulator[tc_index]["arguments"] += tc.function.arguments # After stream ends, yield complete tool calls for tc_data in tool_calls_accumulator.values(): # Parse arguments - no repair, let LLM retry on failure if not tc_data["arguments"] or not tc_data["arguments"].strip(): # Empty arguments arguments = {} else: try: arguments = json.loads(tc_data["arguments"]) except json.JSONDecodeError as e: # JSON parsing failed - return empty dict, LLM will see error and retry import sys args_preview = tc_data["arguments"][:100] if len(tc_data["arguments"]) > 100 else tc_data["arguments"] print(f"Warning: Invalid JSON for tool '{tc_data['name']}': {args_preview}", file=sys.stderr) print(f" Error: {e}", file=sys.stderr) arguments = {} yield StreamEvent( type="tool_call", tool_call=ToolCall( id=tc_data["id"], name=tc_data["name"], arguments=arguments ) ) # Yield finish_reason as final event if finish_reason: yield StreamEvent(type="finish", finish_reason=finish_reason) except Exception as e: # Re-raise to be handled by caller raise finally: # Close the client properly await client.close()
[docs] async def complete( self, model: str, messages: List[Message], system_prompt: Optional[str] = None, tools: Optional[List[Dict[str, Any]]] = None, max_tokens: int = 4096, ) -> AssistantMessage: """Complete OpenAI request.""" try: import openai except ImportError: raise ImportError( "openai package required. Install with: pip install openai") try: client = openai.AsyncOpenAI( api_key=self.api_key, base_url=self.base_url) except Exception as e: raise ValueError(f"Failed to create OpenAI client: {e}") try: api_messages = [] if system_prompt: api_messages.append( {"role": "system", "content": system_prompt}) api_messages.extend(self._messages_to_dict(messages)) kwargs = { "model": model, "messages": api_messages, "max_tokens": max_tokens, } if tools: kwargs["tools"] = [self._convert_tool(t) for t in tools] response = await client.chat.completions.create(**kwargs) choice = response.choices[0] # Build assistant message content_parts = [] tool_calls = [] if choice.message.content: content_parts.append(TextContent(text=choice.message.content)) if choice.message.tool_calls: for tc in choice.message.tool_calls: tool_calls.append(ToolCall( id=tc.id, name=tc.function.name, arguments=json.loads(tc.function.arguments) )) usage = None if response.usage: usage = Usage( input_tokens=response.usage.prompt_tokens, output_tokens=response.usage.completion_tokens, ) return AssistantMessage( content=content_parts, tool_calls=tool_calls, model=model, usage=usage, stop_reason=choice.finish_reason, ) finally: # Close the client properly await client.close()
def _convert_tool(self, tool: Dict[str, Any]) -> Dict[str, Any]: """Convert tool definition to OpenAI format.""" return { "type": "function", "function": { "name": tool["name"], "description": tool["description"], "parameters": tool.get("parameters", {}) } }
[docs] class AnthropicProvider(LLMProvider): """Anthropic Claude API provider."""
[docs] def __init__(self, api_key: Optional[str] = None, base_url: Optional[str] = None): super().__init__( api_key=api_key or os.getenv("ANTHROPIC_API_KEY"), base_url=base_url ) # Validate API key if not self.api_key: raise ValueError( "Anthropic API key not found. Please set ANTHROPIC_API_KEY environment variable " "or pass api_key parameter." )
[docs] async def stream( self, model: str, messages: List[Message], system_prompt: Optional[str] = None, tools: Optional[List[Dict[str, Any]]] = None, max_tokens: int = 4096, ) -> AsyncIterator[StreamEvent]: """Stream Anthropic responses.""" try: import anthropic except ImportError: raise ImportError( "anthropic package required. Install with: pip install anthropic") try: kwargs = {"api_key": self.api_key} if self.base_url: kwargs["base_url"] = self.base_url client = anthropic.AsyncAnthropic(**kwargs) except Exception as e: raise ValueError(f"Failed to create Anthropic client: {e}") kwargs = { "model": model, "messages": self._anthropic_messages(messages), "max_tokens": max_tokens, } if system_prompt: kwargs["system"] = system_prompt if tools: kwargs["tools"] = [self._convert_tool(t) for t in tools] async with client.messages.stream(**kwargs) as stream: # Track tool calls being built tool_calls_map = {} async for event in stream: if event.type == "content_block_start": if hasattr(event.content_block, "type") and event.content_block.type == "tool_use": # Tool call started - initialize tracking tool_calls_map[event.index] = { "id": event.content_block.id, "name": event.content_block.name, "input": "" } elif event.type == "content_block_delta": if hasattr(event.delta, "text"): # Text content yield StreamEvent(type="text_delta", delta=event.delta.text) elif hasattr(event.delta, "partial_json"): # Tool call input in progress if event.index in tool_calls_map: tool_calls_map[event.index]["input"] += event.delta.partial_json elif event.type == "content_block_stop": # Content block finished if event.index in tool_calls_map: # Tool call completed - parse JSON and yield import json tool_data = tool_calls_map[event.index] try: arguments = json.loads(tool_data["input"]) except json.JSONDecodeError: arguments = {} yield StreamEvent( type="tool_call", tool_call=ToolCall( id=tool_data["id"], name=tool_data["name"], arguments=arguments ) ) # Get final message to extract stop_reason final_message = await stream.get_final_message() if final_message and hasattr(final_message, 'stop_reason'): # Map Anthropic stop_reason to OpenAI-style # Anthropic: "end_turn", "max_tokens", "stop_sequence", "tool_use" # OpenAI: "stop", "length", "content_filter", "tool_calls" stop_reason_map = { "end_turn": "stop", "max_tokens": "length", "tool_use": "tool_calls", "stop_sequence": "stop" } finish_reason = stop_reason_map.get(final_message.stop_reason, final_message.stop_reason) yield StreamEvent(type="finish", finish_reason=finish_reason)
[docs] async def complete( self, model: str, messages: List[Message], system_prompt: Optional[str] = None, tools: Optional[List[Dict[str, Any]]] = None, max_tokens: int = 4096, ) -> AssistantMessage: """Complete Anthropic request.""" try: import anthropic except ImportError: raise ImportError( "anthropic package required. Install with: pip install anthropic") try: kwargs = {"api_key": self.api_key} if self.base_url: kwargs["base_url"] = self.base_url client = anthropic.AsyncAnthropic(**kwargs) except Exception as e: raise ValueError(f"Failed to create Anthropic client: {e}") kwargs = { "model": model, "messages": self._anthropic_messages(messages), "max_tokens": max_tokens, } if system_prompt: kwargs["system"] = system_prompt if tools: kwargs["tools"] = [self._convert_tool(t) for t in tools] response = await client.messages.create(**kwargs) # Parse response content_parts = [] tool_calls = [] for block in response.content: if block.type == "text": content_parts.append(TextContent(text=block.text)) elif block.type == "tool_use": tool_calls.append(ToolCall( id=block.id, name=block.name, arguments=block.input )) usage = None if response.usage: usage = Usage( input_tokens=response.usage.input_tokens, output_tokens=response.usage.output_tokens, ) return AssistantMessage( content=content_parts, tool_calls=tool_calls, model=model, usage=usage, stop_reason=response.stop_reason, )
def _anthropic_messages(self, messages: List[Message]) -> List[Dict[str, Any]]: """Convert messages to Anthropic format.""" result = [] for msg in messages: if isinstance(msg, UserMessage): result.append({ "role": "user", "content": msg.content if isinstance(msg.content, str) else self._format_content(msg.content) }) elif isinstance(msg, AssistantMessage): content = [] # Add text content if isinstance(msg.content, str): content.append({"type": "text", "text": msg.content}) else: for item in msg.content: if isinstance(item, TextContent): content.append({"type": "text", "text": item.text}) # Add tool calls for tc in msg.tool_calls: content.append({ "type": "tool_use", "id": tc.id, "name": tc.name, "input": tc.arguments }) result.append({"role": "assistant", "content": content}) elif isinstance(msg, ToolResultMessage): content = msg.content # Handle mixed content with images if isinstance(content, list): tool_result_content = [] for item in content: if isinstance(item, TextContent): tool_result_content.append({ "type": "text", "text": item.text }) elif isinstance(item, ImageContent): # Anthropic native image format tool_result_content.append({ "type": "image", "source": item.source # Already in correct format! }) result.append({ "role": "user", "content": [ { "type": "tool_result", "tool_use_id": msg.tool_call_id, "content": tool_result_content, # List of text/image blocks "is_error": msg.is_error } ] }) else: # Simple string content - existing logic result.append({ "role": "user", "content": [ { "type": "tool_result", "tool_use_id": msg.tool_call_id, "content": content, "is_error": msg.is_error } ] }) return result def _convert_tool(self, tool: Dict[str, Any]) -> Dict[str, Any]: """Convert tool definition to Anthropic format.""" return { "name": tool["name"], "description": tool["description"], "input_schema": tool.get("parameters", {}) }
# Registry of providers PROVIDERS = { "openai": OpenAIProvider, "anthropic": AnthropicProvider, }
[docs] def get_provider(provider_name: str, **kwargs) -> LLMProvider: """Get LLM provider by name.""" provider_class = PROVIDERS.get(provider_name) if not provider_class: raise ValueError(f"Unknown provider: {provider_name}") return provider_class(**kwargs)