from __future__ import annotations
import typing
from functools import partial
import cryptography.exceptions as bkx
from cryptography.hazmat.primitives import serialization as serial
from cryptography.hazmat.primitives.asymmetric import rsa, utils
from cryptography.hazmat.primitives.serialization import (
Encoding,
PrivateFormat,
PublicFormat,
)
from pyflocker.ciphers import base, exc
from pyflocker.ciphers.backends.asymmetric import OAEP, PSS
from . import Hash
from .asymmetric import get_padding_algorithm
if typing.TYPE_CHECKING:
from cryptography.hazmat.primitives.asymmetric import padding as _padding
[docs]
class RSAPrivateKey(base.BaseRSAPrivateKey):
# Encodings supported by this key.
_ENCODINGS = {
"PEM": Encoding.PEM,
"DER": Encoding.DER,
}
# Formats supported by this key.
_FORMATS = {
"TraditionalOpenSSL": PrivateFormat.TraditionalOpenSSL,
"PKCS1": PrivateFormat.TraditionalOpenSSL,
"OpenSSH": PrivateFormat.OpenSSH,
"PKCS8": PrivateFormat.PKCS8,
}
# Key loaders indexed by the key format.
_LOADERS: dict[
bytes,
typing.Callable[[bytes, bytes | None], typing.Any],
] = {
b"-----BEGIN OPENSSH PRIVATE KEY": serial.load_ssh_private_key,
b"-----": serial.load_pem_private_key,
b"0": serial.load_der_private_key,
}
def __init__(
self,
n: int | None,
e: int = 65537,
_key: rsa.RSAPrivateKey | None = None,
) -> None:
if _key is not None:
self._key = _key
else:
if not isinstance(n, int): # pragma: no cover
msg = "n must be an integer value"
raise TypeError(msg)
self._key = rsa.generate_private_key(e, n)
# numbers
priv_nos = self._key.private_numbers()
self._p = priv_nos.p
self._q = priv_nos.q
self._d = priv_nos.d
pub_nos = priv_nos.public_numbers
self._e = pub_nos.e
self._n = pub_nos.n
@property
def p(self) -> int:
return self._p
@property
def q(self) -> int:
return self._q
@property
def d(self) -> int:
return self._d
@property
def e(self) -> int:
return self._e
@property
def n(self) -> int:
return self._n
@property
def key_size(self) -> int:
return self._key.key_size
[docs]
def public_key(self) -> RSAPublicKey:
return RSAPublicKey(self._key.public_key())
[docs]
def decryptor(
self,
padding: base.BaseAsymmetricPadding | None = None,
) -> DecryptorContext:
if padding is None: # pragma: no cover
padding = OAEP()
return DecryptorContext(
self._key,
get_padding_algorithm(padding, padding),
)
[docs]
def signer(
self,
padding: base.BaseAsymmetricPadding | None = None,
) -> SignerContext:
if padding is None: # pragma: no cover
padding = PSS()
return SignerContext(
self._key,
get_padding_algorithm(padding, padding),
)
[docs]
def serialize(
self,
encoding: str = "PEM",
format: str = "PKCS8",
passphrase: bytes | None = None,
) -> bytes:
"""Serialize the private key.
Args:
encoding: PEM or DER (defaults to PEM).
format: The formats can be:
- PKCS8 (default)
- TraditionalOpenSSL
- OpenSSH (available from pyca/cryptography version >=3.X)
- PKCS1 (alias to TraditionalOpenSSL for Cryptodome compat)
passphrase:
A bytes-like object to protect the private key. If
``passphrase`` is None, the private key will be exported in the
clear!
Returns:
The private key as a bytes object.
Raises:
ValueError: if the format or encoding is invalid or not supported.
"""
try:
encd = self._ENCODINGS[encoding]
fmt = self._FORMATS[format]
except KeyError as e:
msg = f"The encoding or format is invalid: {e.args[0]!r}"
raise ValueError(msg) from e
protection: serial.KeySerializationEncryption
if passphrase is None:
protection = serial.NoEncryption()
else:
protection = serial.BestAvailableEncryption(
memoryview(passphrase).tobytes()
)
try:
return self._key.private_bytes(encd, fmt, protection)
except ValueError as e:
msg = f"Failed to serialize key: {e!s}"
raise ValueError(msg) from e
[docs]
@classmethod
def load(
cls,
data: bytes,
passphrase: bytes | None = None,
) -> RSAPrivateKey:
loader = cls._get_loader(data)
if passphrase is not None:
passphrase = memoryview(passphrase).tobytes()
try:
key = loader(memoryview(data), passphrase)
if not isinstance(key, rsa.RSAPrivateKey):
msg = "Key is not an RSA Private key"
raise ValueError(msg)
except (ValueError, TypeError) as e:
msg = f"Failed to load key: {e!s}"
raise ValueError(msg) from e
return cls(None, _key=key)
@classmethod
def _get_loader(
cls,
data: bytes,
) -> typing.Callable[[bytes, bytes | None], rsa.RSAPrivateKey]:
"""
Returns a loader function depending on the initial bytes of the key.
"""
try:
return cls._LOADERS[next(filter(data.startswith, cls._LOADERS))]
except StopIteration:
msg = "Invalid format"
raise ValueError(msg) from None
[docs]
class RSAPublicKey(base.BaseRSAPublicKey):
# Encodings supported by this key.
_ENCODINGS = {
"PEM": Encoding.PEM,
"DER": Encoding.DER,
"OpenSSH": Encoding.OpenSSH,
}
# Formats supported by this key.
_FORMATS = {
"OpenSSH": PublicFormat.OpenSSH,
"PKCS1": PublicFormat.PKCS1,
"SubjectPublicKeyInfo": PublicFormat.SubjectPublicKeyInfo,
}
# Key loaders indexed by the key format.
_LOADERS = {
b"ssh-rsa ": serial.load_ssh_public_key,
b"-----": serial.load_pem_public_key,
b"0": serial.load_der_public_key,
}
def __init__(self, key: rsa.RSAPublicKey) -> None:
if not isinstance(key, rsa.RSAPublicKey): # pragma: no cover
msg = "The key is not an RSA public key."
raise ValueError(msg)
self._key = key
# numbers
pub_nos = self._key.public_numbers()
self._e = pub_nos.e
self._n = pub_nos.n
@property
def e(self) -> int:
return self._e
@property
def n(self) -> int:
return self._n
@property
def key_size(self) -> int:
return self._key.key_size
[docs]
def encryptor(
self,
padding: base.BaseAsymmetricPadding | None = None,
) -> EncryptorContext:
if padding is None: # pragma: no cover
padding = OAEP()
return EncryptorContext(
self._key,
get_padding_algorithm(padding, padding),
)
[docs]
def verifier(
self,
padding: base.BaseAsymmetricPadding | None = None,
) -> VerifierContext:
if padding is None: # pragma: no cover
padding = PSS()
return VerifierContext(
self._key,
get_padding_algorithm(padding, padding),
)
[docs]
def serialize(
self,
encoding: str = "PEM",
format: str = "SubjectPublicKeyInfo",
) -> bytes:
"""Serialize the public key.
Args:
encoding: PEM, DER or OpenSSH (defaults to PEM).
format: The supported formats are:
- SubjectPublicKeyInfo (default)
- PKCS1
- OpenSSH
Returns:
Serialized public key as bytes object.
Raises:
ValueError: if the encoding or format is incorrect or unsupported.
"""
try:
encd = self._ENCODINGS[encoding]
fmt = self._FORMATS[format]
except KeyError as e:
msg = f"Invalid encoding or format: {e.args[0]!r}"
raise ValueError(msg) from e
return self._key.public_bytes(encd, fmt)
[docs]
@classmethod
def load(cls, data: bytes) -> RSAPublicKey:
loader = cls._get_loader(data)
try:
key = loader(memoryview(data))
if not isinstance(key, rsa.RSAPublicKey):
msg = "Key is not an RSA public key."
raise ValueError(msg)
except ValueError as e:
msg = f"Failed to load key: {e!s}"
raise ValueError(msg) from e
assert isinstance(key, rsa.RSAPublicKey)
return cls(key)
@classmethod
def _get_loader(cls, data: bytes) -> typing.Callable:
"""
Returns a loader function depending on the initial bytes of the key.
"""
try:
return cls._LOADERS[next(filter(data.startswith, cls._LOADERS))]
except StopIteration:
msg = "Invalid format."
raise ValueError(msg) from None
[docs]
class EncryptorContext(base.BaseEncryptorContext):
def __init__(
self,
key: rsa.RSAPublicKey,
padding: _padding.AsymmetricPadding,
) -> None:
self._encrypt_func = partial(key.encrypt, padding=padding)
[docs]
def encrypt(self, plaintext: bytes) -> bytes:
return self._encrypt_func(plaintext)
[docs]
class DecryptorContext(base.BaseDecryptorContext):
def __init__(
self,
key: rsa.RSAPrivateKey,
padding: _padding.AsymmetricPadding,
) -> None:
self._decrypt_func = partial(key.decrypt, padding=padding)
[docs]
def decrypt(self, ciphertext: bytes) -> bytes:
try:
return self._decrypt_func(ciphertext)
except ValueError as e:
raise exc.DecryptionError from e
[docs]
class SignerContext(base.BaseSignerContext):
def __init__(
self,
key: rsa.RSAPrivateKey,
padding: _padding.AsymmetricPadding,
) -> None:
self._sign_func = partial(key.sign, padding=padding)
[docs]
def sign(self, msghash: base.BaseHash) -> bytes:
return self._sign_func(
data=msghash.digest(),
algorithm=utils.Prehashed(Hash._get_hash_algorithm(msghash)),
)
[docs]
class VerifierContext(base.BaseVerifierContext):
def __init__(
self, key: rsa.RSAPublicKey, padding: _padding.AsymmetricPadding
) -> None:
self._verify_func = partial(key.verify, padding=padding)
[docs]
def verify(self, msghash: base.BaseHash, signature: bytes) -> None:
try:
return self._verify_func(
signature=signature,
data=msghash.digest(),
algorithm=utils.Prehashed(Hash._get_hash_algorithm(msghash)),
)
except bkx.InvalidSignature as e:
raise exc.SignatureError from e
[docs]
def generate(bits: int, e: int = 65537) -> RSAPrivateKey:
"""
Generate a private key with given key modulus ``bits`` and public exponent
``e`` (default 65537). Recommended size of ``bits`` > 1024.
Args:
bits: The bit length of the RSA key.
e: The public exponent value. Default is 65537.
Returns:
RSAPrivateKey: The RSA private key.
"""
return RSAPrivateKey(bits, e)
[docs]
def load_public_key(data: bytes) -> RSAPublicKey:
"""Loads the public key and returns a Key interface.
Args:
data: The public key (a bytes-like object) to deserialize.
Returns:
RSAPublicKey: The RSA public key.
"""
return RSAPublicKey.load(data)
[docs]
def load_private_key(
data: bytes,
passphrase: bytes | None = None,
) -> RSAPrivateKey:
"""Loads the private key and returns a Key interface.
Args:
data: The private key (a bytes-like object) to deserialize.
passphrase:
The passphrase that was used to encrypt the private key. ``None``
if the private key is not encrypted.
Returns:
RSAPrivateKey: The RSA private key.
"""
return RSAPrivateKey.load(data, passphrase)