Source code for tgbox.crypto

"""This module stores all cryptography used in API."""

import logging

logger = logging.getLogger(__name__)

from os import urandom
from typing import Union, Optional

from pyaes.util import (
    append_PKCS7_padding,
    strip_PKCS7_padding
)
from .errors import ModeInvalid
try:
    from cryptography.hazmat.primitives.ciphers\
        import Cipher, algorithms, modes
    FAST_ENCRYPTION = True
    logger.info('Fast cryptography library was found.')
except ModuleNotFoundError:
    # We can use PyAES if there is no cryptography library.
    # PyAES is much slower. You can use it for quick tests.
    from pyaes import AESModeOfOperationCBC
    FAST_ENCRYPTION = False
    logger.warning('Fast cryptography library was NOT found. ')
try:
    # Check if cryptg is installed.
    from cryptg import __name__ as _
    del _
    FAST_TELETHON = True
except ModuleNotFoundError:
    FAST_TELETHON = False

__all__ = [
    'AESwState',
    'get_rnd_bytes',
    'FAST_TELETHON',
    'FAST_ENCRYPTION',
    'Salt', 'BoxSalt',
    'FileSalt', 'IV'
]
[docs] class IV: """This is a class-wrapper for AES IV""" def __init__(self, iv: Union[bytes, memoryview]): self.iv = iv if isinstance(iv, bytes) else bytes(iv) def __repr__(self) -> str: return f'{self.__class__.__name__}({repr(self.iv)})' def __str__(self) -> str: return f'{self.__class__.__name__}({repr(self.iv)}) # at {hex(id(self))}' def __add__(self, other): return self.iv + other def __len__(self) -> int: return len(self.iv) def __eq__(self, other) -> bool: if hasattr(other, 'iv'): return self.iv == other.iv return False
[docs] @classmethod def generate(cls, bytelength: Optional[int] = 16): """ Generates AES IV by ``bytelength`` Arguments: bytelength (``int``, optional): Bytelength of IV. 16 bytes by default. """ return cls(get_rnd_bytes(bytelength))
[docs] def hex(self) -> str: """Returns IV as hexadecimal""" return self.iv.hex()
[docs] class Salt: """This is a class-wrapper for some TGBOX salt""" def __init__(self, salt: Union[bytes, memoryview]): self.salt = salt if isinstance(salt, bytes) else bytes(salt) def __repr__(self) -> str: return f'{self.__class__.__name__}({repr(self.salt)})' def __str__(self) -> str: return f'{self.__class__.__name__}({repr(self.salt)}) # at {hex(id(self))}' def __add__(self, other): return self.salt + other def __len__(self) -> int: return len(self.salt) def __eq__(self, other) -> bool: if hasattr(other, 'salt'): return self.salt == other.salt return False
[docs] @classmethod def generate(cls, bytelength: Optional[int] = 32): """ Generates Salt by ``bytelength`` Arguments: bytelength (``int``, optional): Bytelength of Salt. 32 bytes by default. """ return cls(get_rnd_bytes(bytelength))
[docs] def hex(self) -> str: """Returns Salt as hexadecimal""" return self.salt.hex()
[docs] class BoxSalt(Salt): """This is a class-wrapper for BoxSalt"""
[docs] class FileSalt(Salt): """This is a class-wrapper for FileSalt"""
class _PyaesState: def __init__(self, key: Union[bytes, 'Key'], iv: IV): """ Class to wrap ``pyaes.AESModeOfOperationCBC`` if there is no ``FAST_ENCRYPTION``. .. note:: You should use only ``encrypt()`` or ``decrypt()`` method per one object. Arguments: key (``bytes``, ``Key``): AES encryption/decryption Key. iv (``IV``): AES Initialization Vector. """ key = key.key if hasattr(key, 'key') else key self.iv = iv self._aes_state = AESModeOfOperationCBC( # pylint: disable=E0601 key = bytes(key), iv = self.iv.iv ) self.__mode = None # encrypt mode is 1 and decrypt is 2 @staticmethod def __convert_memoryview(data: Union[bytes, memoryview]) -> bytes: # PyAES doesn't support memoryview, convert to bytes if isinstance(data, memoryview) and not FAST_ENCRYPTION: data = data.tobytes() return data def encrypt(self, data: Union[bytes, memoryview]) -> bytes: """``data`` length must be divisible by 16.""" if not self.__mode: self.__mode = 1 else: if self.__mode != 1: raise ModeInvalid('You should use only decrypt function.') data = self.__convert_memoryview(data) if len(data) % 16: raise ValueError('data length must be divisible by 16') total = b'' for _ in range(len(data) // 16): total += self._aes_state.encrypt(data[:16]) data = data[16:] return total def decrypt(self, data: Union[bytes, memoryview]) -> bytes: """``data`` length must be divisible by 16.""" if not self.__mode: self.__mode = 2 else: if self.__mode != 2: raise ModeInvalid('You should use only encrypt function.') data = self.__convert_memoryview(data) if len(data) % 16: raise ValueError('data length must be divisible by 16') total = b'' for _ in range(len(data) // 16): total += self._aes_state.decrypt(data[:16]) data = data[16:] return total
[docs] class AESwState: """ Wrapper around AES CBC which preserve state. .. note:: You should use only ``encrypt()`` or ``decrypt()`` method per one object. """ def __init__( self, key: Union[bytes, 'Key'], iv: Optional[Union[IV, bytes]] = None ): """ Arguments: key (``bytes``, ``Key``): AES encryption/decryption Key. iv (``IV``, ``bytes``, optional): AES Initialization Vector. If mode is *Encryption*, and isn't specified, will be used bytes from `urandom(16)`. If mode is *Decryption*, and isn't specified, will be used first 16 bytes of ciphertext. """ self.key = key.key if hasattr(key, 'key') else key self.iv, self.__mode, self._aes_cbc = iv, None, None if self.iv and not isinstance(self.iv, IV): self.iv = IV(self.iv) self.__iv_concated = False def __repr__(self) -> str: return f'<class {self.__class__.__name__}(<key>, {repr(self.iv)})>' def __str__(self) -> str: return f'<class {self.__class__.__name__}(<key>, {repr(self.iv)})> # {self.__mode=}' def __init_aes_state(self, mode: int) -> None: if FAST_ENCRYPTION: self._aes_cbc = Cipher(algorithms.AES(self.key), modes.CBC(self.iv.iv)) if mode == 1: # Encryption encryptor = self._aes_cbc.encryptor() setattr(self._aes_cbc, 'encrypt', encryptor.update) else: # Decryption decryptor = self._aes_cbc.decryptor() setattr(self._aes_cbc, 'decrypt', decryptor.update) else: self._aes_cbc = _PyaesState(self.key, self.iv) @property def mode(self) -> int: """ Returns ``1`` if mode is encryption and ``2`` if decryption. """ return self.__mode
[docs] def encrypt(self, data: bytes, pad: bool=True, concat_iv: bool=True) -> bytes: """ Encrypts ``data`` with AES CBC. If ``concat_iv`` is ``True``, then first 16 bytes of result will be IV. """ if not self.__mode: self.__mode = 1 if not self.iv: self.iv = IV.generate() self.__init_aes_state(self.__mode) else: if self.__mode != 1: raise ModeInvalid('You should use only decrypt method.') if pad: data = append_PKCS7_padding(data) data = self._aes_cbc.encrypt(data) if concat_iv and not self.__iv_concated: self.__iv_concated = True return self.iv.iv + data return data
[docs] def decrypt(self, data: bytes, unpad: bool=True) -> bytes: """ Decrypts ``data`` with AES CBC. ``data`` length must be evenly divisible by 16. """ if not self.__mode: self.__mode = 2 if not self.iv: self.iv, data = IV(data[:16]), data[16:] self.__init_aes_state(self.__mode) else: if self.__mode != 2: raise ModeInvalid('You should use only encrypt method.') data = self._aes_cbc.decrypt(data) if unpad: data = strip_PKCS7_padding(data) return data
[docs] def get_rnd_bytes(length: int=32) -> bytes: """Returns ``os.urandom(length)``.""" return urandom(length)