Copy certificate subject from peercertificate, use ssl.PROTOCOL_TLSv1_2 for client wrap and allow TLSv1_1 for server wrap (#370)

This commit is contained in:
Abhinav Singh 2020-06-13 21:42:12 +05:30 committed by GitHub
parent d6e60774ae
commit 1b2966140c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 19 additions and 8 deletions

View File

@ -39,7 +39,7 @@ countryName_min = 2
countryName_max = 2
stateOrProvinceName = State or Province Name (full name)
localityName = Locality Name (eg, city)
0.organizationName = Organization Name (eg, company)
organizationName = Organization Name (eg, company)
organizationalUnitName = Organizational Unit Name (eg, section)
commonName = Common Name (eg, fully qualified host name)
commonName_max = 64

View File

@ -356,7 +356,7 @@ class HttpProxyPlugin(HttpProtocolHandlerPlugin):
self.response.total_size,
connection_time_ms))
def gen_ca_signed_certificate(self, cert_file_path: str) -> None:
def gen_ca_signed_certificate(self, cert_file_path: str, certificate: Dict[str, Any]) -> None:
'''CA signing key (default) is used for generating a public key
for common_name, if one already doesn't exist. Using generated
public key a CSR request is generated, which is then signed by
@ -366,11 +366,19 @@ class HttpProxyPlugin(HttpProtocolHandlerPlugin):
returns signed certificate path.'''
assert(self.request.host and self.flags.ca_cert_dir and self.flags.ca_signing_key_file and
self.flags.ca_key_file and self.flags.ca_cert_file)
upstream_subject = {s[0][0]: s[0][1] for s in certificate['subject']}
public_key_path = os.path.join(self.flags.ca_cert_dir,
'{0}.{1}'.format(text_(self.request.host), 'pub'))
private_key_path = self.flags.ca_signing_key_file
private_key_password = ''
subject = '/CN={0}'.format(text_(self.request.host))
subject = '/CN={0}/C={1}/ST={2}/L={3}/O={4}/OU={5}'.format(
upstream_subject.get('commonName', text_(self.request.host)),
upstream_subject.get('countryName', 'NA'),
upstream_subject.get('stateOrProvinceName', 'Unavailable'),
upstream_subject.get('localityName', 'Unavailable'),
upstream_subject.get('organizationName', 'Unavailable'),
upstream_subject.get('organizationalUnitName', 'Unavailable'))
alt_subj_names = [text_(self.request.host), ]
validity_in_days = 365 * 2
timeout = 10
@ -412,7 +420,7 @@ class HttpProxyPlugin(HttpProtocolHandlerPlugin):
return os.path.join(ca_cert_dir, '%s.pem' % host)
def generate_upstream_certificate(
self, _certificate: Optional[Dict[str, Any]]) -> str:
self, certificate: Dict[str, Any]) -> str:
if not (self.flags.ca_cert_dir and self.flags.ca_signing_key_file and
self.flags.ca_cert_file and self.flags.ca_key_file):
raise HttpProtocolException(
@ -424,7 +432,7 @@ class HttpProxyPlugin(HttpProtocolHandlerPlugin):
self.flags.ca_cert_dir, text_(self.request.host))
with self.lock:
if not os.path.isfile(cert_file_path):
self.gen_ca_signed_certificate(cert_file_path)
self.gen_ca_signed_certificate(cert_file_path, certificate)
return cert_file_path
def wrap_server(self) -> None:
@ -432,7 +440,8 @@ class HttpProxyPlugin(HttpProtocolHandlerPlugin):
assert isinstance(self.server.connection, socket.socket)
ctx = ssl.create_default_context(
ssl.Purpose.SERVER_AUTH, cafile=self.flags.ca_file)
ctx.options |= ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3 | ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1
ctx.options |= ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3 | ssl.OP_NO_TLSv1
ctx.check_hostname = True
self.server.connection.setblocking(True)
self.server._conn = ctx.wrap_socket(
self.server.connection,
@ -450,7 +459,8 @@ class HttpProxyPlugin(HttpProtocolHandlerPlugin):
self.client.connection,
server_side=True,
certfile=generated_cert,
keyfile=self.flags.ca_signing_key_file)
keyfile=self.flags.ca_signing_key_file,
ssl_version=ssl.PROTOCOL_TLSv1_2)
self.client.connection.setblocking(False)
logger.debug(
'TLS interception using %s', generated_cert)

View File

@ -168,7 +168,8 @@ class TestHttpProxyTlsInterception(unittest.TestCase):
server_side=True,
keyfile=self.flags.ca_signing_key_file,
certfile=HttpProxyPlugin.generated_cert_file_path(
self.flags.ca_cert_dir, host)
self.flags.ca_cert_dir, host),
ssl_version=ssl.PROTOCOL_TLSv1_2
)
self.assertEqual(self._conn.setblocking.call_count, 2)
self.assertEqual(