From 72ba0b6bc17eecbfce3f50adbbbd828ed893a865 Mon Sep 17 00:00:00 2001
From: Kashiko Koibumi <kashiko@tuta.io>
Date: Mon, 20 May 2024 15:42:04 +0900
Subject: [PATCH] make runnable with OpenSSL 3 in addition to other versions

---
 src/network/tls.py | 48 ++++++++++++++++++++++++++++++++++++++--------
 1 file changed, 40 insertions(+), 8 deletions(-)

diff --git a/src/network/tls.py b/src/network/tls.py
index a3774b44..87240ce8 100644
--- a/src/network/tls.py
+++ b/src/network/tls.py
@@ -58,25 +58,44 @@ class TLSDispatcher(AdvancedDispatcher):
         self.tlsDone = False
         self.tlsVersion = "N/A"
         self.isSSL = False
+        if ssl.OPENSSL_VERSION_NUMBER >= 0x30000000:
+            self.tlsPrepared = False
 
     def state_tls_init(self):
         """Prepare sockets for TLS handshake"""
         self.isSSL = True
         self.tlsStarted = True
+
+        if ssl.OPENSSL_VERSION_NUMBER >= 0x30000000:
+            self.want_read = self.want_write = True
+            self.set_state("tls_handshake")
+            return False
+
+        do_tls_init(self)
+
+    def do_tls_init(self):
         # Once the connection has been established,
         # it's safe to wrap the socket.
         if sys.version_info >= (2, 7, 9):
-            context = ssl.create_default_context(
-                purpose=ssl.Purpose.SERVER_AUTH
-                if self.server_side else ssl.Purpose.CLIENT_AUTH)
+            if ssl.OPENSSL_VERSION_NUMBER >= 0x30000000:
+                context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
+            else:
+                context = ssl.create_default_context(
+                    purpose=ssl.Purpose.SERVER_AUTH
+                    if self.server_side else ssl.Purpose.CLIENT_AUTH)
             context.set_ciphers(self.ciphers)
             context.set_ecdh_curve("secp256k1")
             context.check_hostname = False
             context.verify_mode = ssl.CERT_NONE
             # also exclude TLSv1 and TLSv1.1 in the future
-            context.options = ssl.OP_ALL | ssl.OP_NO_SSLv2 |\
-                ssl.OP_NO_SSLv3 | ssl.OP_SINGLE_ECDH_USE |\
-                ssl.OP_CIPHER_SERVER_PREFERENCE
+            if ssl.OPENSSL_VERSION_NUMBER >= 0x30000000:
+                context.options = ssl.OP_ALL | ssl.OP_NO_SSLv2 |\
+                    ssl.OP_NO_SSLv3 | ssl.OP_SINGLE_ECDH_USE |\
+                    ssl.OP_CIPHER_SERVER_PREFERENCE | ssl.OP_NO_TLS1_3
+            else:
+                context.options = ssl.OP_ALL | ssl.OP_NO_SSLv2 |\
+                    ssl.OP_NO_SSLv3 | ssl.OP_SINGLE_ECDH_USE |\
+                    ssl.OP_CIPHER_SERVER_PREFERENCE
             self.sslSocket = context.wrap_socket(
                 self.socket, server_side=self.server_side,
                 do_handshake_on_connect=False)
@@ -88,7 +107,10 @@ class TLSDispatcher(AdvancedDispatcher):
                 ciphers=self.ciphers, do_handshake_on_connect=False)
         self.sslSocket.setblocking(0)
         self.want_read = self.want_write = True
-        self.set_state("tls_handshake")
+        if ssl.OPENSSL_VERSION_NUMBER >= 0x30000000:
+            self.tlsPrepared = True
+        else:
+            self.set_state("tls_handshake")
         return False
 
     @staticmethod
@@ -114,7 +136,9 @@ class TLSDispatcher(AdvancedDispatcher):
             # during TLS handshake, and after flushing write buffer,
             # return status of last handshake attempt
             if self.tlsStarted and not self.tlsDone and not self.write_buf:
-                logger.debug('tls readable, %r', self.want_read)
+                # with OpenSSL 3, excessive logs are spitted
+                if ssl.OPENSSL_VERSION_NUMBER < 0x30000000:
+                    logger.debug('tls readable, %r', self.want_read)
                 return self.want_read
             # prior to TLS handshake,
             # receiveDataThread should emulate synchronous behaviour
@@ -134,6 +158,10 @@ class TLSDispatcher(AdvancedDispatcher):
         try:
             # wait for write buffer flush
             if self.tlsStarted and not self.tlsDone and not self.write_buf:
+                if ssl.OPENSSL_VERSION_NUMBER >= 0x30000000:
+                    if not self.tlsPrepared:
+                        self.do_tls_init()
+                        return
                 self.tls_handshake()
             else:
                 AdvancedDispatcher.handle_read(self)
@@ -156,6 +184,10 @@ class TLSDispatcher(AdvancedDispatcher):
         try:
             # wait for write buffer flush
             if self.tlsStarted and not self.tlsDone and not self.write_buf:
+                if ssl.OPENSSL_VERSION_NUMBER >= 0x30000000:
+                    if not self.tlsPrepared:
+                        self.do_tls_init()
+                        return
                 self.tls_handshake()
             else:
                 AdvancedDispatcher.handle_write(self)