PyBitmessage/src/pyelliptic/ecc.py
Peter Surda 2e0f7755c6
Blind signature support in pyelliptic
- add blind signature functionality to pyelliptic as described in #1409
- add tests for blind signatures
- PEP8 fixes for pyelliptic
- some minor refactoring is necessary for further integration, this is just a
  minimal implementation to pass a test
2019-08-27 23:13:45 +02:00

506 lines
20 KiB
Python

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
src/pyelliptic/ecc.py
=====================
"""
# pylint: disable=protected-access
# Copyright (C) 2011 Yann GUIBET <yannguibet@gmail.com>
# See LICENSE for details.
from hashlib import sha512
from struct import pack, unpack
from cipher import Cipher
from hash import equals, hmac_sha256
from openssl import OpenSSL
class ECC(object):
"""
Asymmetric encryption with Elliptic Curve Cryptography (ECC)
ECDH, ECDSA and ECIES
>>> import pyelliptic
>>> alice = pyelliptic.ECC() # default curve: sect283r1
>>> bob = pyelliptic.ECC(curve='sect571r1')
>>> ciphertext = alice.encrypt("Hello Bob", bob.get_pubkey())
>>> print bob.decrypt(ciphertext)
>>> signature = bob.sign("Hello Alice")
>>> # alice's job :
>>> print pyelliptic.ECC(
>>> pubkey=bob.get_pubkey()).verify(signature, "Hello Alice")
>>> # ERROR !!!
>>> try:
>>> key = alice.get_ecdh_key(bob.get_pubkey())
>>> except:
>>> print("For ECDH key agreement, the keys must be defined on the same curve !")
>>> alice = pyelliptic.ECC(curve='sect571r1')
>>> print alice.get_ecdh_key(bob.get_pubkey()).encode('hex')
>>> print bob.get_ecdh_key(alice.get_pubkey()).encode('hex')
"""
def __init__(
self,
pubkey=None,
privkey=None,
pubkey_x=None,
pubkey_y=None,
raw_privkey=None,
curve='sect283r1',
): # pylint: disable=too-many-arguments
"""
For a normal and High level use, specifie pubkey,
privkey (if you need) and the curve
"""
if isinstance(curve, str):
self.curve = OpenSSL.get_curve(curve)
else:
self.curve = curve
if pubkey_x is not None and pubkey_y is not None:
self._set_keys(pubkey_x, pubkey_y, raw_privkey)
elif pubkey is not None:
curve, pubkey_x, pubkey_y, _ = ECC._decode_pubkey(pubkey)
if privkey is not None:
curve2, raw_privkey, _ = ECC._decode_privkey(privkey)
if curve != curve2:
raise Exception("Bad ECC keys ...")
self.curve = curve
self._set_keys(pubkey_x, pubkey_y, raw_privkey)
else:
self.privkey, self.pubkey_x, self.pubkey_y = self._generate()
def _set_keys(self, pubkey_x, pubkey_y, privkey):
if self.raw_check_key(privkey, pubkey_x, pubkey_y) < 0:
self.pubkey_x = None
self.pubkey_y = None
self.privkey = None
raise Exception("Bad ECC keys ...")
else:
self.pubkey_x = pubkey_x
self.pubkey_y = pubkey_y
self.privkey = privkey
@staticmethod
def get_curves():
"""
static method, returns the list of all the curves available
"""
return OpenSSL.curves.keys()
def get_curve(self):
"""Encryption object from curve name"""
return OpenSSL.get_curve_by_id(self.curve)
def get_curve_id(self):
"""Currently used curve"""
return self.curve
def get_pubkey(self):
"""
High level function which returns :
curve(2) + len_of_pubkeyX(2) + pubkeyX + len_of_pubkeyY + pubkeyY
"""
return b''.join((
pack('!H', self.curve),
pack('!H', len(self.pubkey_x)),
self.pubkey_x,
pack('!H', len(self.pubkey_y)),
self.pubkey_y,
))
def get_privkey(self):
"""
High level function which returns
curve(2) + len_of_privkey(2) + privkey
"""
return b''.join((
pack('!H', self.curve),
pack('!H', len(self.privkey)),
self.privkey,
))
@staticmethod
def _decode_pubkey(pubkey):
i = 0
curve = unpack('!H', pubkey[i:i + 2])[0]
i += 2
tmplen = unpack('!H', pubkey[i:i + 2])[0]
i += 2
pubkey_x = pubkey[i:i + tmplen]
i += tmplen
tmplen = unpack('!H', pubkey[i:i + 2])[0]
i += 2
pubkey_y = pubkey[i:i + tmplen]
i += tmplen
return curve, pubkey_x, pubkey_y, i
@staticmethod
def _decode_privkey(privkey):
i = 0
curve = unpack('!H', privkey[i:i + 2])[0]
i += 2
tmplen = unpack('!H', privkey[i:i + 2])[0]
i += 2
privkey = privkey[i:i + tmplen]
i += tmplen
return curve, privkey, i
def _generate(self):
try:
pub_key_x = OpenSSL.BN_new()
pub_key_y = OpenSSL.BN_new()
key = OpenSSL.EC_KEY_new_by_curve_name(self.curve)
if key == 0:
raise Exception("[OpenSSL] EC_KEY_new_by_curve_name FAIL ...")
if (OpenSSL.EC_KEY_generate_key(key)) == 0:
raise Exception("[OpenSSL] EC_KEY_generate_key FAIL ...")
if (OpenSSL.EC_KEY_check_key(key)) == 0:
raise Exception("[OpenSSL] EC_KEY_check_key FAIL ...")
priv_key = OpenSSL.EC_KEY_get0_private_key(key)
group = OpenSSL.EC_KEY_get0_group(key)
pub_key = OpenSSL.EC_KEY_get0_public_key(key)
if OpenSSL.EC_POINT_get_affine_coordinates_GFp(
group, pub_key, pub_key_x, pub_key_y, 0) == 0:
raise Exception("[OpenSSL] EC_POINT_get_affine_coordinates_GFp FAIL ...")
privkey = OpenSSL.malloc(0, OpenSSL.BN_num_bytes(priv_key))
pubkeyx = OpenSSL.malloc(0, OpenSSL.BN_num_bytes(pub_key_x))
pubkeyy = OpenSSL.malloc(0, OpenSSL.BN_num_bytes(pub_key_y))
OpenSSL.BN_bn2bin(priv_key, privkey)
privkey = privkey.raw
OpenSSL.BN_bn2bin(pub_key_x, pubkeyx)
pubkeyx = pubkeyx.raw
OpenSSL.BN_bn2bin(pub_key_y, pubkeyy)
pubkeyy = pubkeyy.raw
self.raw_check_key(privkey, pubkeyx, pubkeyy)
return privkey, pubkeyx, pubkeyy
finally:
OpenSSL.EC_KEY_free(key)
OpenSSL.BN_free(pub_key_x)
OpenSSL.BN_free(pub_key_y)
def get_ecdh_key(self, pubkey):
"""
High level function. Compute public key with the local private key
and returns a 512bits shared key
"""
curve, pubkey_x, pubkey_y, _ = ECC._decode_pubkey(pubkey)
if curve != self.curve:
raise Exception("ECC keys must be from the same curve !")
return sha512(self.raw_get_ecdh_key(pubkey_x, pubkey_y)).digest()
def raw_get_ecdh_key(self, pubkey_x, pubkey_y):
"""ECDH key as binary data"""
try:
ecdh_keybuffer = OpenSSL.malloc(0, 32)
other_key = OpenSSL.EC_KEY_new_by_curve_name(self.curve)
if other_key == 0:
raise Exception("[OpenSSL] EC_KEY_new_by_curve_name FAIL ...")
other_pub_key_x = OpenSSL.BN_bin2bn(pubkey_x, len(pubkey_x), 0)
other_pub_key_y = OpenSSL.BN_bin2bn(pubkey_y, len(pubkey_y), 0)
other_group = OpenSSL.EC_KEY_get0_group(other_key)
other_pub_key = OpenSSL.EC_POINT_new(other_group)
if (OpenSSL.EC_POINT_set_affine_coordinates_GFp(other_group,
other_pub_key,
other_pub_key_x,
other_pub_key_y,
0)) == 0:
raise Exception(
"[OpenSSL] EC_POINT_set_affine_coordinates_GFp FAIL ...")
if (OpenSSL.EC_KEY_set_public_key(other_key, other_pub_key)) == 0:
raise Exception("[OpenSSL] EC_KEY_set_public_key FAIL ...")
if (OpenSSL.EC_KEY_check_key(other_key)) == 0:
raise Exception("[OpenSSL] EC_KEY_check_key FAIL ...")
own_key = OpenSSL.EC_KEY_new_by_curve_name(self.curve)
if own_key == 0:
raise Exception("[OpenSSL] EC_KEY_new_by_curve_name FAIL ...")
own_priv_key = OpenSSL.BN_bin2bn(
self.privkey, len(self.privkey), 0)
if (OpenSSL.EC_KEY_set_private_key(own_key, own_priv_key)) == 0:
raise Exception("[OpenSSL] EC_KEY_set_private_key FAIL ...")
if OpenSSL._hexversion > 0x10100000 and not OpenSSL._libreSSL:
OpenSSL.EC_KEY_set_method(own_key, OpenSSL.EC_KEY_OpenSSL())
else:
OpenSSL.ECDH_set_method(own_key, OpenSSL.ECDH_OpenSSL())
ecdh_keylen = OpenSSL.ECDH_compute_key(
ecdh_keybuffer, 32, other_pub_key, own_key, 0)
if ecdh_keylen != 32:
raise Exception("[OpenSSL] ECDH keylen FAIL ...")
return ecdh_keybuffer.raw
finally:
OpenSSL.EC_KEY_free(other_key)
OpenSSL.BN_free(other_pub_key_x)
OpenSSL.BN_free(other_pub_key_y)
OpenSSL.EC_POINT_free(other_pub_key)
OpenSSL.EC_KEY_free(own_key)
OpenSSL.BN_free(own_priv_key)
def check_key(self, privkey, pubkey):
"""
Check the public key and the private key.
The private key is optional (replace by None)
"""
curve, pubkey_x, pubkey_y, _ = ECC._decode_pubkey(pubkey)
if privkey is None:
raw_privkey = None
curve2 = curve
else:
curve2, raw_privkey, _ = ECC._decode_privkey(privkey)
if curve != curve2:
raise Exception("Bad public and private key")
return self.raw_check_key(raw_privkey, pubkey_x, pubkey_y, curve)
def raw_check_key(self, privkey, pubkey_x, pubkey_y, curve=None):
"""Check key validity, key is supplied as binary data"""
# pylint: disable=too-many-branches
if curve is None:
curve = self.curve
elif isinstance(curve, str):
curve = OpenSSL.get_curve(curve)
else:
curve = curve
try:
key = OpenSSL.EC_KEY_new_by_curve_name(curve)
if key == 0:
raise Exception("[OpenSSL] EC_KEY_new_by_curve_name FAIL ...")
if privkey is not None:
priv_key = OpenSSL.BN_bin2bn(privkey, len(privkey), 0)
pub_key_x = OpenSSL.BN_bin2bn(pubkey_x, len(pubkey_x), 0)
pub_key_y = OpenSSL.BN_bin2bn(pubkey_y, len(pubkey_y), 0)
if privkey is not None:
if (OpenSSL.EC_KEY_set_private_key(key, priv_key)) == 0:
raise Exception(
"[OpenSSL] EC_KEY_set_private_key FAIL ...")
group = OpenSSL.EC_KEY_get0_group(key)
pub_key = OpenSSL.EC_POINT_new(group)
if (OpenSSL.EC_POINT_set_affine_coordinates_GFp(group, pub_key,
pub_key_x,
pub_key_y,
0)) == 0:
raise Exception(
"[OpenSSL] EC_POINT_set_affine_coordinates_GFp FAIL ...")
if (OpenSSL.EC_KEY_set_public_key(key, pub_key)) == 0:
raise Exception("[OpenSSL] EC_KEY_set_public_key FAIL ...")
if (OpenSSL.EC_KEY_check_key(key)) == 0:
raise Exception("[OpenSSL] EC_KEY_check_key FAIL ...")
return 0
finally:
OpenSSL.EC_KEY_free(key)
OpenSSL.BN_free(pub_key_x)
OpenSSL.BN_free(pub_key_y)
OpenSSL.EC_POINT_free(pub_key)
if privkey is not None:
OpenSSL.BN_free(priv_key)
def sign(self, inputb, digest_alg=OpenSSL.digest_ecdsa_sha1):
"""
Sign the input with ECDSA method and returns the signature
"""
# pylint: disable=too-many-branches,too-many-locals
try:
size = len(inputb)
buff = OpenSSL.malloc(inputb, size)
digest = OpenSSL.malloc(0, 64)
if OpenSSL._hexversion > 0x10100000 and not OpenSSL._libreSSL:
md_ctx = OpenSSL.EVP_MD_CTX_new()
else:
md_ctx = OpenSSL.EVP_MD_CTX_create()
dgst_len = OpenSSL.pointer(OpenSSL.c_int(0))
siglen = OpenSSL.pointer(OpenSSL.c_int(0))
sig = OpenSSL.malloc(0, 151)
key = OpenSSL.EC_KEY_new_by_curve_name(self.curve)
if key == 0:
raise Exception("[OpenSSL] EC_KEY_new_by_curve_name FAIL ...")
priv_key = OpenSSL.BN_bin2bn(self.privkey, len(self.privkey), 0)
pub_key_x = OpenSSL.BN_bin2bn(self.pubkey_x, len(self.pubkey_x), 0)
pub_key_y = OpenSSL.BN_bin2bn(self.pubkey_y, len(self.pubkey_y), 0)
if (OpenSSL.EC_KEY_set_private_key(key, priv_key)) == 0:
raise Exception("[OpenSSL] EC_KEY_set_private_key FAIL ...")
group = OpenSSL.EC_KEY_get0_group(key)
pub_key = OpenSSL.EC_POINT_new(group)
if (OpenSSL.EC_POINT_set_affine_coordinates_GFp(group, pub_key,
pub_key_x,
pub_key_y,
0)) == 0:
raise Exception(
"[OpenSSL] EC_POINT_set_affine_coordinates_GFp FAIL ...")
if (OpenSSL.EC_KEY_set_public_key(key, pub_key)) == 0:
raise Exception("[OpenSSL] EC_KEY_set_public_key FAIL ...")
if (OpenSSL.EC_KEY_check_key(key)) == 0:
raise Exception("[OpenSSL] EC_KEY_check_key FAIL ...")
if OpenSSL._hexversion > 0x10100000 and not OpenSSL._libreSSL:
OpenSSL.EVP_MD_CTX_new(md_ctx)
else:
OpenSSL.EVP_MD_CTX_init(md_ctx)
OpenSSL.EVP_DigestInit_ex(md_ctx, digest_alg(), None)
if (OpenSSL.EVP_DigestUpdate(md_ctx, buff, size)) == 0:
raise Exception("[OpenSSL] EVP_DigestUpdate FAIL ...")
OpenSSL.EVP_DigestFinal_ex(md_ctx, digest, dgst_len)
OpenSSL.ECDSA_sign(0, digest, dgst_len.contents, sig, siglen, key)
if (OpenSSL.ECDSA_verify(0, digest, dgst_len.contents, sig,
siglen.contents, key)) != 1:
raise Exception("[OpenSSL] ECDSA_verify FAIL ...")
return sig.raw[:siglen.contents.value]
finally:
OpenSSL.EC_KEY_free(key)
OpenSSL.BN_free(pub_key_x)
OpenSSL.BN_free(pub_key_y)
OpenSSL.BN_free(priv_key)
OpenSSL.EC_POINT_free(pub_key)
if OpenSSL._hexversion > 0x10100000 and not OpenSSL._libreSSL:
OpenSSL.EVP_MD_CTX_free(md_ctx)
else:
OpenSSL.EVP_MD_CTX_destroy(md_ctx)
def verify(self, sig, inputb, digest_alg=OpenSSL.digest_ecdsa_sha1):
"""
Verify the signature with the input and the local public key.
Returns a boolean
"""
# pylint: disable=too-many-branches
try:
bsig = OpenSSL.malloc(sig, len(sig))
binputb = OpenSSL.malloc(inputb, len(inputb))
digest = OpenSSL.malloc(0, 64)
dgst_len = OpenSSL.pointer(OpenSSL.c_int(0))
if OpenSSL._hexversion > 0x10100000 and not OpenSSL._libreSSL:
md_ctx = OpenSSL.EVP_MD_CTX_new()
else:
md_ctx = OpenSSL.EVP_MD_CTX_create()
key = OpenSSL.EC_KEY_new_by_curve_name(self.curve)
if key == 0:
raise Exception("[OpenSSL] EC_KEY_new_by_curve_name FAIL ...")
pub_key_x = OpenSSL.BN_bin2bn(self.pubkey_x, len(self.pubkey_x), 0)
pub_key_y = OpenSSL.BN_bin2bn(self.pubkey_y, len(self.pubkey_y), 0)
group = OpenSSL.EC_KEY_get0_group(key)
pub_key = OpenSSL.EC_POINT_new(group)
if (OpenSSL.EC_POINT_set_affine_coordinates_GFp(group, pub_key,
pub_key_x,
pub_key_y,
0)) == 0:
raise Exception(
"[OpenSSL] EC_POINT_set_affine_coordinates_GFp FAIL ...")
if (OpenSSL.EC_KEY_set_public_key(key, pub_key)) == 0:
raise Exception("[OpenSSL] EC_KEY_set_public_key FAIL ...")
if (OpenSSL.EC_KEY_check_key(key)) == 0:
raise Exception("[OpenSSL] EC_KEY_check_key FAIL ...")
if OpenSSL._hexversion > 0x10100000 and not OpenSSL._libreSSL:
OpenSSL.EVP_MD_CTX_new(md_ctx)
else:
OpenSSL.EVP_MD_CTX_init(md_ctx)
OpenSSL.EVP_DigestInit_ex(md_ctx, digest_alg(), None)
if (OpenSSL.EVP_DigestUpdate(md_ctx, binputb, len(inputb))) == 0:
raise Exception("[OpenSSL] EVP_DigestUpdate FAIL ...")
OpenSSL.EVP_DigestFinal_ex(md_ctx, digest, dgst_len)
ret = OpenSSL.ECDSA_verify(
0, digest, dgst_len.contents, bsig, len(sig), key)
if ret == -1:
return False # Fail to Check
if ret == 0:
return False # Bad signature !
return True # Good
finally:
OpenSSL.EC_KEY_free(key)
OpenSSL.BN_free(pub_key_x)
OpenSSL.BN_free(pub_key_y)
OpenSSL.EC_POINT_free(pub_key)
if OpenSSL._hexversion > 0x10100000 and not OpenSSL._libreSSL:
OpenSSL.EVP_MD_CTX_free(md_ctx)
else:
OpenSSL.EVP_MD_CTX_destroy(md_ctx)
@staticmethod
def encrypt(data, pubkey, ephemcurve=None, ciphername='aes-256-cbc'):
"""
Encrypt data with ECIES method using the public key of the recipient.
"""
curve, pubkey_x, pubkey_y, _ = ECC._decode_pubkey(pubkey)
return ECC.raw_encrypt(data, pubkey_x, pubkey_y, curve=curve,
ephemcurve=ephemcurve, ciphername=ciphername)
@staticmethod
def raw_encrypt(
data,
pubkey_x,
pubkey_y,
curve='sect283r1',
ephemcurve=None,
ciphername='aes-256-cbc',
): # pylint: disable=too-many-arguments
"""ECHD encryption, keys supplied in binary data format"""
if ephemcurve is None:
ephemcurve = curve
ephem = ECC(curve=ephemcurve)
key = sha512(ephem.raw_get_ecdh_key(pubkey_x, pubkey_y)).digest()
key_e, key_m = key[:32], key[32:]
pubkey = ephem.get_pubkey()
iv = OpenSSL.rand(OpenSSL.get_cipher(ciphername).get_blocksize())
ctx = Cipher(key_e, iv, 1, ciphername)
ciphertext = iv + pubkey + ctx.ciphering(data)
mac = hmac_sha256(key_m, ciphertext)
return ciphertext + mac
def decrypt(self, data, ciphername='aes-256-cbc'):
"""
Decrypt data with ECIES method using the local private key
"""
# pylint: disable=too-many-locals
blocksize = OpenSSL.get_cipher(ciphername).get_blocksize()
iv = data[:blocksize]
i = blocksize
_, pubkey_x, pubkey_y, i2 = ECC._decode_pubkey(data[i:])
i += i2
ciphertext = data[i:len(data) - 32]
i += len(ciphertext)
mac = data[i:]
key = sha512(self.raw_get_ecdh_key(pubkey_x, pubkey_y)).digest()
key_e, key_m = key[:32], key[32:]
if not equals(hmac_sha256(key_m, data[:len(data) - 32]), mac):
raise RuntimeError("Fail to verify data")
ctx = Cipher(key_e, iv, 0, ciphername)
return ctx.ciphering(ciphertext)