"""Azure Key Vault settings source.""" from __future__ import annotations as _annotations from collections.abc import Iterator, Mapping from typing import TYPE_CHECKING from pydantic.alias_generators import to_snake from pydantic.fields import FieldInfo from .env import EnvSettingsSource if TYPE_CHECKING: from azure.core.credentials import TokenCredential from azure.core.exceptions import ResourceNotFoundError from azure.keyvault.secrets import SecretClient from pydantic_settings.main import BaseSettings else: TokenCredential = None ResourceNotFoundError = None SecretClient = None def import_azure_key_vault() -> None: global TokenCredential global SecretClient global ResourceNotFoundError try: from azure.core.credentials import TokenCredential from azure.core.exceptions import ResourceNotFoundError from azure.keyvault.secrets import SecretClient except ImportError as e: # pragma: no cover raise ImportError( 'Azure Key Vault dependencies are not installed, run `pip install pydantic-settings[azure-key-vault]`' ) from e class AzureKeyVaultMapping(Mapping[str, str | None]): _loaded_secrets: dict[str, str | None] _secret_client: SecretClient _secret_names: list[str] def __init__( self, secret_client: SecretClient, case_sensitive: bool, snake_case_conversion: bool, env_prefix: str | None, ) -> None: self._loaded_secrets = {} self._secret_client = secret_client self._case_sensitive = case_sensitive self._snake_case_conversion = snake_case_conversion self._env_prefix = env_prefix if env_prefix else '' self._secret_map: dict[str, str] = self._load_remote() def _load_remote(self) -> dict[str, str]: secret_names: Iterator[str] = ( secret.name for secret in self._secret_client.list_properties_of_secrets() if secret.name and secret.enabled ) if self._snake_case_conversion: name_map: dict[str, str] = {} for name in secret_names: if name.startswith(self._env_prefix): name_map[f'{self._env_prefix}{to_snake(name[len(self._env_prefix) :])}'] = name else: name_map[to_snake(name)] = name return name_map if self._case_sensitive: return {name: name for name in secret_names} return {name.lower(): name for name in secret_names} def __getitem__(self, key: str) -> str | None: new_key = key if self._snake_case_conversion: if key.startswith(self._env_prefix): new_key = f'{self._env_prefix}{to_snake(key[len(self._env_prefix) :])}' else: new_key = to_snake(key) elif not self._case_sensitive: new_key = key.lower() if new_key not in self._loaded_secrets: if new_key in self._secret_map: self._loaded_secrets[new_key] = self._secret_client.get_secret(self._secret_map[new_key]).value else: raise KeyError(key) return self._loaded_secrets[new_key] def __len__(self) -> int: return len(self._secret_map) def __iter__(self) -> Iterator[str]: return iter(self._secret_map.keys()) class AzureKeyVaultSettingsSource(EnvSettingsSource): _url: str _credential: TokenCredential def __init__( self, settings_cls: type[BaseSettings], url: str, credential: TokenCredential, dash_to_underscore: bool = False, case_sensitive: bool | None = None, snake_case_conversion: bool = False, env_prefix: str | None = None, env_parse_none_str: str | None = None, env_parse_enums: bool | None = None, ) -> None: import_azure_key_vault() self._url = url self._credential = credential self._dash_to_underscore = dash_to_underscore self._snake_case_conversion = snake_case_conversion super().__init__( settings_cls, case_sensitive=True if snake_case_conversion else case_sensitive, env_prefix=env_prefix, env_nested_delimiter='__' if snake_case_conversion else '--', env_ignore_empty=False, env_parse_none_str=env_parse_none_str, env_parse_enums=env_parse_enums, ) def _load_env_vars(self) -> Mapping[str, str | None]: secret_client = SecretClient(vault_url=self._url, credential=self._credential) return AzureKeyVaultMapping( secret_client=secret_client, case_sensitive=self.case_sensitive, snake_case_conversion=self._snake_case_conversion, env_prefix=self.env_prefix, ) def _extract_field_info(self, field: FieldInfo, field_name: str) -> list[tuple[str, str, bool]]: if self._snake_case_conversion: field_info = list((x[0], x[1], x[2]) for x in super()._extract_field_info(field, field_name)) return field_info if self._dash_to_underscore: return list((x[0], x[1].replace('_', '-'), x[2]) for x in super()._extract_field_info(field, field_name)) return super()._extract_field_info(field, field_name) def __repr__(self) -> str: return f'{self.__class__.__name__}(url={self._url!r}, env_nested_delimiter={self.env_nested_delimiter!r})' __all__ = ['AzureKeyVaultMapping', 'AzureKeyVaultSettingsSource']