Source code for pyflocker.ciphers.backends.cryptodome_.AES

"""Implementation of AES cipher."""

from __future__ import annotations

import contextlib
import typing
from types import MappingProxyType
from typing import TYPE_CHECKING

from Cryptodome.Cipher import AES

from pyflocker.ciphers import base, exc, modes
from pyflocker.ciphers.backends.symmetric import FileCipherWrapper, HMACWrapper
from pyflocker.ciphers.base import BaseAEADOneShotCipher
from pyflocker.ciphers.modes import Modes

from . import Hash
from .misc import derive_hkdf_key
from .symmetric import (
    AEADCipherTemplate,
    AuthenticationMixin,
    NonAEADCipherTemplate,
)

if TYPE_CHECKING:  # pragma: no cover
    import io

SUPPORTED = MappingProxyType(
    {
        # classic modes
        Modes.MODE_CTR: AES.MODE_CTR,
        Modes.MODE_CFB: AES.MODE_CFB,
        Modes.MODE_CFB8: AES.MODE_CFB,  # compat with pyca/cryptography
        Modes.MODE_OFB: AES.MODE_OFB,
        # AEAD modes
        Modes.MODE_GCM: AES.MODE_GCM,
        Modes.MODE_EAX: AES.MODE_EAX,
        Modes.MODE_SIV: AES.MODE_SIV,
        Modes.MODE_CCM: AES.MODE_CCM,
        Modes.MODE_OCB: AES.MODE_OCB,
    }
)

del MappingProxyType


def _get_aes_cipher(
    key: bytes,
    mode: Modes,
    iv_or_nonce: bytes,
) -> typing.Any:
    args: tuple[bytes, ...] = (iv_or_nonce,)
    kwargs = {}

    if mode == Modes.MODE_CFB:
        # compat with pyca/cryptography's CFB(...) mode
        kwargs = {"segment_size": 128}
    elif mode == Modes.MODE_CTR:
        kwargs = {
            # initial value of Cryptodome is nonce for pyca/cryptography
            "initial_value": int.from_bytes(iv_or_nonce, "big"),
            "nonce": typing.cast("int", b""),
        }

        args = ()

    return AES.new(key, SUPPORTED[mode], *args, **kwargs)  # type: ignore


[docs] class AEAD(AEADCipherTemplate): def __init__( self, encrypting: bool, key: bytes, mode: Modes, nonce: bytes, ) -> None: self._cipher = _get_aes_cipher(key, mode, nonce) self._updated = False self._encrypting = encrypting self._mode = mode self._tag = None # creating a context is relatively expensive here self._update_func = ( self._cipher.encrypt if encrypting else self._cipher.decrypt ) @property def mode(self) -> Modes: # pragma: no cover """The AES mode.""" return self._mode
[docs] class NonAEAD(NonAEADCipherTemplate): def __init__( self, encrypting: bool, key: bytes, mode: Modes, nonce: bytes, ) -> None: self._cipher = _get_aes_cipher(key, mode, nonce) self._updated = False self._encrypting = encrypting self._mode = mode # creating a context is relatively expensive here self._update_func = ( self._cipher.encrypt if encrypting else self._cipher.decrypt ) @property def mode(self) -> Modes: # pragma: no cover """The AES mode.""" return self._mode
[docs] class AEADOneShot(AuthenticationMixin, BaseAEADOneShotCipher): _write_into_buffer_unsupported = {Modes.MODE_OCB} def __init__( self, encrypting: bool, key: bytes, mode: Modes, nonce: bytes, ) -> None: self._cipher = _get_aes_cipher(key, mode, nonce) self._updated = False self._encrypting = encrypting self._mode = mode # creating a context is relatively expensive here self._update_func = self._get_update_func(encrypting, self._cipher)
[docs] def is_encrypting(self) -> bool: return self._encrypting
@property def mode(self) -> Modes: # pragma: no cover """The AES mode.""" return self._mode @staticmethod def _get_update_func( encrypting: bool, cipher: typing.Any, ) -> typing.Callable: if encrypting: func = cipher.encrypt_and_digest return lambda data, _=None, **k: func(data, **k)[0] func = cipher.decrypt_and_verify return lambda data, tag, **k: func(data, tag, **k)
[docs] def update(self, data: bytes, tag: bytes | None = None) -> bytes: result = self._update_helper(data, None, tag) assert result is not None return result
def _update_helper( self, data: bytes, out: bytearray | memoryview | None, tag: bytes | None = None, ) -> bytes | None: if self._update_func is None: raise exc.AlreadyFinalized update_func_kwargs = {} if self.mode in self._write_into_buffer_unsupported: # the mode does not support writing into mutable buffers. if out is not None: msg = f"writing into buffer unsupported by {self.mode.name}" raise NotImplementedError(msg) else: update_func_kwargs = {"output": out} if not self.is_encrypting() and tag is None: msg = "tag is required for decryption" raise ValueError(msg) result: bytes | None = None with contextlib.suppress(ValueError): result = self._update_func(data, tag, **update_func_kwargs) self.finalize(tag) return result
[docs] def update_into( self, data: bytes, out: bytearray | memoryview, tag: bytes | None = None, ) -> None: self._update_helper(data, out, tag)
[docs] def new( encrypting: bool, key: bytes, mode: Modes, iv_or_nonce: bytes, *, use_hmac: bool = False, tag_length: int | None = 16, digestmod: None | base.BaseHash = None, file: io.BytesIO | io.BufferedReader | None = None, ) -> AEAD | NonAEAD | AEADOneShot | FileCipherWrapper | HMACWrapper: """Create a new backend specific AES cipher. Args: encrypting: True is encryption and False is decryption. key: The key for the cipher. mode: The mode to use for AES cipher. iv_or_nonce: The Initialization Vector or Nonce for the cipher. It must not be repeated with the same key. Keyword Arguments: use_hmac: Should the cipher use HMAC as authentication or not, if it does not support AEAD. (Default: False) tag_length: Length of HMAC tag. By default, a **16 byte tag** is generated. If ``tag_length`` is ``None``, a **non-truncated** tag is generated. Length of non-truncated tag depends on the digest size of the underlying hash algorithm used by HMAC. digestmod: The algorithm to use for HMAC. If ``None``, Defaults to ``sha256``. Specifying this value without setting ``use_hmac`` to True has no effect. file: The source file to read from. If ``file`` is specified and the ``mode`` is not an AEAD mode, HMAC is always used. Important: The following arguments are ignored if the mode is an AEAD mode: - ``use_hmac`` - ``tag_length`` - ``digestmod`` Returns: AES cipher. Raises: NotImplementedError: if the ``mode`` does not support encryption/decryption of files or the mode is not supported by the backend. Note: Any other error that is raised is from the backend itself. """ cipher: base.BaseAEADCipher | base.BaseNonAEADCipher | FileCipherWrapper if mode not in supported_modes(): msg = f"{mode.name} not supported." raise exc.UnsupportedMode(msg) is_mode_aead = mode in modes.AEAD is_file = file is not None use_hmac = (is_file and not is_mode_aead) or (use_hmac and not is_mode_aead) if mode in modes.SPECIAL: if is_file: msg = ( f"{mode.name} does not support encryption/decryption of files." ) raise NotImplementedError(msg) return AEADOneShot(encrypting, key, mode, iv_or_nonce) if is_mode_aead: cipher = AEAD(encrypting, key, mode, iv_or_nonce) else: cipher = NonAEAD(encrypting, key, mode, iv_or_nonce) if use_hmac: cipher = _wrap_hmac( encrypting, key, mode, iv_or_nonce, digestmod or Hash.new("sha256"), tag_length, ) if file: assert isinstance(cipher, base.BaseAEADCipher) cipher = FileCipherWrapper(cipher, file) return cipher
[docs] def supported_modes() -> set[Modes]: """Lists all modes supported by AES cipher of this backend. Returns: set of :any:`Modes` object supported by backend. """ return set(SUPPORTED)
def _wrap_hmac( encrypting: bool, key: bytes, mode: Modes, iv_or_nonce: bytes, hashalgo: base.BaseHash, tag_length: int | None, ) -> HMACWrapper: ckey, hkey = derive_hkdf_key(key, len(key), hashalgo, iv_or_nonce) return HMACWrapper( NonAEAD(encrypting, ckey, mode, iv_or_nonce), hkey, iv_or_nonce, hashalgo, tag_length=tag_length, )