"""Unified LLM interface supporting multiple providers."""
import json
import os
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, AsyncIterator, Callable, Dict, List, Optional
from .messages import (AssistantMessage, Message, TextContent, ToolCall,
ToolResultMessage, Usage, UserMessage)
[docs]
@dataclass
class StreamEvent:
"""LLM stream event."""
type: str
delta: str = ""
tool_call: Optional[ToolCall] = None
[docs]
class LLMProvider(ABC):
"""Abstract LLM provider interface."""
[docs]
def __init__(self, api_key: Optional[str] = None, base_url: Optional[str] = None):
self.api_key = api_key
self.base_url = base_url
[docs]
@abstractmethod
async def stream(
self,
model: str,
messages: List[Message],
system_prompt: Optional[str] = None,
tools: Optional[List[Dict[str, Any]]] = None,
max_tokens: int = 4096,
) -> AsyncIterator[StreamEvent]:
"""Stream LLM responses."""
pass
[docs]
@abstractmethod
async def complete(
self,
model: str,
messages: List[Message],
system_prompt: Optional[str] = None,
tools: Optional[List[Dict[str, Any]]] = None,
max_tokens: int = 4096,
) -> AssistantMessage:
"""Complete LLM request (non-streaming)."""
pass
def _messages_to_dict(self, messages: List[Message]) -> List[Dict[str, Any]]:
"""Convert messages to API format."""
result = []
for msg in messages:
if isinstance(msg, UserMessage):
result.append({
"role": "user",
"content": msg.content if isinstance(msg.content, str) else self._format_content(msg.content)
})
elif isinstance(msg, AssistantMessage):
entry = {"role": "assistant"}
if isinstance(msg.content, str):
entry["content"] = msg.content
else:
# Has tool calls
if msg.tool_calls:
entry["tool_calls"] = [
{
"id": tc.id,
"type": "function",
"function": {
"name": tc.name,
"arguments": json.dumps(tc.arguments)
}
}
for tc in msg.tool_calls
]
# Extract text content
text_parts = [
c.text for c in msg.content if isinstance(c, TextContent)]
if text_parts:
entry["content"] = "\n".join(text_parts)
elif not msg.tool_calls:
entry["content"] = ""
result.append(entry)
elif isinstance(msg, ToolResultMessage):
result.append({
"role": "tool",
"tool_call_id": msg.tool_call_id,
"content": msg.content if isinstance(msg.content, str) else self._format_content(msg.content)
})
return result
def _format_content(self, content: List[Any]) -> str:
"""Format mixed content as string."""
parts = []
for item in content:
if isinstance(item, TextContent):
parts.append(item.text)
elif hasattr(item, 'text'):
parts.append(item.text)
return "\n".join(parts)
[docs]
class OpenAIProvider(LLMProvider):
"""OpenAI/compatible API provider."""
[docs]
def __init__(self, api_key: Optional[str] = None, base_url: Optional[str] = None):
super().__init__(
api_key=api_key or os.getenv("OPENAI_API_KEY"),
base_url=base_url or os.getenv(
"OPENAI_BASE_URL") or "https://api.openai.com/v1"
)
# Validate API key
if not self.api_key:
raise ValueError(
"OpenAI API key not found. Please set OPENAI_API_KEY environment variable "
"or pass api_key parameter."
)
[docs]
async def stream(
self,
model: str,
messages: List[Message],
system_prompt: Optional[str] = None,
tools: Optional[List[Dict[str, Any]]] = None,
max_tokens: int = 4096,
) -> AsyncIterator[StreamEvent]:
"""Stream OpenAI responses."""
try:
import openai
except ImportError:
raise ImportError(
"openai package required. Install with: pip install openai")
try:
client = openai.AsyncOpenAI(
api_key=self.api_key, base_url=self.base_url)
except Exception as e:
raise ValueError(f"Failed to create OpenAI client: {e}")
try:
api_messages = []
if system_prompt:
api_messages.append(
{"role": "system", "content": system_prompt})
api_messages.extend(self._messages_to_dict(messages))
kwargs = {
"model": model,
"messages": api_messages,
"max_tokens": max_tokens,
"stream": True,
}
if tools:
kwargs["tools"] = [self._convert_tool(t) for t in tools]
stream = await client.chat.completions.create(**kwargs)
# Accumulate tool calls during streaming
tool_calls_accumulator = {} # index -> {id, name, arguments_str}
async for chunk in stream:
if not chunk.choices:
continue
delta = chunk.choices[0].delta
if delta.content:
yield StreamEvent(type="text_delta", delta=delta.content)
if delta.tool_calls:
for tc in delta.tool_calls:
# Tool calls come in increments, need to accumulate
# Use index as key since id might not be in every delta
tc_index = tc.index if tc.index is not None else 0
tc_id = tc.id or f"call_{tc_index}"
if tc_index not in tool_calls_accumulator:
tool_calls_accumulator[tc_index] = {
"id": tc_id,
"name": "",
"arguments": ""
}
# Update id if provided
if tc.id:
tool_calls_accumulator[tc_index]["id"] = tc.id
if tc.function:
if tc.function.name:
tool_calls_accumulator[tc_index]["name"] = tc.function.name
if tc.function.arguments:
tool_calls_accumulator[tc_index]["arguments"] += tc.function.arguments
# After stream ends, yield complete tool calls
for tc_data in tool_calls_accumulator.values():
try:
arguments = json.loads(
tc_data["arguments"]) if tc_data["arguments"] else {}
except json.JSONDecodeError:
# If JSON parsing fails, treat as empty
arguments = {}
yield StreamEvent(
type="tool_call",
tool_call=ToolCall(
id=tc_data["id"],
name=tc_data["name"],
arguments=arguments
)
)
finally:
# Close the client properly
await client.close()
[docs]
async def complete(
self,
model: str,
messages: List[Message],
system_prompt: Optional[str] = None,
tools: Optional[List[Dict[str, Any]]] = None,
max_tokens: int = 4096,
) -> AssistantMessage:
"""Complete OpenAI request."""
try:
import openai
except ImportError:
raise ImportError(
"openai package required. Install with: pip install openai")
try:
client = openai.AsyncOpenAI(
api_key=self.api_key, base_url=self.base_url)
except Exception as e:
raise ValueError(f"Failed to create OpenAI client: {e}")
try:
api_messages = []
if system_prompt:
api_messages.append(
{"role": "system", "content": system_prompt})
api_messages.extend(self._messages_to_dict(messages))
kwargs = {
"model": model,
"messages": api_messages,
"max_tokens": max_tokens,
}
if tools:
kwargs["tools"] = [self._convert_tool(t) for t in tools]
response = await client.chat.completions.create(**kwargs)
choice = response.choices[0]
# Build assistant message
content_parts = []
tool_calls = []
if choice.message.content:
content_parts.append(TextContent(text=choice.message.content))
if choice.message.tool_calls:
for tc in choice.message.tool_calls:
tool_calls.append(ToolCall(
id=tc.id,
name=tc.function.name,
arguments=json.loads(tc.function.arguments)
))
usage = None
if response.usage:
usage = Usage(
input_tokens=response.usage.prompt_tokens,
output_tokens=response.usage.completion_tokens,
)
return AssistantMessage(
content=content_parts,
tool_calls=tool_calls,
model=model,
usage=usage,
stop_reason=choice.finish_reason,
)
finally:
# Close the client properly
await client.close()
def _convert_tool(self, tool: Dict[str, Any]) -> Dict[str, Any]:
"""Convert tool definition to OpenAI format."""
return {
"type": "function",
"function": {
"name": tool["name"],
"description": tool["description"],
"parameters": tool.get("parameters", {})
}
}
[docs]
class AnthropicProvider(LLMProvider):
"""Anthropic Claude API provider."""
[docs]
def __init__(self, api_key: Optional[str] = None, base_url: Optional[str] = None):
super().__init__(
api_key=api_key or os.getenv("ANTHROPIC_API_KEY"),
base_url=base_url
)
# Validate API key
if not self.api_key:
raise ValueError(
"Anthropic API key not found. Please set ANTHROPIC_API_KEY environment variable "
"or pass api_key parameter."
)
[docs]
async def stream(
self,
model: str,
messages: List[Message],
system_prompt: Optional[str] = None,
tools: Optional[List[Dict[str, Any]]] = None,
max_tokens: int = 4096,
) -> AsyncIterator[StreamEvent]:
"""Stream Anthropic responses."""
try:
import anthropic
except ImportError:
raise ImportError(
"anthropic package required. Install with: pip install anthropic")
try:
kwargs = {"api_key": self.api_key}
if self.base_url:
kwargs["base_url"] = self.base_url
client = anthropic.AsyncAnthropic(**kwargs)
except Exception as e:
raise ValueError(f"Failed to create Anthropic client: {e}")
kwargs = {
"model": model,
"messages": self._anthropic_messages(messages),
"max_tokens": max_tokens,
"stream": True,
}
if system_prompt:
kwargs["system"] = system_prompt
if tools:
kwargs["tools"] = [self._convert_tool(t) for t in tools]
async with client.messages.stream(**kwargs) as stream:
async for event in stream:
if event.type == "content_block_delta":
if hasattr(event.delta, "text"):
yield StreamEvent(type="text_delta", delta=event.delta.text)
elif hasattr(event.delta, "partial_json"):
# Tool call in progress
pass
elif event.type == "content_block_start":
if hasattr(event.content_block, "type") and event.content_block.type == "tool_use":
# Tool call started
pass
[docs]
async def complete(
self,
model: str,
messages: List[Message],
system_prompt: Optional[str] = None,
tools: Optional[List[Dict[str, Any]]] = None,
max_tokens: int = 4096,
) -> AssistantMessage:
"""Complete Anthropic request."""
try:
import anthropic
except ImportError:
raise ImportError(
"anthropic package required. Install with: pip install anthropic")
try:
kwargs = {"api_key": self.api_key}
if self.base_url:
kwargs["base_url"] = self.base_url
client = anthropic.AsyncAnthropic(**kwargs)
except Exception as e:
raise ValueError(f"Failed to create Anthropic client: {e}")
kwargs = {
"model": model,
"messages": self._anthropic_messages(messages),
"max_tokens": max_tokens,
}
if system_prompt:
kwargs["system"] = system_prompt
if tools:
kwargs["tools"] = [self._convert_tool(t) for t in tools]
response = await client.messages.create(**kwargs)
# Parse response
content_parts = []
tool_calls = []
for block in response.content:
if block.type == "text":
content_parts.append(TextContent(text=block.text))
elif block.type == "tool_use":
tool_calls.append(ToolCall(
id=block.id,
name=block.name,
arguments=block.input
))
usage = None
if response.usage:
usage = Usage(
input_tokens=response.usage.input_tokens,
output_tokens=response.usage.output_tokens,
)
return AssistantMessage(
content=content_parts,
tool_calls=tool_calls,
model=model,
usage=usage,
stop_reason=response.stop_reason,
)
def _anthropic_messages(self, messages: List[Message]) -> List[Dict[str, Any]]:
"""Convert messages to Anthropic format."""
result = []
for msg in messages:
if isinstance(msg, UserMessage):
result.append({
"role": "user",
"content": msg.content if isinstance(msg.content, str) else self._format_content(msg.content)
})
elif isinstance(msg, AssistantMessage):
content = []
# Add text content
if isinstance(msg.content, str):
content.append({"type": "text", "text": msg.content})
else:
for item in msg.content:
if isinstance(item, TextContent):
content.append({"type": "text", "text": item.text})
# Add tool calls
for tc in msg.tool_calls:
content.append({
"type": "tool_use",
"id": tc.id,
"name": tc.name,
"input": tc.arguments
})
result.append({"role": "assistant", "content": content})
elif isinstance(msg, ToolResultMessage):
result.append({
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": msg.tool_call_id,
"content": msg.content if isinstance(msg.content, str) else self._format_content(msg.content),
"is_error": msg.is_error
}
]
})
return result
def _convert_tool(self, tool: Dict[str, Any]) -> Dict[str, Any]:
"""Convert tool definition to Anthropic format."""
return {
"name": tool["name"],
"description": tool["description"],
"input_schema": tool.get("parameters", {})
}
# Registry of providers
PROVIDERS = {
"openai": OpenAIProvider,
"anthropic": AnthropicProvider,
}
[docs]
def get_provider(provider_name: str, **kwargs) -> LLMProvider:
"""Get LLM provider by name."""
provider_class = PROVIDERS.get(provider_name)
if not provider_class:
raise ValueError(f"Unknown provider: {provider_name}")
return provider_class(**kwargs)