diff --git a/sparkmagic/sparkmagic/kernels/kernelmagics.py b/sparkmagic/sparkmagic/kernels/kernelmagics.py index 727b9937c..30a65cf2b 100644 --- a/sparkmagic/sparkmagic/kernels/kernelmagics.py +++ b/sparkmagic/sparkmagic/kernels/kernelmagics.py @@ -18,7 +18,7 @@ from sparkmagic.utils.sparkevents import SparkEvents from sparkmagic.utils.constants import LANGS_SUPPORTED from sparkmagic.livyclientlib.command import Command -from sparkmagic.livyclientlib.endpoint import Endpoint +from sparkmagic.livyclientlib.endpoint import Endpoint, SSLInfo from sparkmagic.magics.sparkmagicsbase import SparkMagicBase from sparkmagic.livyclientlib.exceptions import handle_expected_exceptions, wrap_unexpected_exceptions, \ BadUserDataException @@ -433,7 +433,12 @@ def matplot(self, line, cell="", local_ns=None): def refresh_configuration(self): credentials = getattr(conf, 'base64_kernel_' + self.language + '_credentials')() (username, password, auth, url) = (credentials['username'], credentials['password'], credentials['auth'], credentials['url']) - self.endpoint = Endpoint(url, auth, username, password) + (ssl_client_cert, ssl_client_key, ssl_verify) = (credentials.get('ssl_client_cert'), credentials.get('ssl_client_key'), credentials.get('ssl_verify'),) + if ssl_client_cert is None: + ssl_info = None + else: + ssl_info = SSLInfo(ssl_client_cert, ssl_client_key, ssl_verify) + self.endpoint = Endpoint(url, auth, username, password, ssl_info=ssl_info) def get_session_settings(self, line, force): line = line.strip() diff --git a/sparkmagic/sparkmagic/livyclientlib/endpoint.py b/sparkmagic/sparkmagic/livyclientlib/endpoint.py index 075e2b69d..a06b10331 100644 --- a/sparkmagic/sparkmagic/livyclientlib/endpoint.py +++ b/sparkmagic/sparkmagic/livyclientlib/endpoint.py @@ -3,7 +3,7 @@ class Endpoint(object): - def __init__(self, url, auth, username="", password="", implicitly_added=False): + def __init__(self, url, auth, username="", password="", implicitly_added=False, ssl_info=None): if not url: raise BadUserDataException(u"URL must not be empty") if auth not in AUTHS_SUPPORTED: @@ -13,6 +13,7 @@ def __init__(self, url, auth, username="", password="", implicitly_added=False): self.username = username self.password = password self.auth = auth + self.ssl_info = ssl_info # implicitly_added is set to True only if the endpoint wasn't configured manually by the user through # a widget, but was instead implicitly defined as an endpoint to a wrapper kernel in the configuration # JSON file. @@ -21,13 +22,37 @@ def __init__(self, url, auth, username="", password="", implicitly_added=False): def __eq__(self, other): if type(other) is not Endpoint: return False - return self.url == other.url and self.username == other.username and self.password == other.password and self.auth == other.auth + return self.url == other.url and self.username == other.username and self.password == other.password and self.auth == other.auth and self.ssl_info == other.ssl_info def __hash__(self): - return hash((self.url, self.username, self.password, self.auth)) + return hash((self.url, self.username, self.password, self.auth, self.ssl_info)) def __ne__(self, other): return not self == other def __str__(self): return u"Endpoint({})".format(self.url) + +class SSLInfo(object): + def __init__(self, client_cert, client_key, ssl_verify): + self.client_cert = client_cert + self.client_key = client_key + self.ssl_verify = ssl_verify + + @property + def cert(self): + return (self.client_cert, self.client_key, ) + + def __eq__(self, other): + if type(other) is not SSLInfo: + return False + return self.client_cert == other.client_cert and self.client_key == other.client_key and self.ssl_verify == other.ssl_verify + + def __hash__(self): + return hash((self.client_cert, self.client_key, self.ssl_verify)) + + def __ne__(self, other): + return not self == other + + def __str__(self): + return u"SSLInfo(client_cert={}, client_key={}, ssl_verify={})".format(self.client_cert, self.client_key, self.ssl_verify) diff --git a/sparkmagic/sparkmagic/livyclientlib/reliablehttpclient.py b/sparkmagic/sparkmagic/livyclientlib/reliablehttpclient.py index 5a9ed238a..95dffbe82 100644 --- a/sparkmagic/sparkmagic/livyclientlib/reliablehttpclient.py +++ b/sparkmagic/sparkmagic/livyclientlib/reliablehttpclient.py @@ -61,15 +61,37 @@ def _send_request_helper(self, url, accepted_status_codes, function, data, retry try: if self._endpoint.auth == constants.NO_AUTH: if data is None: - r = function(url, headers=self._headers, verify=self.verify_ssl) + if self._endpoint.ssl_info is None: + r = function(url, headers=self._headers, verify=self.verify_ssl) + else: + r = function(url, headers=self._headers, + verify=self._endpoint.ssl_info.ssl_verify, + cert=self._endpoint.ssl_info.cert) else: - r = function(url, headers=self._headers, data=json.dumps(data), verify=self.verify_ssl) + if self._endpoint.ssl_info is None: + r = function(url, headers=self._headers, data=json.dumps(data), verify=self.verify_ssl) + else: + r = function(url, headers=self._headers, data=json.dumps(data), + verify=self._endpoint.ssl_info.ssl_verify, + cert=self._endpoint.ssl_info.cert) else: if data is None: - r = function(url, headers=self._headers, auth=self._auth, verify=self.verify_ssl) + if self._endpoint.ssl_info is None: + r = function(url, headers=self._headers, auth=self._auth, verify=self.verify_ssl) + else: + r = function(url, headers=self._headers, auth=self._auth, + verify=self._endpoint.ssl_info.ssl_verify, + cert=self._endpoint.ssl_info.cert) else: - r = function(url, headers=self._headers, auth=self._auth, - data=json.dumps(data), verify=self.verify_ssl) + if self._endpoint.ssl_info is None: + r = function(url, headers=self._headers, auth=self._auth, + data=json.dumps(data), verify=self.verify_ssl) + else: + r = function(url, headers=self._headers, auth=self._auth, + data=json.dumps(data), + verify=self._endpoint.ssl_info.ssl_verify, + cert=self._endpoint.ssl_info.cert) + except requests.exceptions.RequestException as e: error = True r = None diff --git a/sparkmagic/sparkmagic/tests/test_configuration.py b/sparkmagic/sparkmagic/tests/test_configuration.py index 52ca0e2f5..79688db3c 100644 --- a/sparkmagic/sparkmagic/tests/test_configuration.py +++ b/sparkmagic/sparkmagic/tests/test_configuration.py @@ -20,7 +20,10 @@ def test_configuration_override_base64_password(): assert_equals(conf.d, { conf.kernel_python_credentials.__name__: kpc, conf.livy_session_startup_timeout_seconds.__name__: 1 }) assert_equals(conf.livy_session_startup_timeout_seconds(), 1) - assert_equals(conf.base64_kernel_python_credentials(), { 'username': 'U', 'password': 'password', 'url': 'L', 'auth': AUTH_BASIC }) + assert_equals(conf.base64_kernel_python_credentials(), { + 'username': 'U', 'password': 'password', 'url': 'L', 'auth': AUTH_BASIC, + 'ssl_client_cert': None, 'ssl_client_key': None, 'ssl_verify': None + }) @with_setup(_setup) @@ -28,7 +31,10 @@ def test_configuration_auth_missing_basic_auth(): kpc = { 'username': 'U', 'password': 'P', 'url': 'L'} overrides = { conf.kernel_python_credentials.__name__: kpc } conf.override_all(overrides) - assert_equals(conf.base64_kernel_python_credentials(), { 'username': 'U', 'password': 'P', 'url': 'L', 'auth': AUTH_BASIC }) + assert_equals(conf.base64_kernel_python_credentials(), { + 'username': 'U', 'password': 'P', 'url': 'L', 'auth': AUTH_BASIC, + 'ssl_client_cert': None, 'ssl_client_key': None, 'ssl_verify': None + }) @with_setup(_setup) @@ -36,12 +42,18 @@ def test_configuration_auth_missing_no_auth(): kpc = { 'username': '', 'password': '', 'url': 'L'} overrides = { conf.kernel_python_credentials.__name__: kpc } conf.override_all(overrides) - assert_equals(conf.base64_kernel_python_credentials(), { 'username': '', 'password': '', 'url': 'L', 'auth': NO_AUTH }) + assert_equals(conf.base64_kernel_python_credentials(), { + 'username': '', 'password': '', 'url': 'L', 'auth': NO_AUTH, + 'ssl_client_cert': None, 'ssl_client_key': None, 'ssl_verify': None + }) @with_setup(_setup) def test_configuration_override_fallback_to_password(): - kpc = { 'username': 'U', 'password': 'P', 'url': 'L', 'auth': NO_AUTH } + kpc = { + 'username': 'U', 'password': 'P', 'url': 'L', 'auth': NO_AUTH, + 'ssl_client_cert': None, 'ssl_client_key': None, 'ssl_verify': None + } overrides = { conf.kernel_python_credentials.__name__: kpc } conf.override_all(overrides) conf.override(conf.livy_session_startup_timeout_seconds.__name__, 1) @@ -60,7 +72,12 @@ def test_configuration_override_work_with_empty_password(): assert_equals(conf.d, { conf.kernel_python_credentials.__name__: kpc, conf.livy_session_startup_timeout_seconds.__name__: 1 }) assert_equals(conf.livy_session_startup_timeout_seconds(), 1) - assert_equals(conf.base64_kernel_python_credentials(), { 'username': 'U', 'password': '', 'url': '', 'auth': AUTH_BASIC }) + assert_equals( + conf.base64_kernel_python_credentials(), { + 'username': 'U', 'password': '', 'url': '', 'auth': AUTH_BASIC, + 'ssl_client_cert': None, 'ssl_client_key': None, 'ssl_verify': None + } + ) @raises(BadUserConfigurationException) diff --git a/sparkmagic/sparkmagic/utils/configuration.py b/sparkmagic/sparkmagic/utils/configuration.py index 51ae9a31f..4fa145c5c 100644 --- a/sparkmagic/sparkmagic/utils/configuration.py +++ b/sparkmagic/sparkmagic/utils/configuration.py @@ -260,7 +260,9 @@ def _credentials_override(f): If 'base64_password' is not set, it will fallback to 'password' in config. """ credentials = f() - base64_decoded_credentials = {k: credentials.get(k) for k in ('username', 'password', 'url', 'auth')} + base64_decoded_credentials = {k: credentials.get(k) for k in ( + 'username', 'password', 'url', 'auth', 'ssl_client_cert', 'ssl_client_key', 'ssl_verify' + )} base64_password = credentials.get('base64_password') if base64_password is not None: try: