from __future__ import annotations
import typing
from Cryptodome.PublicKey import RSA
from pyflocker.ciphers import base, exc
from pyflocker.ciphers.backends.asymmetric import OAEP, PSS
from .asymmetric import PROTECTION_SCHEMES, get_padding_algorithm
[docs]
class RSAPrivateKey(base.BaseRSAPrivateKey):
# Encodings supported by this key.
_ENCODINGS = {
"PEM": "PEM",
"DER": "DER",
}
# Formats supported by this key.
_FORMATS = {
"PKCS1": "PKCS1",
"TraditionalOpenSSL": "PKCS1",
"PKCS8": "PKCS8",
}
# The default protection algorithm used for encrypting the private key.
_DEFAULT_PROTECTION = "scryptAndAES256-CBC"
def __init__(
self,
n: int | None,
e: int = 65537,
_key: RSA.RsaKey | None = None,
) -> None:
if _key is not None:
self._key = _key
else:
if not isinstance(n, int): # pragma: no cover
msg = "n must be an integer value"
raise TypeError(msg)
self._key = RSA.generate(n, e=e)
@property
def p(self) -> int:
return self._key.p
@property
def q(self) -> int:
return self._key.q
@property
def d(self) -> int:
return self._key.d
@property
def n(self) -> int:
return self._key.n
@property
def e(self) -> int:
return self._key.e
@property
def key_size(self) -> int:
return self._key.size_in_bits()
[docs]
def decryptor(
self,
padding: base.BaseAsymmetricPadding | None = None,
) -> DecryptorContext:
if padding is None: # pragma: no cover
padding = OAEP()
return DecryptorContext(
get_padding_algorithm(padding, self._key, padding),
)
[docs]
def signer(
self,
padding: base.BaseAsymmetricPadding | None = None,
) -> SignerContext:
if padding is None: # pragma: no cover
padding = PSS()
return SignerContext(
get_padding_algorithm(padding, self._key, padding),
)
[docs]
def public_key(self) -> RSAPublicKey:
return RSAPublicKey(self._key.publickey())
[docs]
def serialize(
self,
encoding: str = "PEM",
format: str = "PKCS8",
passphrase: bytes | None = None,
*,
protection: str | None = None,
) -> bytes:
"""Serialize the private key.
Args:
encoding: PEM or DER (defaults to PEM).
format: PKCS1 or PKCS8 (defaults to PKCS8).
passphrase:
a bytes object to use for encrypting the private key. If
``passphrase`` is None, the private key will be exported in the
clear!
Keyword Arguments:
protection:
The protection scheme to use. Supplying a value for protection
has meaning only if the ``format`` is PKCS8. If ``None`` is
provided ``scryptAndAES256-CBC`` is used as the protection
scheme.
Returns:
Serialized key as a bytes object.
Raises:
ValueError:
If the encoding or format is incorrect or,
if DER is used with PKCS1 or,
protection value is supplied with PKCS1 format.
"""
try:
encoding, format = self._ENCODINGS[encoding], self._FORMATS[format]
except KeyError as e:
msg = f"Invalid encoding or format: {e}"
raise ValueError(msg) from e
if (
protection is not None and protection not in PROTECTION_SCHEMES
): # pragma: no cover
msg = f"invalid protection scheme: {protection!r}"
raise ValueError(msg)
if passphrase:
passphrase = memoryview(passphrase).tobytes()
kwargs: dict[str, typing.Any] = {}
if encoding == "PEM":
self._set_pem_args(format, passphrase, protection, kwargs)
elif encoding == "DER":
self._set_der_args(format, passphrase, protection, kwargs)
try:
key: str | bytes = self._key.export_key(**kwargs)
except ValueError as e:
msg = f"Failed to serialize key: {e!s}"
raise ValueError(msg) from e
return key if isinstance(key, bytes) else key.encode()
@classmethod
def _set_pem_args(
cls,
format: str,
passphrase: bytes | None,
protection: str | None,
kwargs: dict,
) -> None:
kwargs["format"] = "PEM"
if format == "PKCS8":
kwargs["pkcs"] = 8
cls._set_pkcs8_passphrase_args(passphrase, protection, kwargs)
elif format == "PKCS1":
kwargs["pkcs"] = 1
cls._set_pkcs1_passphrase_args(passphrase, protection, kwargs)
else:
msg = f"Invalid format for PEM: {format!r}"
raise ValueError(msg)
@classmethod
def _set_der_args(
cls,
format: str,
passphrase: bytes | None,
protection: str | None,
kwargs: dict,
) -> None:
kwargs["format"] = "DER"
if format == "PKCS8":
kwargs["pkcs"] = 8
cls._set_pkcs8_passphrase_args(passphrase, protection, kwargs)
elif format == "PKCS1":
kwargs["pkcs"] = 1
cls._set_pkcs1_passphrase_args(passphrase, protection, kwargs)
else:
msg = f"Invalid format for DER: {format!r}"
raise ValueError(msg)
@classmethod
def _set_pkcs8_passphrase_args(
cls,
passphrase: bytes | None,
protection: str | None,
kwargs: dict,
) -> None:
if not passphrase and protection:
msg = "Using protection without passphrase is invalid"
raise ValueError(msg)
kwargs["passphrase"] = passphrase
kwargs["protection"] = (
protection if protection else cls._DEFAULT_PROTECTION
)
@staticmethod
def _set_pkcs1_passphrase_args(
passphrase: bytes | None,
protection: str | None,
kwargs: dict,
) -> None:
if protection is not None: # pragma: no cover
msg = "protection is meaningful only for PKCS8"
raise ValueError(msg)
if passphrase is not None:
kwargs["passphrase"] = passphrase
@staticmethod
def _validate_pkcs1_args(
encoding: str,
protection: str | None,
) -> None:
if protection is not None: # pragma: no cover
msg = "protection is meaningful only for PKCS8"
raise ValueError(msg)
if encoding == "DER":
msg = "cannot use DER with PKCS1 format"
raise ValueError(msg)
[docs]
@classmethod
def load(
cls,
data: bytes,
passphrase: bytes | None = None,
) -> RSAPrivateKey:
try:
key = RSA.import_key(data, passphrase) # type: ignore
if not key.has_private():
msg = "The key is not a private key"
raise ValueError(msg)
except ValueError as e:
msg = f"Failed to load key: {e!s}"
raise ValueError(msg) from e
return cls(None, _key=key)
[docs]
class RSAPublicKey(base.BaseRSAPublicKey):
# Encodings supported by this key.
_ENCODINGS = {
"PEM": "PEM",
"DER": "DER",
"OpenSSH": "OpenSSH",
}
# Formats supported by this key.
_FORMATS = {
"SubjectPublicKeyInfo": "SubjectPublicKeyInfo",
"OpenSSH": "OpenSSH",
}
def __init__(self, key: RSA.RsaKey) -> None:
self._key = key
@property
def n(self) -> int:
return self._key.n
@property
def e(self) -> int:
return self._key.e
@property
def key_size(self) -> int:
return self._key.size_in_bits()
[docs]
def encryptor(
self,
padding: base.BaseAsymmetricPadding | None = None,
) -> EncryptorContext:
if padding is None: # pragma: no cover
padding = OAEP()
return EncryptorContext(
get_padding_algorithm(padding, self._key, padding),
)
[docs]
def verifier(
self,
padding: base.BaseAsymmetricPadding | None = None,
) -> VerifierContext:
if padding is None: # pragma: no cover
padding = PSS()
return VerifierContext(
get_padding_algorithm(padding, self._key, padding),
)
[docs]
def serialize(
self,
encoding: str = "PEM",
format: str = "SubjectPublicKeyInfo",
) -> bytes:
"""Serialize the public key.
Args:
encoding: PEM, DER or OpenSSH (defaults to PEM).
format: The supported formats are:
- SubjectPublicKeyInfo
- OpenSSH
Note:
``format`` argument is not actually used by Cryptodome. It
is here to maintain compatibility with pyca/cryptography
backend counterpart.
Returns:
The serialized public key as bytes object.
Raises:
ValueError:
if the encoding or format is not supported or invalid,
or OpenSSH encoding is not used with OpenSSH format.
"""
try:
encoding, format = self._ENCODINGS[encoding], self._FORMATS[format]
except KeyError as e:
msg = f"Invalid encoding or format: {e}"
raise ValueError(msg) from e
kwargs: dict[str, typing.Any] = {}
if encoding == "OpenSSH":
self._set_openssh_args(format, kwargs)
elif encoding == "PEM":
self._set_pem_args(format, kwargs)
elif encoding == "DER":
self._set_der_args(format, kwargs)
try:
data: str | bytes = self._key.export_key(**kwargs)
except ValueError as e:
msg = f"Failed to serialize key: {e!s}"
raise ValueError(msg) from e
return data if isinstance(data, bytes) else data.encode("utf-8")
@staticmethod
def _set_openssh_args(format: str, kwargs: dict) -> None:
if format == "OpenSSH":
kwargs["format"] = "OpenSSH"
return
msg = f"Invalid format for OpenSSH: {format!r}"
raise ValueError(msg)
@staticmethod
def _set_pem_args(format: str, kwargs: dict) -> None:
if format == "SubjectPublicKeyInfo":
kwargs["format"] = "PEM"
return
msg = f"Invalid format for PEM: {format!r}"
raise ValueError(msg)
@staticmethod
def _set_der_args(format: str, kwargs: dict) -> None:
if format == "SubjectPublicKeyInfo":
kwargs["format"] = "DER"
return
msg = f"Invalid format for DER: {format!r}"
raise ValueError(msg)
[docs]
@classmethod
def load(cls, data: bytes) -> RSAPublicKey:
try:
key = RSA.import_key(data)
if key.has_private():
msg = "The key is not a private key"
raise ValueError(msg)
except ValueError as e:
msg = f"Failed to load key: {e!s}"
raise ValueError(msg) from e
return cls(key)
[docs]
class EncryptorContext(base.BaseEncryptorContext):
def __init__(self, ctx: typing.Any) -> None:
self._ctx = ctx
[docs]
def encrypt(self, plaintext: bytes) -> bytes:
return self._ctx.encrypt(plaintext)
[docs]
class DecryptorContext(base.BaseDecryptorContext):
def __init__(self, ctx: typing.Any) -> None:
self._ctx = ctx
[docs]
def decrypt(self, plaintext: bytes) -> bytes:
try:
return self._ctx.decrypt(plaintext)
except ValueError as e:
raise exc.DecryptionError from e
[docs]
class SignerContext(base.BaseSignerContext):
def __init__(self, ctx: typing.Any) -> None:
self._ctx = ctx
[docs]
def sign(self, msghash: base.BaseHash) -> bytes:
return self._ctx.sign(msghash)
[docs]
class VerifierContext(base.BaseVerifierContext):
def __init__(self, ctx: typing.Any) -> None:
self._ctx = ctx
[docs]
def verify(self, msghash: base.BaseHash, signature: bytes) -> None:
try:
self._ctx.verify(msghash, signature)
except ValueError as e:
raise exc.SignatureError from e
[docs]
def generate(bits: int, e: int = 65537) -> RSAPrivateKey:
"""
Generate a private key with given key modulus ``bits`` and public exponent
``e`` (default 65537). Recommended size of ``bits`` > 1024.
Args:
bits: The bit length of the RSA key.
e: The public exponent value. Default is 65537.
Returns:
The RSA private key.
"""
return RSAPrivateKey(bits, e)
[docs]
def load_public_key(data: bytes) -> RSAPublicKey:
"""Loads the public key and returns a Key interface.
Args:
data: The public key (a bytes-like object) to deserialize.
Returns:
The RSA public key.
"""
return RSAPublicKey.load(data)
[docs]
def load_private_key(
data: bytes,
passphrase: bytes | None = None,
) -> RSAPrivateKey:
"""Loads the private key and returns a Key interface.
If the private key is not encrypted duting the serialization,
``passphrase`` must be ``None``, otherwise it must be a ``bytes`` object.
Args:
data: The private key (a bytes-like object) to deserialize.
passphrase:
The passphrase that is used to encrypt the private key. ``None``
if the private key is not encrypted.
Returns:
The RSA private key.
"""
return RSAPrivateKey.load(data, passphrase)