Source code for pyflocker.ciphers.backends.cryptography_.AES

"""Implementation of AES cipher."""

from __future__ import annotations

import hmac
import struct
import typing
from types import MappingProxyType

import cryptography.exceptions as bkx
from cryptography.hazmat.backends import default_backend as defb
from cryptography.hazmat.primitives import cmac
from cryptography.hazmat.primitives.ciphers import (
    Cipher as CrCipher,
    aead,
    algorithms as algo,
    modes,
)

from pyflocker.ciphers import (
    base,
    exc,
    modes as modes_,
)
from pyflocker.ciphers.backends.symmetric import (
    FileCipherWrapper,
    HMACWrapper,
    _DecryptionCtx,
    _EncryptionCtx,
)
from pyflocker.ciphers.modes import Modes

from . import Hash
from .misc import derive_hkdf_key
from .symmetric import AEADCipherTemplate, NonAEADCipherTemplate

if typing.TYPE_CHECKING:
    import io

SUPPORTED = MappingProxyType(
    {
        Modes.MODE_GCM: modes.GCM,
        Modes.MODE_EAX: None,  # not defined by backend
        Modes.MODE_CTR: modes.CTR,
        Modes.MODE_CFB8: modes.CFB8,
        Modes.MODE_CFB: modes.CFB,
        Modes.MODE_OFB: modes.OFB,
        Modes.MODE_CCM: aead.AESCCM,
    }
)

del MappingProxyType


[docs] class AEAD(AEADCipherTemplate): def __init__( self, encrypting: bool, key: bytes, mode: Modes, nonce: bytes, ) -> None: self._encrypting = encrypting self._updated = False self._tag = None self._mode = mode cipher = _aes_cipher(key, mode, nonce) # cryptography already provides a context if encrypting: self._ctx = cipher.encryptor() else: self._ctx = cipher.decryptor() @property def mode(self) -> Modes: """The AES mode.""" return self._mode
[docs] class NonAEAD(NonAEADCipherTemplate): def __init__( self, encrypting: bool, key: bytes, mode: Modes, nonce: bytes, ) -> None: self._encrypting = encrypting self._mode = mode cipher = _aes_cipher(key, mode, nonce) # cryptography already provides a context if encrypting: self._ctx = cipher.encryptor() else: self._ctx = cipher.decryptor() @property def mode(self) -> Modes: """The AES mode.""" return self._mode
[docs] class AEADOneShot(base.BaseAEADOneShotCipher): def __init__( self, encrypting: bool, key: bytes, mode: Modes, nonce: bytes, ) -> None: cipher = _aes_cipher(key, mode, nonce) self._mode = mode self._encrypting = encrypting self._aad = b"" self._tag = None self._nonce = nonce self._update_func = cipher.encrypt if encrypting else cipher.decrypt self._raise_on_tag_err = False self._tag_length = 16 @property def mode(self) -> Modes: """The AES mode.""" return self._mode
[docs] def authenticate(self, data: bytes) -> None: if self._update_func is None: raise exc.AlreadyFinalized self._aad += data
[docs] def is_encrypting(self) -> bool: return self._encrypting
[docs] def update( self, data: bytes, tag: bytes | None = None, ) -> bytes: if self._update_func is None: raise exc.AlreadyFinalized if self.is_encrypting(): ctxt_tag = self._update_func(self._nonce, data, self._aad) self._tag = ctxt_tag[-self._tag_length :] self.finalize(tag) return ctxt_tag[: -self._tag_length] if tag is None: msg = "tag is required for decryption." raise ValueError(msg) try: data = self._update_func(self._nonce, data + tag, self._aad) except bkx.InvalidTag: self._raise_on_tag_err = True finally: self.finalize(tag) return data
[docs] def update_into( self, data: bytes, out: bytearray | memoryview, tag: bytes | None = None, ) -> None: del tag, out, data raise NotImplementedError
[docs] def finalize(self, tag: bytes | None = None) -> None: if self._update_func is None: raise exc.AlreadyFinalized if not self.is_encrypting() and tag is None: msg = "tag is required for decryption." raise ValueError(msg) self._update_func = None if self._raise_on_tag_err: raise exc.DecryptionError
[docs] def calculate_tag(self) -> bytes | None: if self._update_func is not None: raise exc.NotFinalized return self._tag
class _AuthWrapper: """Wrapper class for objects that do not support memoryview objects.""" __slots__ = ("_auth",) def __init__(self, auth: typing.Any) -> None: self._auth = auth def update(self, data: bytes) -> None: self._auth.update(bytes(data)) def __getattr__(self, name: str) -> typing.Any: return getattr(self._auth, name) class _EAX: """AES-EAX adapter for pyca/cryptography.""" __slots__ = ( "_mac_len", "_omac", "_auth", "_omac_cache", "_cipher", "_updated", "__ctx", "__tag", ) def __init__(self, key: bytes, nonce: bytes, mac_len: int = 16) -> None: self._mac_len = mac_len self._omac = [cmac.CMAC(algo.AES(key), defb()) for _ in range(3)] for i in range(3): self._omac[i].update( bytes(1) * (algo.AES.block_size // 8 - 1) + struct.pack("B", i) # noqa: W503 ) self._omac[0].update(nonce) self._auth = _AuthWrapper(self._omac[1]) # create a cache since cryptography allows us to calculate tag # only once... why... self._omac_cache: list[bytes] = [] self._omac_cache.append(self._omac[0].finalize()) self._cipher = CrCipher( algo.AES(key), modes.CTR(self._omac_cache[0]), defb(), ) self.__ctx = None self._updated = False self.__tag = None @property def _ctx(self) -> typing.Any: # pragma: no cover """The Cipher context used by the backend. Maintains compatibility across pyca/cryptography style cipher instances. """ if self.__ctx: return self.__ctx._ctx return None def authenticate_additional_data(self, data: bytes) -> None: if self.__ctx is None: # pragma: no cover raise bkx.AlreadyFinalized if self._updated: raise ValueError # pragma: no cover self._auth.update(data) def encryptor(self) -> _EAX: self.__ctx = _EncryptionCtx( self._cipher.encryptor(), # type: ignore _AuthWrapper(self._omac[2]), 15, ) return self def decryptor(self) -> _EAX: self.__ctx = _DecryptionCtx( self._cipher.decryptor(), # type: ignore _AuthWrapper(self._omac[2]), ) return self def update(self, data: bytes) -> bytes: if self.__ctx is None: # pragma: no cover raise bkx.AlreadyFinalized self._updated = True return self.__ctx.update(data) def update_into( self, data: bytes, out: bytearray | memoryview, ) -> None: if self.__ctx is None: # pragma: no cover raise bkx.AlreadyFinalized self._updated = True self.__ctx.update_into(data, out) def finalize(self) -> None: """Finalizes the cipher.""" if self.__ctx is None: # pragma: no cover raise bkx.AlreadyFinalized tag = bytes(typing.cast("int", algo.AES.block_size) // 8) for i in range(3): if i >= len(self._omac_cache): self._omac_cache.append(self._omac[i].finalize()) tag = strxor(tag, self._omac_cache[i]) self.__tag, self.__ctx = tag[: self._mac_len], None def finalize_with_tag(self, tag: bytes) -> None: self.finalize() assert self.__tag is not None if not hmac.compare_digest(tag, self.__tag): raise bkx.InvalidTag # pragma: no cover @property def tag(self) -> bytes | None: if self.__ctx is not None: # pragma: no cover raise bkx.NotYetFinalized return self.__tag
[docs] def strxor(x: bytes, y: bytes) -> bytes: """XOR two byte strings""" return bytes(ix ^ iy for ix, iy in zip(x, y))
[docs] def new( encrypting: bool, key: bytes, mode: Modes, iv_or_nonce: bytes, *, use_hmac: bool = False, tag_length: int | None = 16, digestmod: None | base.BaseHash = None, file: io.BufferedReader | None = None, ) -> AEAD | NonAEAD | AEADOneShot | FileCipherWrapper | HMACWrapper: """Create a new backend specific AES cipher. Args: encrypting: True is encryption and False is decryption. key: The key for the cipher. mode: The mode to use for AES cipher. iv_or_nonce: The Initialization Vector or Nonce for the cipher. It must not be repeated with the same key. Keyword Arguments: use_hmac: Should the cipher use HMAC as authentication or not, if it does not support AEAD. (Default: False) tag_length: Length of HMAC tag. By default, a **16 byte tag** is generated. If ``tag_length`` is ``None``, a **non-truncated** tag is generated. Length of non-truncated tag depends on the digest size of the underlying hash algorithm used by HMAC. digestmod: The algorithm to use for HMAC. If ``None``, Defaults to ``sha256``. Specifying this value without setting ``use_hmac`` to True has no effect. file: The source file to read from. If `file` is specified and the `mode` is not an AEAD mode, HMAC is always used. Important: The following arguments are ignored if the mode is an AEAD mode: - ``use_hmac`` - ``tag_length`` - ``digestmod`` Returns: AES cipher. Raises: NotImplementedError: if the ``mode`` does not support encryption/decryption of files. Note: Any other error that is raised is from the backend itself. """ cipher: base.BaseAEADCipher | base.BaseNonAEADCipher | FileCipherWrapper if mode not in supported_modes(): msg = f"{mode.name} not supported." raise exc.UnsupportedMode(msg) is_mode_aead = mode in modes_.AEAD is_file = file is not None use_hmac = (is_file and not is_mode_aead) or (use_hmac and not is_mode_aead) if mode in modes_.SPECIAL: if is_file: msg = ( f"{mode.name} does not support encryption/decryption of files." ) raise NotImplementedError(msg) return AEADOneShot(encrypting, key, mode, iv_or_nonce) if is_mode_aead: cipher = AEAD(encrypting, key, mode, iv_or_nonce) else: cipher = NonAEAD(encrypting, key, mode, iv_or_nonce) if use_hmac: cipher = _wrap_hmac( encrypting, key, mode, iv_or_nonce, digestmod or Hash.new("sha256"), tag_length, ) if file: assert isinstance(cipher, base.BaseAEADCipher) cipher = FileCipherWrapper(cipher, file, offset=15) return cipher
[docs] def supported_modes() -> set[Modes]: """Lists all modes supported by AES cipher of this backend. Returns: set of :any:`Modes` object supported by backend. """ return set(SUPPORTED)
def _aes_cipher(key: bytes, mode: Modes, nonce_or_iv: bytes) -> typing.Any: if mode == Modes.MODE_EAX: return _EAX(key, nonce_or_iv) backend_mode = SUPPORTED[mode] assert backend_mode is not None if mode == Modes.MODE_CCM: if not 7 <= len(nonce_or_iv) <= 13: msg = "Length of nonce must be between 7 and 13 bytes" raise ValueError(msg) return backend_mode(key) assert not issubclass(backend_mode, aead.AESCCM) return CrCipher(algo.AES(key), backend_mode(nonce_or_iv)) def _wrap_hmac( encrypting: bool, key: bytes, mode: Modes, iv_or_nonce: bytes, digestmod: base.BaseHash, tag_length: int | None, ) -> HMACWrapper: ckey, hkey = derive_hkdf_key(key, len(key), digestmod, iv_or_nonce) return HMACWrapper( NonAEAD(encrypting, ckey, mode, iv_or_nonce), hkey, iv_or_nonce, digestmod, tag_length=tag_length, offset=15, )