Source code for ragit.providers.ollama

#
# Copyright RODMENA LIMITED 2025
# SPDX-License-Identifier: Apache-2.0
#
"""
Ollama provider for LLM and Embedding operations.

This provider connects to a local or remote Ollama server.
Configuration is loaded from environment variables.

Performance optimizations:
- Connection pooling via requests.Session()
- Async parallel embedding via httpx
- LRU cache for repeated embedding queries

Resilience features (via resilient-circuit):
- Retry with exponential backoff
- Circuit breaker pattern for fault tolerance
"""

from datetime import timedelta
from fractions import Fraction
from functools import lru_cache
from typing import Any

import httpx
import requests
from resilient_circuit import (
    CircuitProtectorPolicy,
    ExponentialDelay,
    RetryWithBackoffPolicy,
    SafetyNet,
)
from resilient_circuit.exceptions import ProtectedCallError, RetryLimitReached

from ragit.config import config
from ragit.exceptions import IndexingError, ProviderError
from ragit.logging import log_operation, logger
from ragit.providers.base import (
    BaseEmbeddingProvider,
    BaseLLMProvider,
    EmbeddingResponse,
    LLMResponse,
)


def _create_generate_policy() -> SafetyNet:
    """Create resilience policy for LLM generation (longer timeouts, more tolerant)."""
    return SafetyNet(
        policies=(
            RetryWithBackoffPolicy(
                max_retries=3,
                backoff=ExponentialDelay(
                    min_delay=timedelta(seconds=1),
                    max_delay=timedelta(seconds=30),
                    factor=2,
                    jitter=0.1,
                ),
                should_handle=lambda e: isinstance(e, (ConnectionError, TimeoutError, requests.RequestException)),
            ),
            CircuitProtectorPolicy(
                resource_key="ollama_generate",
                cooldown=timedelta(seconds=60),
                failure_limit=Fraction(3, 10),  # 30% failure rate trips circuit
                success_limit=Fraction(4, 5),  # 80% success to close
                should_handle=lambda e: isinstance(e, (ConnectionError, requests.RequestException)),
            ),
        )
    )


def _create_embed_policy() -> SafetyNet:
    """Create resilience policy for embeddings (faster, stricter)."""
    return SafetyNet(
        policies=(
            RetryWithBackoffPolicy(
                max_retries=2,
                backoff=ExponentialDelay(
                    min_delay=timedelta(milliseconds=500),
                    max_delay=timedelta(seconds=5),
                    factor=2,
                    jitter=0.1,
                ),
                should_handle=lambda e: isinstance(e, (ConnectionError, TimeoutError, requests.RequestException)),
            ),
            CircuitProtectorPolicy(
                resource_key="ollama_embed",
                cooldown=timedelta(seconds=30),
                failure_limit=Fraction(2, 5),  # 40% failure rate trips circuit
                success_limit=Fraction(3, 3),  # All 3 tests must succeed to close
                should_handle=lambda e: isinstance(e, (ConnectionError, requests.RequestException)),
            ),
        )
    )


def _truncate_text(text: str, max_chars: int = 2000) -> str:
    """Truncate text to max_chars. Used BEFORE cache lookup to fix cache key bug."""
    return text[:max_chars] if len(text) > max_chars else text


# Module-level cache for embeddings (shared across instances)
# NOTE: Text must be truncated BEFORE calling this function to ensure correct cache keys
@lru_cache(maxsize=2048)
def _cached_embedding(text: str, model: str, embedding_url: str, timeout: int) -> tuple[float, ...]:
    """Cache embedding results to avoid redundant API calls.

    IMPORTANT: Caller must truncate text BEFORE calling this function.
    This ensures cache keys are consistent for truncated inputs.
    """
    response = requests.post(
        f"{embedding_url}/api/embed",
        headers={"Content-Type": "application/json"},
        json={"model": model, "input": text},
        timeout=timeout,
    )
    response.raise_for_status()
    data = response.json()
    embeddings = data.get("embeddings", [])
    if not embeddings or not embeddings[0]:
        raise ValueError("Empty embedding returned from Ollama")
    return tuple(embeddings[0])


[docs] class OllamaProvider(BaseLLMProvider, BaseEmbeddingProvider): """ Ollama provider for both LLM and Embedding operations. Performance features: - Connection pooling via requests.Session() for faster sequential requests - Native batch embedding via /api/embed endpoint (single API call) - LRU cache for repeated embedding queries (2048 entries) Parameters ---------- base_url : str, optional Ollama server URL (default: from OLLAMA_BASE_URL env var) api_key : str, optional API key for authentication (default: from OLLAMA_API_KEY env var) timeout : int, optional Request timeout in seconds (default: from OLLAMA_TIMEOUT env var) use_cache : bool, optional Enable embedding cache (default: True) Examples -------- >>> provider = OllamaProvider() >>> response = provider.generate("What is RAG?", model="llama3") >>> print(response.text) >>> # Batch embedding (single API call) >>> embeddings = provider.embed_batch(texts, "mxbai-embed-large") """ # Known embedding model dimensions EMBEDDING_DIMENSIONS: dict[str, int] = { "nomic-embed-text": 768, "nomic-embed-text:latest": 768, "mxbai-embed-large": 1024, "all-minilm": 384, "snowflake-arctic-embed": 1024, "qwen3-embedding": 4096, "qwen3-embedding:0.6b": 1024, "qwen3-embedding:4b": 2560, "qwen3-embedding:8b": 4096, } # Max characters per embedding request (safe limit for 512 token models) MAX_EMBED_CHARS = 2000 # Default timeouts per operation type (in seconds) DEFAULT_TIMEOUTS: dict[str, int] = { "generate": 300, # 5 minutes for LLM generation "chat": 300, # 5 minutes for chat "embed": 30, # 30 seconds for single embedding "embed_batch": 120, # 2 minutes for batch embedding "health": 5, # 5 seconds for health check "list_models": 10, # 10 seconds for listing models }
[docs] def __init__( self, base_url: str | None = None, embedding_url: str | None = None, api_key: str | None = None, timeout: int | None = None, timeouts: dict[str, int] | None = None, use_cache: bool = True, use_resilience: bool = True, ) -> None: self.base_url = (base_url or config.OLLAMA_BASE_URL).rstrip("/") self.embedding_url = (embedding_url or config.OLLAMA_EMBEDDING_URL).rstrip("/") self.api_key = api_key or config.OLLAMA_API_KEY self.use_cache = use_cache self.use_resilience = use_resilience self._current_embed_model: str | None = None self._current_dimensions: int = 768 # default # Per-operation timeouts (merge user overrides with defaults) self._timeouts = {**self.DEFAULT_TIMEOUTS, **(timeouts or {})} # Legacy single timeout parameter overrides all operations if timeout is not None: self._timeouts = {k: timeout for k in self._timeouts} # Keep legacy timeout property for backwards compatibility self.timeout = timeout or config.OLLAMA_TIMEOUT # Connection pooling via session self._session: requests.Session | None = None # Resilience policies (retry + circuit breaker) self._generate_policy: SafetyNet | None = None self._embed_policy: SafetyNet | None = None if use_resilience: self._generate_policy = _create_generate_policy() self._embed_policy = _create_embed_policy()
@property def session(self) -> requests.Session: """Lazy-initialized session for connection pooling. Note: API key is NOT stored in session headers to prevent potential exposure in logs or error messages. Authentication is handled per-request via _get_headers(). """ if self._session is None: self._session = requests.Session() self._session.headers.update({"Content-Type": "application/json"}) # Security: API key is injected per-request via _get_headers() # rather than stored in session headers to prevent log exposure return self._session
[docs] def close(self) -> None: """Close the session and release resources.""" session = getattr(self, "_session", None) if session is not None: session.close() self._session = None
def __enter__(self) -> "OllamaProvider": """Context manager entry - returns self for use in 'with' statements. Example: with OllamaProvider() as provider: result = provider.generate("Hello", model="llama3") # Session automatically closed here """ return self def __exit__(self, exc_type: type | None, exc_val: Exception | None, exc_tb: object) -> None: """Context manager exit - ensures cleanup regardless of exceptions.""" self.close() def __del__(self) -> None: """Cleanup on garbage collection (fallback, prefer context manager).""" self.close() def _get_headers(self, include_auth: bool = True) -> dict[str, str]: """Get request headers including authentication if API key is set.""" headers = {"Content-Type": "application/json"} if include_auth and self.api_key: headers["Authorization"] = f"Bearer {self.api_key}" return headers @property def provider_name(self) -> str: return "ollama" @property def dimensions(self) -> int: return self._current_dimensions
[docs] def is_available(self) -> bool: """Check if Ollama server is reachable.""" try: response = self.session.get( f"{self.base_url}/api/tags", headers=self._get_headers(), timeout=self._timeouts["health"], ) return bool(response.status_code == 200) except requests.RequestException: return False
[docs] def list_models(self) -> list[dict[str, Any]]: """List available models on the Ollama server.""" try: response = self.session.get( f"{self.base_url}/api/tags", headers=self._get_headers(), timeout=self._timeouts["list_models"], ) response.raise_for_status() data = response.json() return list(data.get("models", [])) except requests.RequestException as e: raise ProviderError("Failed to list Ollama models", e) from e
[docs] def generate( self, prompt: str, model: str, system_prompt: str | None = None, temperature: float = 0.7, max_tokens: int | None = None, ) -> LLMResponse: """Generate text using Ollama with optional resilience (retry + circuit breaker).""" if self.use_resilience and self._generate_policy is not None: @self._generate_policy def _protected_generate() -> LLMResponse: return self._do_generate(prompt, model, system_prompt, temperature, max_tokens) try: return _protected_generate() except ProtectedCallError as e: logger.warning(f"Circuit breaker OPEN for ollama.generate (model={model})") raise ProviderError("Ollama service unavailable - circuit breaker open", e) from e except RetryLimitReached as e: logger.error(f"Retry limit reached for ollama.generate (model={model}): {e.__cause__}") raise ProviderError("Ollama generate failed after retries", e.__cause__) from e else: return self._do_generate(prompt, model, system_prompt, temperature, max_tokens)
def _do_generate( self, prompt: str, model: str, system_prompt: str | None = None, temperature: float = 0.7, max_tokens: int | None = None, ) -> LLMResponse: """Internal generate implementation (unprotected).""" options: dict[str, float | int] = {"temperature": temperature} if max_tokens: options["num_predict"] = max_tokens payload: dict[str, str | bool | dict[str, float | int]] = { "model": model, "prompt": prompt, "stream": False, "options": options, } if system_prompt: payload["system"] = system_prompt with log_operation("ollama.generate", model=model, prompt_len=len(prompt)) as ctx: try: response = self.session.post( f"{self.base_url}/api/generate", headers=self._get_headers(), json=payload, timeout=self._timeouts["generate"], ) response.raise_for_status() data = response.json() ctx["completion_tokens"] = data.get("eval_count") return LLMResponse( text=data.get("response", ""), model=model, provider=self.provider_name, usage={ "prompt_tokens": data.get("prompt_eval_count"), "completion_tokens": data.get("eval_count"), "total_duration": data.get("total_duration"), }, ) except requests.RequestException as e: raise ProviderError("Ollama generate failed", e) from e
[docs] def embed(self, text: str, model: str) -> EmbeddingResponse: """Generate embedding using Ollama with optional caching and resilience.""" if self.use_resilience and self._embed_policy is not None: @self._embed_policy def _protected_embed() -> EmbeddingResponse: return self._do_embed(text, model) try: return _protected_embed() except ProtectedCallError as e: logger.warning(f"Circuit breaker OPEN for ollama.embed (model={model})") raise ProviderError("Ollama embedding service unavailable - circuit breaker open", e) from e except RetryLimitReached as e: logger.error(f"Retry limit reached for ollama.embed (model={model}): {e.__cause__}") raise IndexingError("Ollama embed failed after retries", e.__cause__) from e else: return self._do_embed(text, model)
def _do_embed(self, text: str, model: str) -> EmbeddingResponse: """Internal embed implementation (unprotected).""" self._current_embed_model = model self._current_dimensions = self.EMBEDDING_DIMENSIONS.get(model, 768) # Truncate BEFORE cache lookup (fixes cache key bug) truncated_text = _truncate_text(text, self.MAX_EMBED_CHARS) was_truncated = len(text) > self.MAX_EMBED_CHARS with log_operation("ollama.embed", model=model, text_len=len(text), truncated=was_truncated) as ctx: try: if self.use_cache: # Use cached version with truncated text embedding = _cached_embedding(truncated_text, model, self.embedding_url, self._timeouts["embed"]) ctx["cache"] = "hit_or_miss" # Can't tell from here else: # Direct call without cache response = self.session.post( f"{self.embedding_url}/api/embed", headers=self._get_headers(), json={"model": model, "input": truncated_text}, timeout=self._timeouts["embed"], ) response.raise_for_status() data = response.json() embeddings = data.get("embeddings", []) if not embeddings or not embeddings[0]: raise ValueError("Empty embedding returned from Ollama") embedding = tuple(embeddings[0]) ctx["cache"] = "disabled" # Update dimensions from actual response self._current_dimensions = len(embedding) ctx["dimensions"] = len(embedding) return EmbeddingResponse( embedding=embedding, model=model, provider=self.provider_name, dimensions=len(embedding), ) except requests.RequestException as e: raise IndexingError("Ollama embed failed", e) from e
[docs] def embed_batch(self, texts: list[str], model: str) -> list[EmbeddingResponse]: """Generate embeddings for multiple texts in a single API call with resilience. The /api/embed endpoint supports batch inputs natively. """ if self.use_resilience and self._embed_policy is not None: @self._embed_policy def _protected_embed_batch() -> list[EmbeddingResponse]: return self._do_embed_batch(texts, model) try: return _protected_embed_batch() except ProtectedCallError as e: logger.warning(f"Circuit breaker OPEN for ollama.embed_batch (model={model}, batch_size={len(texts)})") raise ProviderError("Ollama embedding service unavailable - circuit breaker open", e) from e except RetryLimitReached as e: logger.error(f"Retry limit reached for ollama.embed_batch (model={model}): {e.__cause__}") raise IndexingError("Ollama batch embed failed after retries", e.__cause__) from e else: return self._do_embed_batch(texts, model)
def _do_embed_batch(self, texts: list[str], model: str) -> list[EmbeddingResponse]: """Internal batch embed implementation (unprotected).""" self._current_embed_model = model self._current_dimensions = self.EMBEDDING_DIMENSIONS.get(model, 768) # Truncate oversized inputs truncated_texts = [_truncate_text(text, self.MAX_EMBED_CHARS) for text in texts] truncated_count = sum(1 for t, tt in zip(texts, truncated_texts, strict=True) if len(t) != len(tt)) with log_operation( "ollama.embed_batch", model=model, batch_size=len(texts), truncated_count=truncated_count ) as ctx: try: response = self.session.post( f"{self.embedding_url}/api/embed", headers=self._get_headers(), json={"model": model, "input": truncated_texts}, timeout=self._timeouts["embed_batch"], ) response.raise_for_status() data = response.json() embeddings_list = data.get("embeddings", []) if not embeddings_list: raise ValueError("Empty embeddings returned from Ollama") results = [] for embedding_data in embeddings_list: embedding = tuple(embedding_data) if embedding_data else () if embedding: self._current_dimensions = len(embedding) results.append( EmbeddingResponse( embedding=embedding, model=model, provider=self.provider_name, dimensions=len(embedding), ) ) ctx["dimensions"] = self._current_dimensions return results except requests.RequestException as e: raise IndexingError("Ollama batch embed failed", e) from e
[docs] async def embed_batch_async( self, texts: list[str], model: str, max_concurrent: int = 10, # kept for API compatibility, no longer used ) -> list[EmbeddingResponse]: """Generate embeddings for multiple texts asynchronously. The /api/embed endpoint supports batch inputs natively, so this makes a single async HTTP request for all texts. Parameters ---------- texts : list[str] Texts to embed. model : str Embedding model name. max_concurrent : int Deprecated, kept for API compatibility. No longer used since the API now supports native batching. Returns ------- list[EmbeddingResponse] Embeddings in the same order as input texts. Examples -------- >>> import asyncio >>> embeddings = asyncio.run(provider.embed_batch_async(texts, "mxbai-embed-large")) """ self._current_embed_model = model self._current_dimensions = self.EMBEDDING_DIMENSIONS.get(model, 768) # Truncate oversized inputs truncated_texts = [text[: self.MAX_EMBED_CHARS] if len(text) > self.MAX_EMBED_CHARS else text for text in texts] try: async with httpx.AsyncClient() as client: response = await client.post( f"{self.embedding_url}/api/embed", json={"model": model, "input": truncated_texts}, timeout=self._timeouts["embed_batch"], ) response.raise_for_status() data = response.json() embeddings_list = data.get("embeddings", []) if not embeddings_list: raise ValueError("Empty embeddings returned from Ollama") results = [] for embedding_data in embeddings_list: embedding = tuple(embedding_data) if embedding_data else () if embedding: self._current_dimensions = len(embedding) results.append( EmbeddingResponse( embedding=embedding, model=model, provider=self.provider_name, dimensions=len(embedding), ) ) return results except httpx.HTTPError as e: raise IndexingError("Ollama async batch embed failed", e) from e
[docs] def chat( self, messages: list[dict[str, str]], model: str, temperature: float = 0.7, max_tokens: int | None = None, ) -> LLMResponse: """ Chat completion using Ollama with optional resilience. Parameters ---------- messages : list[dict] List of messages with 'role' and 'content' keys. model : str Model identifier. temperature : float Sampling temperature. max_tokens : int, optional Maximum tokens to generate. Returns ------- LLMResponse The generated response. """ if self.use_resilience and self._generate_policy is not None: @self._generate_policy def _protected_chat() -> LLMResponse: return self._do_chat(messages, model, temperature, max_tokens) try: return _protected_chat() except ProtectedCallError as e: logger.warning(f"Circuit breaker OPEN for ollama.chat (model={model})") raise ProviderError("Ollama service unavailable - circuit breaker open", e) from e except RetryLimitReached as e: logger.error(f"Retry limit reached for ollama.chat (model={model}): {e.__cause__}") raise ProviderError("Ollama chat failed after retries", e.__cause__) from e else: return self._do_chat(messages, model, temperature, max_tokens)
def _do_chat( self, messages: list[dict[str, str]], model: str, temperature: float = 0.7, max_tokens: int | None = None, ) -> LLMResponse: """Internal chat implementation (unprotected).""" options: dict[str, float | int] = {"temperature": temperature} if max_tokens: options["num_predict"] = max_tokens payload: dict[str, str | bool | list[dict[str, str]] | dict[str, float | int]] = { "model": model, "messages": messages, "stream": False, "options": options, } with log_operation("ollama.chat", model=model, message_count=len(messages)) as ctx: try: response = self.session.post( f"{self.base_url}/api/chat", headers=self._get_headers(), json=payload, timeout=self._timeouts["chat"], ) response.raise_for_status() data = response.json() ctx["completion_tokens"] = data.get("eval_count") return LLMResponse( text=data.get("message", {}).get("content", ""), model=model, provider=self.provider_name, usage={ "prompt_tokens": data.get("prompt_eval_count"), "completion_tokens": data.get("eval_count"), }, ) except requests.RequestException as e: raise ProviderError("Ollama chat failed", e) from e # Circuit breaker status monitoring @property def generate_circuit_status(self) -> str: """Get generate circuit breaker status (CLOSED, OPEN, HALF_OPEN, or 'disabled').""" if not self.use_resilience or self._generate_policy is None: return "disabled" # Access the circuit protector (second policy in SafetyNet) policies = getattr(self._generate_policy, "policies", None) if policies is None or len(policies) < 2: return "unknown" circuit = policies[1] status = getattr(circuit, "status", None) return str(getattr(status, "name", "unknown")) @property def embed_circuit_status(self) -> str: """Get embed circuit breaker status (CLOSED, OPEN, HALF_OPEN, or 'disabled').""" if not self.use_resilience or self._embed_policy is None: return "disabled" policies = getattr(self._embed_policy, "policies", None) if policies is None or len(policies) < 2: return "unknown" circuit = policies[1] status = getattr(circuit, "status", None) return str(getattr(status, "name", "unknown"))
[docs] @staticmethod def clear_embedding_cache() -> None: """Clear the embedding cache.""" _cached_embedding.cache_clear()
[docs] @staticmethod def embedding_cache_info() -> dict[str, int]: """Get embedding cache statistics.""" info = _cached_embedding.cache_info() return { "hits": info.hits, "misses": info.misses, "maxsize": info.maxsize or 0, "currsize": info.currsize, }
# Export the EMBEDDING_DIMENSIONS for external use EMBEDDING_DIMENSIONS = OllamaProvider.EMBEDDING_DIMENSIONS