Skip to content

Adapters

Adapters are the concrete implementations of the Ports. Each adapter binds the system to one specific external technology.

Container is the only entry point

Never import an adapter class directly in your application code. Always go through services/container.py — that is the single wiring point where adapters are named. This preserves swappability.


Provider overview

Two providers are supported for both embedding and LLM. Selection is controlled entirely by environment variables — no code changes required.

Component EMBED_PROVIDER=vertex (default) EMBED_PROVIDER=openai
Embedding adapter VertexEmbeddingAdapter OpenAIEmbeddingAdapter
Auth required gcloud ADC token OPENAI_API_KEY
Model default text-embedding-005 text-embedding-3-small
Component LLM_PROVIDER=vertex (default) LLM_PROVIDER=openai
LLM adapter GeminiLLMAdapter OpenAILLMAdapter
Auth required gcloud ADC token OPENAI_API_KEY
Model default gemini-2.5-flash gpt-4o

Mix-and-match is supported (e.g. EMBED_PROVIDER=openai LLM_PROVIDER=vertex).


GCPAuthManager

Shared across both GCP adapters (VertexEmbeddingAdapter and GeminiLLMAdapter). Manages the bearer token lifecycle — fetching, caching, and refreshing — so that both adapters share a single token rather than making separate gcloud subprocess calls.

Only instantiated when EMBED_PROVIDER=vertex or LLM_PROVIDER=vertex.

gcp_auth

adapters/gcp_auth.py ────────────────────────────────────────────────────────────────────────────── GCP authentication manager.

Wraps the gcloud auth print-access-token subprocess call, caches the token in memory, and refreshes automatically when it is within TOKEN_REFRESH_MARGIN seconds of expiry.

Design notes
  • All adapters that need GCP auth receive a GCPAuthManager instance via DI.
  • A single GCPAuthManager is created in services/container.py and shared across VertexEmbeddingAdapter and GeminiLLMAdapter → one token, one call.
  • Thread-safe: uses a threading.Lock for token refresh in multi-threaded contexts (Streamlit, FastAPI worker threads).

GCPAuthManager

Manages a GCP Application Default Credentials access token.

Usage (injected by container.py — do not instantiate manually): auth = GCPAuthManager(settings) token = auth.get_token() # fresh or cached

Source code in prod/adapters/gcp_auth.py
class GCPAuthManager:
    """Manages a GCP Application Default Credentials access token.

    Usage (injected by container.py — do not instantiate manually):
        auth = GCPAuthManager(settings)
        token = auth.get_token()          # fresh or cached
    """

    def __init__(self, settings: Settings) -> None:
        self._gcloud_path = settings.gcloud_path
        self._state = _TokenState()
        self._lock = threading.Lock()
        logger.debug("GCPAuthManager initialised | gcloud=%s", self._gcloud_path)

    # ── Public API ─────────────────────────────────────────────────────────

    def get_token(self) -> str:
        """Return a valid access token, refreshing if necessary.

        Returns:
            Bearer token string.

        Raises:
            AuthenticationError: If gcloud fails.
        """
        with self._lock:
            if self._needs_refresh():
                self._refresh()
            return self._state.value

    def invalidate(self) -> None:
        """Force the next call to get_token() to fetch a fresh token."""
        with self._lock:
            self._state.expires_at = 0.0
            logger.debug("GCPAuthManager: token invalidated")

    # ── Private helpers ────────────────────────────────────────────────────

    def _needs_refresh(self) -> bool:
        margin_time = time.time() + TOKEN_REFRESH_MARGIN
        return self._state.value == "" or self._state.expires_at <= margin_time

    def _refresh(self) -> None:
        logger.info("GCPAuthManager: refreshing access token …")
        try:
            result = subprocess.run(
                [self._gcloud_path, "auth", "print-access-token"],
                capture_output=True,
                text=True,
                timeout=30,
                check=True,
            )
        except subprocess.TimeoutExpired as exc:
            raise AuthenticationError("gcloud timed out fetching access token") from exc
        except subprocess.CalledProcessError as exc:
            stderr = exc.stderr.strip() if exc.stderr else "(no stderr)"
            raise AuthenticationError(
                f"gcloud auth print-access-token failed: {stderr}"
            ) from exc
        except FileNotFoundError as exc:
            raise AuthenticationError(
                f"gcloud not found at '{self._gcloud_path}'. "
                "Set the GCLOUD_PATH environment variable."
            ) from exc

        token = result.stdout.strip()
        if not token:
            raise AuthenticationError("gcloud returned an empty access token")

        self._state.value = token
        self._state.expires_at = time.time() + TOKEN_TTL_SECONDS
        logger.info("GCPAuthManager: token refreshed (expires in %ds)", TOKEN_TTL_SECONDS)

get_token

get_token() -> str

Return a valid access token, refreshing if necessary.

Returns:

Type Description
str

Bearer token string.

Raises:

Type Description
AuthenticationError

If gcloud fails.

Source code in prod/adapters/gcp_auth.py
def get_token(self) -> str:
    """Return a valid access token, refreshing if necessary.

    Returns:
        Bearer token string.

    Raises:
        AuthenticationError: If gcloud fails.
    """
    with self._lock:
        if self._needs_refresh():
            self._refresh()
        return self._state.value

invalidate

invalidate() -> None

Force the next call to get_token() to fetch a fresh token.

Source code in prod/adapters/gcp_auth.py
def invalidate(self) -> None:
    """Force the next call to get_token() to fetch a fresh token."""
    with self._lock:
        self._state.expires_at = 0.0
        logger.debug("GCPAuthManager: token invalidated")

VertexEmbeddingAdapter

Implements EmbeddingPort using the Vertex AI Predict REST API.

Key behaviours:

  • embed_query uses RETRIEVAL_QUERY task type (asymmetric retrieval)
  • embed_document uses RETRIEVAL_DOCUMENT task type
  • Batch calls chunk to embed_batch_size (default 50) to stay within API limits
  • HTTP 401 → invalidates the cached token and retries once
  • HTTP 429 / 503 → exponential back-off with up to embed_retries attempts

vertex_embedding

adapters/vertex_embedding.py ────────────────────────────────────────────────────────────────────────────── Implements EmbeddingPort using Vertex AI text-embedding-005.

Key behaviour
  • embed_query → RETRIEVAL_QUERY task type (asymmetric retrieval)
  • embed_document → RETRIEVAL_DOCUMENT task type
  • embed_documents_batch → single API call for up to embed_batch_size items
  • Retries on transient HTTP errors (429, 503) with exponential back-off
  • Corporate proxy support via settings.https_proxy
  • Token 401 → triggers GCPAuthManager.invalidate() then retries once

To swap to a different embedding model (e.g. OpenAI text-embedding-3-large): 1. Write OpenAIEmbeddingAdapter implementing EmbeddingPort 2. Change ONE import in services/container.py 3. Update EMBED_DIM in settings / .env

VertexEmbeddingAdapter

Vertex AI text-embedding-005 adapter.

Injected into HybridRetriever via services/container.py.

Source code in prod/adapters/vertex_embedding.py
class VertexEmbeddingAdapter:
    """Vertex AI text-embedding-005 adapter.

    Injected into HybridRetriever via services/container.py.
    """

    def __init__(self, auth: GCPAuthManager, settings: Settings) -> None:
        self._auth = auth
        self._settings = settings
        self._url = _build_embed_url(settings)
        self._proxies = (
            {"https": f"http://{settings.https_proxy}"}
            if settings.https_proxy
            else {}
        )
        logger.debug("VertexEmbeddingAdapter ready | url=%s", self._url)

    # ── EmbeddingPort implementation ───────────────────────────────────────

    @property
    def model_name(self) -> str:
        return self._settings.gcp_embed_model

    @property
    def dimensions(self) -> int:
        return self._settings.embed_dim

    def embed_query(self, text: str) -> list[float]:
        """Embed a query with RETRIEVAL_QUERY task type."""
        return self._embed_single(text, task_type=_TASK_QUERY)

    def embed_document(self, text: str, title: str = "") -> list[float]:
        """Embed a document with RETRIEVAL_DOCUMENT task type."""
        return self._embed_single(text, task_type=_TASK_DOCUMENT, title=title)

    def embed_documents_batch(
        self,
        texts: list[str],
        titles: list[str] | None = None,
    ) -> list[list[float] | None]:
        """Embed multiple documents in batches."""
        if not texts:
            return []
        titles = titles or [""] * len(texts)
        batch_size = self._settings.embed_batch_size
        all_results: list[list[float] | None] = []

        for start in range(0, len(texts), batch_size):
            chunk_texts = texts[start : start + batch_size]
            chunk_titles = titles[start : start + batch_size]
            batch_results = self._embed_batch(chunk_texts, _TASK_DOCUMENT, chunk_titles)
            all_results.extend(batch_results)

        return all_results

    # ── Private helpers ────────────────────────────────────────────────────

    def _embed_single(
        self,
        text: str,
        task_type: str,
        title: str = "",
        retries: int = 3,
    ) -> list[float]:
        """Call the Vertex AI Predict endpoint for a single text."""
        instance: dict[str, Any] = {"content": text, "task_type": task_type}
        if title:
            instance["title"] = title
        payload = {"instances": [instance]}

        response_json = self._post_with_retry(payload, retries=retries)
        try:
            return response_json["predictions"][0]["embeddings"]["values"]
        except (KeyError, IndexError, TypeError) as exc:
            raise EmbeddingError(
                f"Unexpected embed response shape: {list(response_json.keys())}"
            ) from exc

    def _embed_batch(
        self,
        texts: list[str],
        task_type: str,
        titles: list[str],
    ) -> list[list[float] | None]:
        """Call Predict for a batch of texts; returns None for failed items."""
        instances = [
            {"content": t, "task_type": task_type, **({"title": tl} if tl else {})}
            for t, tl in zip(texts, titles)
        ]
        payload = {"instances": instances}
        try:
            response_json = self._post_with_retry(payload, retries=self._settings.embed_retries)
            predictions = response_json.get("predictions", [])
            results: list[list[float] | None] = []
            for pred in predictions:
                try:
                    results.append(pred["embeddings"]["values"])
                except (KeyError, TypeError):
                    results.append(None)
            # Pad with None if fewer predictions returned than requested
            while len(results) < len(texts):
                results.append(None)
            return results
        except EmbeddingError:
            logger.warning("Batch embed failed; returning None for all %d items", len(texts))
            return [None] * len(texts)

    def _post_with_retry(
        self,
        payload: dict,
        retries: int = 3,
    ) -> dict:
        """POST to Vertex AI with retry and token refresh on 401."""
        delay = 1.0
        last_exc: Exception | None = None

        for attempt in range(1, retries + 1):
            token = self._auth.get_token()
            headers = {
                "Authorization": f"Bearer {token}",
                "Content-Type": "application/json",
            }
            try:
                resp = requests.post(
                    self._url,
                    headers=headers,
                    json=payload,
                    proxies=self._proxies,
                    timeout=self._settings.embed_timeout,
                )
            except requests.RequestException as exc:
                last_exc = exc
                logger.warning("Embed HTTP error (attempt %d/%d): %s", attempt, retries, exc)
                time.sleep(delay)
                delay *= 2
                continue

            if resp.status_code == 401:
                logger.warning("Embed 401 — invalidating token and retrying")
                self._auth.invalidate()
                continue

            if resp.status_code in (429, 503):
                logger.warning("Embed %d (attempt %d/%d) — back-off %.1fs",
                               resp.status_code, attempt, retries, delay)
                time.sleep(delay)
                delay *= 2
                continue

            if not resp.ok:
                raise EmbeddingError(
                    f"Vertex AI Embed returned HTTP {resp.status_code}: {resp.text[:200]}"
                )

            return resp.json()

        raise EmbeddingError(
            f"Embed failed after {retries} attempts"
        ) from last_exc

embed_query

embed_query(text: str) -> list[float]

Embed a query with RETRIEVAL_QUERY task type.

Source code in prod/adapters/vertex_embedding.py
def embed_query(self, text: str) -> list[float]:
    """Embed a query with RETRIEVAL_QUERY task type."""
    return self._embed_single(text, task_type=_TASK_QUERY)

embed_document

embed_document(text: str, title: str = '') -> list[float]

Embed a document with RETRIEVAL_DOCUMENT task type.

Source code in prod/adapters/vertex_embedding.py
def embed_document(self, text: str, title: str = "") -> list[float]:
    """Embed a document with RETRIEVAL_DOCUMENT task type."""
    return self._embed_single(text, task_type=_TASK_DOCUMENT, title=title)

embed_documents_batch

embed_documents_batch(texts: list[str], titles: list[str] | None = None) -> list[list[float] | None]

Embed multiple documents in batches.

Source code in prod/adapters/vertex_embedding.py
def embed_documents_batch(
    self,
    texts: list[str],
    titles: list[str] | None = None,
) -> list[list[float] | None]:
    """Embed multiple documents in batches."""
    if not texts:
        return []
    titles = titles or [""] * len(texts)
    batch_size = self._settings.embed_batch_size
    all_results: list[list[float] | None] = []

    for start in range(0, len(texts), batch_size):
        chunk_texts = texts[start : start + batch_size]
        chunk_titles = titles[start : start + batch_size]
        batch_results = self._embed_batch(chunk_texts, _TASK_DOCUMENT, chunk_titles)
        all_results.extend(batch_results)

    return all_results

GeminiLLMAdapter

Implements LLMPort using the Vertex AI generateContent REST API.

Key behaviours:

  • Requests responseMimeType: application/json — guaranteed valid JSON output
  • temperature: 0.1 — low temperature for consistent, deterministic re-ranking
  • HTTP 401 → token refresh + retry
  • HTTP 429 / 503 → exponential back-off

gemini_llm

adapters/gemini_llm.py ────────────────────────────────────────────────────────────────────────────── Implements LLMPort using Vertex AI Gemini (generateContent REST API).

Key behaviour
  • Sends systemInstruction + contents in the Vertex AI REST format
  • Requests JSON output via responseMimeType: application/json
  • Token 401 → triggers GCPAuthManager.invalidate() then retries once
  • Retries on 429/503 with exponential back-off
  • Returns raw JSON string (caller parses)
To swap to OpenAI-compatible endpoints
  1. Write OpenAILLMAdapter implementing LLMPort
  2. Change ONE import in services/container.py

GeminiLLMAdapter

Vertex AI Gemini adapter.

Injected into LLMReranker via services/container.py.

Source code in prod/adapters/gemini_llm.py
class GeminiLLMAdapter:
    """Vertex AI Gemini adapter.

    Injected into LLMReranker via services/container.py.
    """

    def __init__(self, auth: GCPAuthManager, settings: Settings) -> None:
        self._auth = auth
        self._settings = settings
        self._url = _build_gemini_url(settings)
        self._proxies = (
            {"https": f"http://{settings.https_proxy}"}
            if settings.https_proxy
            else {}
        )
        logger.debug("GeminiLLMAdapter ready | model=%s", settings.gcp_gemini_model)

    # ── LLMPort implementation ─────────────────────────────────────────────

    @property
    def model_name(self) -> str:
        return self._settings.gcp_gemini_model

    def generate_json(
        self,
        system_prompt: str,
        user_message: str,
    ) -> str | None:
        """Send a prompt and return the raw JSON response string.

        Args:
            system_prompt: System-level instruction for Gemini.
            user_message:  User-turn message content.

        Returns:
            Raw JSON string from the model, or None on recoverable failure.

        Raises:
            LLMError: On unrecoverable API failure.
        """
        payload = self._build_payload(system_prompt, user_message)
        _t_total = time.perf_counter()
        result = self._post_with_retry(payload)
        logger.info(
            "⏱ [GeminiLLM] operation=generate_json_total elapsed=%.3fs",
            time.perf_counter() - _t_total,
        )
        return result

    # ── Private helpers ────────────────────────────────────────────────────

    def _build_payload(self, system_prompt: str, user_message: str) -> dict:
        return {
            "systemInstruction": {
                "parts": [{"text": system_prompt}],
            },
            "contents": [
                {
                    "role": "user",
                    "parts": [{"text": user_message}],
                }
            ],
            "generationConfig": {
                "temperature": 0.1,
                "responseMimeType": "application/json",
            },
        }

    def _post_with_retry(
        self,
        payload: dict,
        retries: int = 3,
    ) -> str | None:
        """POST to Gemini with token refresh on 401 and back-off on 429/503."""
        delay = 2.0
        last_exc: Exception | None = None

        # ── Log prompt sizes once (same payload on every attempt) ─────────
        import json as _json
        _payload_bytes = len(_json.dumps(payload).encode())
        _sys_chars = len(payload.get("systemInstruction", {})
                         .get("parts", [{}])[0].get("text", ""))
        _usr_chars = len((payload.get("contents", [{}])[0]
                         .get("parts", [{}])[0].get("text", "")))
        logger.info(
            "⏱ [GeminiLLM] prompt_size system_chars=%d user_chars=%d "
            "total_payload_kb=%.1f est_tokens≈%d",
            _sys_chars,
            _usr_chars,
            _payload_bytes / 1024,
            (_sys_chars + _usr_chars) // 4,
        )

        for attempt in range(1, retries + 1):
            token = self._auth.get_token()
            headers = {
                "Authorization": f"Bearer {token}",
                "Content-Type": "application/json",
            }
            try:
                _t0 = time.perf_counter()
                resp = requests.post(
                    self._url,
                    headers=headers,
                    json=payload,
                    proxies=self._proxies,
                    timeout=self._settings.llm_timeout,
                )
                _resp_kb = len(resp.content) / 1024
                logger.info(
                    "⏱ [GeminiLLM] operation=http_post attempt=%d "
                    "status=%d elapsed=%.3fs response_kb=%.1f",
                    attempt, resp.status_code, time.perf_counter() - _t0, _resp_kb,
                )
            except requests.RequestException as exc:
                last_exc = exc
                logger.warning("Gemini HTTP error (attempt %d/%d): %s", attempt, retries, exc)
                time.sleep(delay)
                delay *= 2
                continue

            if resp.status_code == 401:
                logger.warning("Gemini 401 — invalidating token and retrying")
                self._auth.invalidate()
                continue

            if resp.status_code in (429, 503):
                logger.warning(
                    "Gemini %d (attempt %d/%d) — back-off %.1fs",
                    resp.status_code, attempt, retries, delay,
                )
                time.sleep(delay)
                delay *= 2
                continue

            if not resp.ok:
                # Log but don't raise — caller can handle None gracefully
                logger.error("Gemini HTTP %d: %s", resp.status_code, resp.text[:300])
                return None

            return self._extract_text(resp.json())

        logger.error("Gemini failed after %d attempts", retries)
        return None

    def _extract_text(self, response_json: dict) -> str | None:
        """Pull the text content out of the Gemini generateContent response."""
        try:
            candidates = response_json.get("candidates", [])
            if not candidates:
                logger.warning("Gemini response contained no candidates")
                return None
            parts = candidates[0].get("content", {}).get("parts", [])
            if not parts:
                logger.warning("Gemini candidate contained no parts")
                return None
            text = parts[0].get("text", "").strip()
            return text if text else None
        except (KeyError, IndexError, TypeError) as exc:
            logger.error("Failed to parse Gemini response structure: %s", exc)
            return None

generate_json

generate_json(system_prompt: str, user_message: str) -> str | None

Send a prompt and return the raw JSON response string.

Parameters:

Name Type Description Default
system_prompt str

System-level instruction for Gemini.

required
user_message str

User-turn message content.

required

Returns:

Type Description
str | None

Raw JSON string from the model, or None on recoverable failure.

Raises:

Type Description
LLMError

On unrecoverable API failure.

Source code in prod/adapters/gemini_llm.py
def generate_json(
    self,
    system_prompt: str,
    user_message: str,
) -> str | None:
    """Send a prompt and return the raw JSON response string.

    Args:
        system_prompt: System-level instruction for Gemini.
        user_message:  User-turn message content.

    Returns:
        Raw JSON string from the model, or None on recoverable failure.

    Raises:
        LLMError: On unrecoverable API failure.
    """
    payload = self._build_payload(system_prompt, user_message)
    _t_total = time.perf_counter()
    result = self._post_with_retry(payload)
    logger.info(
        "⏱ [GeminiLLM] operation=generate_json_total elapsed=%.3fs",
        time.perf_counter() - _t_total,
    )
    return result

OpenAIEmbeddingAdapter

Implements EmbeddingPort using the OpenAI Embeddings REST API. Activated when EMBED_PROVIDER=openai.

Key behaviours:

  • Uses /v1/embeddings directly via requests (no SDK dependency)
  • Passes dimensions=EMBED_DIM to the API so output matches the pgvector column width
  • embed_query and embed_document use the same endpoint (OpenAI embeddings are symmetric)
  • HTTP 401 → raises AuthenticationError immediately (key is wrong, no retry)
  • HTTP 429 / 500 / 503 → exponential back-off

Dimension compatibility

If your database was initialised with vector(768) (the Vertex AI default), set EMBED_DIM=768 when using OpenAI models. text-embedding-3-small and text-embedding-3-large both accept a dimensions parameter to reduce their native output to any target size. For a fresh installation with OpenAI, setting EMBED_DIM=1536 gives better quality from text-embedding-3-small.

openai_embedding

adapters/openai_embedding.py ────────────────────────────────────────────────────────────────────────────── Implements EmbeddingPort using the OpenAI Embeddings API.

Key behaviour
  • Uses /v1/embeddings via raw requests (no openai SDK dependency)
  • Passes dimensions=settings.embed_dim so output matches whatever pgvector column width the DB was initialised with
  • embed_query and embed_document call the same endpoint (OpenAI embeddings are symmetric — no RETRIEVAL_QUERY / RETRIEVAL_DOCUMENT distinction)
  • Batches embed_documents_batch to stay within 2048-token-per-item limit
  • Retries on 429 / 500 with exponential back-off
Required env vars

OPENAI_API_KEY — your OpenAI secret key (sk-...) OPENAI_EMBED_MODEL — default: text-embedding-3-small EMBED_DIM — default: 768; 1536 recommended for text-embedding-3-small

To enable

Set EMBED_PROVIDER=openai in your .env file.

OpenAIEmbeddingAdapter

OpenAI text-embedding adapter.

Injected into HybridRetriever via services/container.py when EMBED_PROVIDER=openai is set in the environment.

Note on dimensions

OpenAI's text-embedding-3-* models accept a dimensions parameter to reduce the output to any size ≤ the model's native dimension. This adapter always passes settings.embed_dim so output vectors are compatible with the pgvector column width that the database was initialised with.

Source code in prod/adapters/openai_embedding.py
class OpenAIEmbeddingAdapter:
    """OpenAI text-embedding adapter.

    Injected into HybridRetriever via services/container.py when
    ``EMBED_PROVIDER=openai`` is set in the environment.

    Note on dimensions:
        OpenAI's ``text-embedding-3-*`` models accept a ``dimensions``
        parameter to reduce the output to any size ≤ the model's native
        dimension.  This adapter always passes ``settings.embed_dim`` so
        output vectors are compatible with the pgvector column width that
        the database was initialised with.
    """

    def __init__(self, settings: Settings) -> None:
        if not settings.openai_api_key:
            raise AuthenticationError(
                "OPENAI_API_KEY is not set. "
                "Add it to your .env file or environment."
            )
        self._settings = settings
        self._headers = {
            "Authorization": f"Bearer {settings.openai_api_key}",
            "Content-Type": "application/json",
        }
        logger.debug(
            "OpenAIEmbeddingAdapter ready | model=%s dim=%d",
            settings.openai_embed_model,
            settings.embed_dim,
        )

    # ── EmbeddingPort implementation ───────────────────────────────────────

    @property
    def model_name(self) -> str:
        """Name of the underlying OpenAI embedding model."""
        return self._settings.openai_embed_model

    @property
    def dimensions(self) -> int:
        """Vector dimensionality (controlled by ``EMBED_DIM`` env var)."""
        return self._settings.embed_dim

    def embed_query(self, text: str) -> list[float]:
        """Embed a search query.

        OpenAI embeddings are symmetric, so the same endpoint and model
        is used for queries and documents.

        Args:
            text: Natural-language query string.

        Returns:
            Dense float vector of length ``settings.embed_dim``.

        Raises:
            EmbeddingError: On API failure or unexpected response shape.
        """
        return self._embed_one(text)

    def embed_document(self, text: str, title: str = "") -> list[float]:
        """Embed a document for storage.

        The ``title`` parameter is accepted for interface compatibility
        but is not used by the OpenAI embeddings API.

        Args:
            text:  Document body text.
            title: Ignored (OpenAI API does not use document titles).

        Returns:
            Dense float vector of length ``settings.embed_dim``.
        """
        return self._embed_one(text)

    def embed_documents_batch(
        self,
        texts: list[str],
        titles: list[str] | None = None,
    ) -> list[list[float] | None]:
        """Embed multiple documents, batched to respect API limits.

        Args:
            texts:  List of document strings to embed.
            titles: Ignored (OpenAI API does not use document titles).

        Returns:
            List of float vectors; ``None`` for any item that failed.
        """
        if not texts:
            return []

        batch_size = self._settings.embed_batch_size
        all_results: list[list[float] | None] = []

        for start in range(0, len(texts), batch_size):
            chunk = texts[start : start + batch_size]
            try:
                batch_results = self._embed_batch(chunk)
                all_results.extend(batch_results)
            except EmbeddingError:
                logger.warning(
                    "OpenAI batch embed failed for items %d%d; "
                    "returning None for all %d items in chunk",
                    start,
                    start + len(chunk) - 1,
                    len(chunk),
                )
                all_results.extend([None] * len(chunk))

        return all_results

    # ── Private helpers ────────────────────────────────────────────────────

    def _embed_one(self, text: str) -> list[float]:
        """Call the OpenAI embeddings endpoint for a single text string."""
        payload = {
            "model": self._settings.openai_embed_model,
            "input": text,
            "dimensions": self._settings.embed_dim,
        }
        data = self._post_with_retry(payload)
        try:
            return data["data"][0]["embedding"]
        except (KeyError, IndexError, TypeError) as exc:
            raise EmbeddingError(
                f"Unexpected OpenAI embed response shape: {list(data.keys())}"
            ) from exc

    def _embed_batch(self, texts: list[str]) -> list[list[float] | None]:
        """Call the OpenAI embeddings endpoint for a list of strings."""
        payload = {
            "model": self._settings.openai_embed_model,
            "input": texts,
            "dimensions": self._settings.embed_dim,
        }
        data = self._post_with_retry(payload)
        items = data.get("data", [])
        # API returns items in index order but we validate just in case
        result: list[list[float] | None] = [None] * len(texts)
        for item in items:
            idx = item.get("index", -1)
            if 0 <= idx < len(texts):
                result[idx] = item.get("embedding")
        return result

    def _post_with_retry(
        self,
        payload: dict,
        retries: int = 3,
    ) -> dict:
        """POST to the OpenAI API with retry on 429 / 500."""
        delay = 2.0
        last_exc: Exception | None = None

        for attempt in range(1, retries + 1):
            try:
                resp = requests.post(
                    _OPENAI_EMBED_URL,
                    headers=self._headers,
                    json=payload,
                    timeout=self._settings.embed_timeout,
                )
            except requests.RequestException as exc:
                last_exc = exc
                logger.warning(
                    "OpenAI embed request error (attempt %d/%d): %s",
                    attempt, retries, exc,
                )
                time.sleep(delay)
                delay *= 2
                continue

            if resp.status_code == 401:
                raise AuthenticationError(
                    "OpenAI returned 401 Unauthorised. "
                    "Check that OPENAI_API_KEY is valid."
                )

            if resp.status_code in (429, 500, 503):
                logger.warning(
                    "OpenAI embed %d (attempt %d/%d) — back-off %.1fs",
                    resp.status_code, attempt, retries, delay,
                )
                time.sleep(delay)
                delay *= 2
                continue

            if not resp.ok:
                raise EmbeddingError(
                    f"OpenAI embed HTTP {resp.status_code}: {resp.text[:300]}"
                )

            return resp.json()

        raise EmbeddingError(
            f"OpenAI embed failed after {retries} attempts"
        ) from last_exc

model_name property

model_name: str

Name of the underlying OpenAI embedding model.

dimensions property

dimensions: int

Vector dimensionality (controlled by EMBED_DIM env var).

embed_query

embed_query(text: str) -> list[float]

Embed a search query.

OpenAI embeddings are symmetric, so the same endpoint and model is used for queries and documents.

Parameters:

Name Type Description Default
text str

Natural-language query string.

required

Returns:

Type Description
list[float]

Dense float vector of length settings.embed_dim.

Raises:

Type Description
EmbeddingError

On API failure or unexpected response shape.

Source code in prod/adapters/openai_embedding.py
def embed_query(self, text: str) -> list[float]:
    """Embed a search query.

    OpenAI embeddings are symmetric, so the same endpoint and model
    is used for queries and documents.

    Args:
        text: Natural-language query string.

    Returns:
        Dense float vector of length ``settings.embed_dim``.

    Raises:
        EmbeddingError: On API failure or unexpected response shape.
    """
    return self._embed_one(text)

embed_document

embed_document(text: str, title: str = '') -> list[float]

Embed a document for storage.

The title parameter is accepted for interface compatibility but is not used by the OpenAI embeddings API.

Parameters:

Name Type Description Default
text str

Document body text.

required
title str

Ignored (OpenAI API does not use document titles).

''

Returns:

Type Description
list[float]

Dense float vector of length settings.embed_dim.

Source code in prod/adapters/openai_embedding.py
def embed_document(self, text: str, title: str = "") -> list[float]:
    """Embed a document for storage.

    The ``title`` parameter is accepted for interface compatibility
    but is not used by the OpenAI embeddings API.

    Args:
        text:  Document body text.
        title: Ignored (OpenAI API does not use document titles).

    Returns:
        Dense float vector of length ``settings.embed_dim``.
    """
    return self._embed_one(text)

embed_documents_batch

embed_documents_batch(texts: list[str], titles: list[str] | None = None) -> list[list[float] | None]

Embed multiple documents, batched to respect API limits.

Parameters:

Name Type Description Default
texts list[str]

List of document strings to embed.

required
titles list[str] | None

Ignored (OpenAI API does not use document titles).

None

Returns:

Type Description
list[list[float] | None]

List of float vectors; None for any item that failed.

Source code in prod/adapters/openai_embedding.py
def embed_documents_batch(
    self,
    texts: list[str],
    titles: list[str] | None = None,
) -> list[list[float] | None]:
    """Embed multiple documents, batched to respect API limits.

    Args:
        texts:  List of document strings to embed.
        titles: Ignored (OpenAI API does not use document titles).

    Returns:
        List of float vectors; ``None`` for any item that failed.
    """
    if not texts:
        return []

    batch_size = self._settings.embed_batch_size
    all_results: list[list[float] | None] = []

    for start in range(0, len(texts), batch_size):
        chunk = texts[start : start + batch_size]
        try:
            batch_results = self._embed_batch(chunk)
            all_results.extend(batch_results)
        except EmbeddingError:
            logger.warning(
                "OpenAI batch embed failed for items %d%d; "
                "returning None for all %d items in chunk",
                start,
                start + len(chunk) - 1,
                len(chunk),
            )
            all_results.extend([None] * len(chunk))

    return all_results

OpenAILLMAdapter

Implements LLMPort using the OpenAI Chat Completions REST API. Activated when LLM_PROVIDER=openai.

Key behaviours:

  • Uses /v1/chat/completions with response_format: {"type": "json_object"} — guarantees syntactically valid JSON output (same guarantee as Gemini's responseMimeType: application/json)
  • temperature: 0.1 — consistent with the Gemini adapter
  • The existing build_system_prompt() already includes the word "JSON", satisfying OpenAI's JSON mode requirement with no prompt changes needed
  • HTTP 401 → raises AuthenticationError immediately
  • HTTP 429 / 500 / 503 → exponential back-off

openai_llm

adapters/openai_llm.py ────────────────────────────────────────────────────────────────────────────── Implements LLMPort using the OpenAI Chat Completions API.

Key behaviour
  • Uses /v1/chat/completions via raw requests (no openai SDK dependency)
  • Requests JSON output via response_format={"type": "json_object"}
  • system_prompt → system role message; user_message → user role message
  • Retries on 429 / 500 with exponential back-off
  • Returns the raw JSON string (caller parses); None on recoverable failure
Required env vars

OPENAI_API_KEY — your OpenAI secret key (sk-...) OPENAI_LLM_MODEL — default: gpt-4o

To enable

Set LLM_PROVIDER=openai in your .env file.

OpenAILLMAdapter

OpenAI GPT chat completions adapter.

Injected into LLMReranker via services/container.py when LLM_PROVIDER=openai is set in the environment.

JSON mode is enabled via response_format={"type": "json_object"}, which guarantees the model returns syntactically valid JSON without needing to strip markdown fences.

.. note:: OpenAI's JSON mode requires the word "JSON" to appear somewhere in the prompt. The existing build_system_prompt() in config/prompts.py already includes this — no prompt changes are needed.

Source code in prod/adapters/openai_llm.py
class OpenAILLMAdapter:
    """OpenAI GPT chat completions adapter.

    Injected into LLMReranker via services/container.py when
    ``LLM_PROVIDER=openai`` is set in the environment.

    JSON mode is enabled via ``response_format={"type": "json_object"}``,
    which guarantees the model returns syntactically valid JSON without
    needing to strip markdown fences.

    .. note::
        OpenAI's JSON mode requires the word "JSON" to appear somewhere in
        the prompt.  The existing ``build_system_prompt()`` in
        ``config/prompts.py`` already includes this — no prompt changes are
        needed.
    """

    def __init__(self, settings: Settings) -> None:
        if not settings.openai_api_key:
            raise AuthenticationError(
                "OPENAI_API_KEY is not set. "
                "Add it to your .env file or environment."
            )
        self._settings = settings
        self._headers = {
            "Authorization": f"Bearer {settings.openai_api_key}",
            "Content-Type": "application/json",
        }
        logger.debug("OpenAILLMAdapter ready | model=%s", settings.openai_llm_model)

    # ── LLMPort implementation ─────────────────────────────────────────────

    @property
    def model_name(self) -> str:
        """Name of the underlying OpenAI chat model."""
        return self._settings.openai_llm_model

    def generate_json(
        self,
        system_prompt: str,
        user_message: str,
    ) -> str | None:
        """Send a prompt and return the raw JSON response string.

        Uses ``response_format={"type": "json_object"}`` to enforce valid
        JSON output from the model.  The caller is responsible for parsing
        the returned string.

        Args:
            system_prompt: System-level instruction for the model.
            user_message:  User-turn message content.

        Returns:
            Raw JSON string from the model, or ``None`` on recoverable failure.

        Raises:
            LLMError: On unrecoverable API failure (non-2xx after all retries).
        """
        payload = self._build_payload(system_prompt, user_message)
        return self._post_with_retry(payload)

    # ── Private helpers ────────────────────────────────────────────────────

    def _build_payload(self, system_prompt: str, user_message: str) -> dict:
        """Build the OpenAI chat completions request body."""
        return {
            "model": self._settings.openai_llm_model,
            "messages": [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": user_message},
            ],
            "temperature": 0.1,
            "response_format": {"type": "json_object"},
        }

    def _post_with_retry(
        self,
        payload: dict,
        retries: int = 3,
    ) -> str | None:
        """POST to the OpenAI API with back-off on 429 / 500."""
        delay = 2.0
        last_exc: Exception | None = None

        for attempt in range(1, retries + 1):
            try:
                resp = requests.post(
                    _OPENAI_CHAT_URL,
                    headers=self._headers,
                    json=payload,
                    timeout=self._settings.llm_timeout,
                )
            except requests.RequestException as exc:
                last_exc = exc
                logger.warning(
                    "OpenAI LLM request error (attempt %d/%d): %s",
                    attempt, retries, exc,
                )
                time.sleep(delay)
                delay *= 2
                continue

            if resp.status_code == 401:
                raise AuthenticationError(
                    "OpenAI returned 401 Unauthorised. "
                    "Check that OPENAI_API_KEY is valid."
                )

            if resp.status_code in (429, 500, 503):
                logger.warning(
                    "OpenAI LLM %d (attempt %d/%d) — back-off %.1fs",
                    resp.status_code, attempt, retries, delay,
                )
                time.sleep(delay)
                delay *= 2
                continue

            if not resp.ok:
                logger.error(
                    "OpenAI LLM HTTP %d: %s",
                    resp.status_code, resp.text[:300],
                )
                return None

            return self._extract_text(resp.json())

        logger.error("OpenAI LLM failed after %d attempts", retries)
        return None

    def _extract_text(self, response_json: dict) -> str | None:
        """Pull the content string out of the chat completions response."""
        try:
            choices = response_json.get("choices", [])
            if not choices:
                logger.warning("OpenAI response contained no choices")
                return None
            content = choices[0].get("message", {}).get("content", "").strip()
            return content if content else None
        except (KeyError, IndexError, TypeError) as exc:
            logger.error("Failed to parse OpenAI response structure: %s", exc)
            return None

model_name property

model_name: str

Name of the underlying OpenAI chat model.

generate_json

generate_json(system_prompt: str, user_message: str) -> str | None

Send a prompt and return the raw JSON response string.

Uses response_format={"type": "json_object"} to enforce valid JSON output from the model. The caller is responsible for parsing the returned string.

Parameters:

Name Type Description Default
system_prompt str

System-level instruction for the model.

required
user_message str

User-turn message content.

required

Returns:

Type Description
str | None

Raw JSON string from the model, or None on recoverable failure.

Raises:

Type Description
LLMError

On unrecoverable API failure (non-2xx after all retries).

Source code in prod/adapters/openai_llm.py
def generate_json(
    self,
    system_prompt: str,
    user_message: str,
) -> str | None:
    """Send a prompt and return the raw JSON response string.

    Uses ``response_format={"type": "json_object"}`` to enforce valid
    JSON output from the model.  The caller is responsible for parsing
    the returned string.

    Args:
        system_prompt: System-level instruction for the model.
        user_message:  User-turn message content.

    Returns:
        Raw JSON string from the model, or ``None`` on recoverable failure.

    Raises:
        LLMError: On unrecoverable API failure (non-2xx after all retries).
    """
    payload = self._build_payload(system_prompt, user_message)
    return self._post_with_retry(payload)

PostgresDatabaseAdapter

Implements DatabasePort using psycopg2 and the pgvector extension.

Key behaviours:

  • Connection is opened lazily on first query and reused
  • On OperationalError, reconnects once automatically
  • vector_search uses the HNSW index via the <=> cosine operator
  • fts_search uses the GIN-indexed tsvector column
  • fetch_by_codes uses ANY(%s) for a single round-trip to fetch N records

Scaling to FastAPI

The current implementation uses a single persistent connection — appropriate for Streamlit (single process, single thread). For a multi-threaded FastAPI deployment, replace the connection management here with psycopg2.pool.ThreadedConnectionPool. No other file needs to change.

postgres_db

adapters/postgres_db.py ────────────────────────────────────────────────────────────────────────────── Implements DatabasePort using psycopg2 + pgvector.

Database layout (from ingest.py): Table : anzsic_codes Cols : anzsic_code (PK), anzsic_desc, class_code, class_desc, group_code, group_desc, subdivision_desc, division_desc, class_exclusions, enriched_text, embedding vector(768), fts_vector Index : HNSW cosine (embedding), GIN (fts_vector)

Three atomic methods match DatabasePort

vector_search → ANN search via pgvector <=> operator fts_search → FTS via tsquery fetch_by_codes → bulk SELECT by primary key list

Connection management
  • A single connection is opened lazily and reused.
  • On OperationalError the connection is reset and one retry is attempted.
  • For Streamlit (single-process, single-thread) this is sufficient.
  • For FastAPI: replace with a psycopg2 connection pool (e.g. psycopg2.pool. ThreadedConnectionPool) — change only this file.

To swap the database engine (e.g. to Weaviate or Pinecone): 1. Write WeaviateDatabaseAdapter implementing DatabasePort 2. Change ONE import in services/container.py

PostgresDatabaseAdapter

psycopg2 + pgvector implementation of DatabasePort.

Uses a ThreadedConnectionPool so concurrent threads (FastAPI + Uvicorn thread pool) each borrow their own connection. The pool is created once per adapter instance and shared across all threads in a process.

Injected into HybridRetriever via services/container.py.

Source code in prod/adapters/postgres_db.py
class PostgresDatabaseAdapter:
    """psycopg2 + pgvector implementation of DatabasePort.

    Uses a ThreadedConnectionPool so concurrent threads (FastAPI + Uvicorn
    thread pool) each borrow their own connection.  The pool is created once
    per adapter instance and shared across all threads in a process.

    Injected into HybridRetriever via services/container.py.
    """

    def __init__(self, settings: Settings) -> None:
        self._dsn = settings.db_dsn
        self._pool: Any = None
        # Legacy single-connection attribute kept for close() backward compat
        self._conn: Any = None
        logger.debug("PostgresDatabaseAdapter ready | dsn=%s", self._dsn)

    # ── DatabasePort implementation ────────────────────────────────────────

    def vector_search(
        self,
        embedding: list[float],
        limit: int,
    ) -> list[tuple[str, int]]:
        """Approximate nearest-neighbour search via pgvector HNSW index.

        Returns list of (anzsic_code, rank) tuples, rank starting at 1.
        """
        sql = """
            SELECT anzsic_code,
                   ROW_NUMBER() OVER (ORDER BY embedding <=> %s::vector) AS rank
            FROM   anzsic_codes
            WHERE  embedding IS NOT NULL
            ORDER  BY embedding <=> %s::vector
            LIMIT  %s
        """
        try:
            rows = self._execute(sql, (embedding, embedding, limit))
            return [(row["anzsic_code"], row["rank"]) for row in rows]
        except Exception as exc:
            raise DatabaseError(f"vector_search failed: {exc}") from exc

    def fts_search(
        self,
        query_text: str,
        limit: int,
    ) -> list[tuple[str, int]]:
        """Full-text search using the GIN-indexed tsvector column.

        Uses OR between stemmed query tokens so that descriptive free-text
        queries like "fixes pipes in industries for AC" match records containing
        ANY of the meaningful terms (pipe, fix, industri, etc.) rather than
        requiring ALL terms to be present in the same record (AND semantics of
        plainto_tsquery would return zero hits for most natural-language inputs).

        Falls back to an empty list rather than raising if no FTS results
        (colloquial queries often produce zero FTS hits — vector covers it).
        """
        sql = """
            SELECT anzsic_code,
                   ROW_NUMBER() OVER (
                       ORDER BY ts_rank_cd(fts_vector, query) DESC
                   ) AS rank
            FROM   anzsic_codes,
                   (SELECT to_tsquery(string_agg(lexeme, ' | '))
                    FROM   unnest(to_tsvector('english', %s))
                   ) AS t(query)
            WHERE  query IS NOT NULL
              AND  fts_vector @@ query
            ORDER  BY ts_rank_cd(fts_vector, query) DESC
            LIMIT  %s
        """
        try:
            rows = self._execute(sql, (query_text, limit))
            return [(row["anzsic_code"], row["rank"]) for row in rows]
        except Exception as exc:
            logger.warning("fts_search error (returning empty): %s", exc)
            return []

    def fetch_by_codes(self, codes: list[str]) -> dict[str, dict]:
        """Fetch full records for a list of ANZSIC codes.

        Returns a dict keyed by anzsic_code.  Missing codes are absent.
        """
        if not codes:
            return {}
        sql = f"""
            SELECT {_SELECT_COLS}
            FROM   anzsic_codes
            WHERE  anzsic_code = ANY(%s)
        """
        try:
            rows = self._execute(sql, (codes,))
            return {row["anzsic_code"]: dict(row) for row in rows}
        except Exception as exc:
            raise DatabaseError(f"fetch_by_codes failed: {exc}") from exc

    # ── Connection pool helpers ────────────────────────────────────────────

    def _get_pool(self) -> Any:
        """Return (or lazily create) the ThreadedConnectionPool."""
        if self._pool is None:
            try:
                pool = psycopg2.pool.ThreadedConnectionPool(
                    _POOL_MINCONN,
                    _POOL_MAXCONN,
                    self._dsn,
                )
                # Register pgvector on every connection in the pool
                for _ in range(_POOL_MINCONN):
                    conn = pool.getconn()
                    conn.autocommit = True
                    register_vector(conn)
                    pool.putconn(conn)
                logger.info(
                    "PostgresDatabaseAdapter: pool created min=%d max=%d",
                    _POOL_MINCONN,
                    _POOL_MAXCONN,
                )
                self._pool = pool
            except psycopg2.Error as exc:
                raise DatabaseError(f"Cannot create connection pool: {exc}") from exc
        return self._pool

    def _new_conn(self) -> Any:
        """Open a single connection with pgvector registered (pool bootstrap)."""
        try:
            conn = psycopg2.connect(self._dsn)
            conn.autocommit = True
            register_vector(conn)
            logger.debug("PostgresDatabaseAdapter: new connection opened")
            return conn
        except psycopg2.Error as exc:
            raise DatabaseError(f"Cannot connect to database: {exc}") from exc

    def _execute(self, sql: str, params: tuple) -> list[dict]:
        """Execute a query borrowing a connection from the pool.

        The connection is returned to the pool after use, whether the query
        succeeds or fails — so the pool is never exhausted by exceptions.
        """
        pool = self._get_pool()
        conn = pool.getconn()
        try:
            if conn.autocommit is False:
                conn.autocommit = True
                register_vector(conn)
            with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
                cur.execute(sql, params)
                rows = list(cur.fetchall())
            pool.putconn(conn)
            return rows
        except psycopg2.OperationalError as exc:
            # Connection may have gone stale — discard it and open a fresh one
            logger.warning("DB OperationalError — replacing stale connection: %s", exc)
            pool.putconn(conn, close=True)
            raise DatabaseError(f"DB query failed (stale connection discarded): {exc}") from exc
        except Exception:
            pool.putconn(conn)
            raise

    def close(self) -> None:
        """Close all connections in the pool (called on process shutdown)."""
        if self._pool:
            self._pool.closeall()
            logger.debug("PostgresDatabaseAdapter: pool closed")
vector_search(embedding: list[float], limit: int) -> list[tuple[str, int]]

Approximate nearest-neighbour search via pgvector HNSW index.

Returns list of (anzsic_code, rank) tuples, rank starting at 1.

Source code in prod/adapters/postgres_db.py
def vector_search(
    self,
    embedding: list[float],
    limit: int,
) -> list[tuple[str, int]]:
    """Approximate nearest-neighbour search via pgvector HNSW index.

    Returns list of (anzsic_code, rank) tuples, rank starting at 1.
    """
    sql = """
        SELECT anzsic_code,
               ROW_NUMBER() OVER (ORDER BY embedding <=> %s::vector) AS rank
        FROM   anzsic_codes
        WHERE  embedding IS NOT NULL
        ORDER  BY embedding <=> %s::vector
        LIMIT  %s
    """
    try:
        rows = self._execute(sql, (embedding, embedding, limit))
        return [(row["anzsic_code"], row["rank"]) for row in rows]
    except Exception as exc:
        raise DatabaseError(f"vector_search failed: {exc}") from exc
fts_search(query_text: str, limit: int) -> list[tuple[str, int]]

Full-text search using the GIN-indexed tsvector column.

Uses OR between stemmed query tokens so that descriptive free-text queries like "fixes pipes in industries for AC" match records containing ANY of the meaningful terms (pipe, fix, industri, etc.) rather than requiring ALL terms to be present in the same record (AND semantics of plainto_tsquery would return zero hits for most natural-language inputs).

Falls back to an empty list rather than raising if no FTS results (colloquial queries often produce zero FTS hits — vector covers it).

Source code in prod/adapters/postgres_db.py
def fts_search(
    self,
    query_text: str,
    limit: int,
) -> list[tuple[str, int]]:
    """Full-text search using the GIN-indexed tsvector column.

    Uses OR between stemmed query tokens so that descriptive free-text
    queries like "fixes pipes in industries for AC" match records containing
    ANY of the meaningful terms (pipe, fix, industri, etc.) rather than
    requiring ALL terms to be present in the same record (AND semantics of
    plainto_tsquery would return zero hits for most natural-language inputs).

    Falls back to an empty list rather than raising if no FTS results
    (colloquial queries often produce zero FTS hits — vector covers it).
    """
    sql = """
        SELECT anzsic_code,
               ROW_NUMBER() OVER (
                   ORDER BY ts_rank_cd(fts_vector, query) DESC
               ) AS rank
        FROM   anzsic_codes,
               (SELECT to_tsquery(string_agg(lexeme, ' | '))
                FROM   unnest(to_tsvector('english', %s))
               ) AS t(query)
        WHERE  query IS NOT NULL
          AND  fts_vector @@ query
        ORDER  BY ts_rank_cd(fts_vector, query) DESC
        LIMIT  %s
    """
    try:
        rows = self._execute(sql, (query_text, limit))
        return [(row["anzsic_code"], row["rank"]) for row in rows]
    except Exception as exc:
        logger.warning("fts_search error (returning empty): %s", exc)
        return []

fetch_by_codes

fetch_by_codes(codes: list[str]) -> dict[str, dict]

Fetch full records for a list of ANZSIC codes.

Returns a dict keyed by anzsic_code. Missing codes are absent.

Source code in prod/adapters/postgres_db.py
def fetch_by_codes(self, codes: list[str]) -> dict[str, dict]:
    """Fetch full records for a list of ANZSIC codes.

    Returns a dict keyed by anzsic_code.  Missing codes are absent.
    """
    if not codes:
        return {}
    sql = f"""
        SELECT {_SELECT_COLS}
        FROM   anzsic_codes
        WHERE  anzsic_code = ANY(%s)
    """
    try:
        rows = self._execute(sql, (codes,))
        return {row["anzsic_code"]: dict(row) for row in rows}
    except Exception as exc:
        raise DatabaseError(f"fetch_by_codes failed: {exc}") from exc

close

close() -> None

Close all connections in the pool (called on process shutdown).

Source code in prod/adapters/postgres_db.py
def close(self) -> None:
    """Close all connections in the pool (called on process shutdown)."""
    if self._pool:
        self._pool.closeall()
        logger.debug("PostgresDatabaseAdapter: pool closed")