Skip to content

Commit

Permalink
Merge branch 'main' into support/3.2
Browse files Browse the repository at this point in the history
  • Loading branch information
dsuch committed Nov 7, 2023
2 parents 3ce04e4 + 5741d02 commit 420d128
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 31 deletions.
15 changes: 13 additions & 2 deletions code/zato-common/src/zato/common/bearer_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ def _build_bearer_token_info(self, sec_def_name:'str', data:'stranydict') -> 'Be
out.expires_in_sec = expires_in_sec
out.expiration_time = expiration_time

# .. indicate that this token was not found in cache ..
out.is_cache_hit = False

# .. and return it to our caller.
return out

# ################################################################################################################################
Expand Down Expand Up @@ -179,7 +183,8 @@ def _get_bearer_token_from_auth_server(
# .. turn into a business object that represents the token ..
info = self._build_bearer_token_info(config.sec_def_name, data)

msg = f'Bearer token received for `{config.sec_def_name}`; expires_in={info.expires_in_sec} ({info.expires_in})'
msg = f'Bearer token received for `{config.sec_def_name}`'
msg += f'; expires_in={info.expires_in_sec} ({info.expires_in} -> {info.expiration_time} UTC)'
msg += f'; scopes={info.scopes}'
logger.info(msg)

Expand Down Expand Up @@ -220,10 +225,11 @@ def _get_bearer_token_from_cache(self, sec_def_name:'str', scopes:'str') -> 'Bea
key = self._get_cache_key(sec_def_name, scopes)

# .. try to get the token information from our cache ..
info = self.cache_api.default.get(key)
info = self.cache_api.default.get(key) # type: BearerTokenInfo

# .. and return it to our caller only if it actually exists.
if info and info != ZATO_NOT_GIVEN:
info.is_cache_hit = True
return info

# ################################################################################################################################
Expand All @@ -243,8 +249,13 @@ def _store_bearer_token_in_cache(self, info:'BearerTokenInfo', scopes:'str') ->
# .. store the token ..
self.cache_api.default.set(key, info, expiry=expiry)

# .. make it known when exactly the key will expire ..
expiry_in = timedelta(seconds=expiry)
expiry_time = datetime.now(tz=timezone.utc) + expiry_in

# .. and log what we have done.
msg = f'Bearer token for `{info.sec_def_name}` cached under key `{key}`'
msg += f'; expiry={expiry} ({expiry_in} -> {expiry_time} UTC)'
logger.info(msg)

# ################################################################################################################################
Expand Down
1 change: 1 addition & 0 deletions code/zato-common/src/zato/common/model/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class BearerTokenConfig(Model):

@dataclass(init=False)
class BearerTokenInfo(Model):
is_cache_hit: 'bool'
sec_def_name: 'str'
token:'str'
token_type:'str'
Expand Down
83 changes: 54 additions & 29 deletions code/zato-server/src/zato/server/connection/http_soap/outgoing.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@

if 0:
from sqlalchemy.orm.session import Session as SASession
from zato.common.bearer_token import BearerTokenInfo
from zato.common.typing_ import any_, callnone, dictnone, list_, stranydict, strdictnone, strstrdict, type_
from zato.server.base.parallel import ParallelServer
from zato.server.config import ConfigDict
Expand Down Expand Up @@ -142,6 +143,8 @@ def invoke_http(
**kwargs:'any_'
) -> '_RequestsResponse':

# Local variables
params = kwargs.get('params')
json = kwargs.pop('json', None)
cert = self.config['tls_key_cert_full_path'] if self.sec_type == _TLS_Key_Cert else None

Expand All @@ -164,14 +167,17 @@ def invoke_http(
# Force type hints
sec_def_name = cast_('str', sec_def_name)

# Reusable
is_bearer_token = _sec_type == _OAuth

# OAuth scopes can be provided on input even if we do not have a Bearer token definition attached,
# which is why we .pop them here, to make sure they do not propagate to the requests library.
scopes = kwargs.pop('auth_scopes', '')

try:

# Bearer tokens are obtained dynamically ..
if _sec_type == _OAuth:
if is_bearer_token:

# .. this is reusable ..
sec_def = self.server.security_facade.bearer_token[sec_def_name]
Expand All @@ -186,37 +192,64 @@ def invoke_http(
scopes = ' '.join(scopes)

# .. get a Bearer token ..
auth_header = self._get_bearer_token_auth(sec_def_name, scopes, data_format)
info = self._get_bearer_token_auth(sec_def_name, scopes, data_format)

# .. populate headers ..
headers['Authorization'] = auth_header
headers['Authorization'] = f'Bearer {info.token}'

# .. this is needed for later use ..
token_expires_in_sec = info.expires_in_sec
token_is_cache_hit = info.is_cache_hit

# This is needed by request
auth = None

# .. otherwise, the credentials will have been already obtained
# .. but note that Suds connections don't have requests_auth, hence the getattr call.
# .. we enter here if this is not a Bearer token definition ..
else:

# .. otherwise, the credentials will have been already obtained
# .. but note that Suds connections don't have requests_auth, hence the getattr call ..
auth = getattr(self, 'requests_auth', None)

return self.session.request(
# .. we have no token to report about.
token_expires_in_sec = None
token_is_cache_hit = None

# .. basic details about what we are sending what we are sending ..
msg = f'REST out → cid={cid}; {method} {address}; name:{self.config["name"]}; params={params}; len={len(data)}' + \
f'; sec={sec_def_name} ({_sec_type})'

# .. optionally, log details of the Bearer token ..
if is_bearer_token:
msg += f'; token-expiry={token_expires_in_sec}; token-cache-hit={token_is_cache_hit}'

# .. log the information about our request ..
logger.info(msg)

# .. do send it ..
response = self.session.request(
method, address, data=data, json=json, auth=auth, headers=headers, hooks=hooks,
cert=cert, verify=tls_verify, timeout=self.config['timeout'], *args, **kwargs)

# .. log what we received ..
msg = f'REST out ← cid={cid}; {response.status_code} time={response.elapsed}; len={len(response.text)}'
logger.info(msg)

# .. and return it.
return response

except RequestsTimeout:
raise TimeoutException(cid, format_exc())

# ################################################################################################################################

def _get_bearer_token_auth(self, sec_def_name:'str', scopes:'str', data_format:'str') -> 'str':
def _get_bearer_token_auth(self, sec_def_name:'str', scopes:'str', data_format:'str') -> 'BearerTokenInfo':

# This will get the token from cache or from the remote auth. server ..
info = self.server.bearer_token_manager.get_bearer_token_info_by_sec_def_name(sec_def_name, scopes, data_format)

# .. now, we can build the authorization header ..
out = f'Bearer {info.token}'

# .. and return it to our caller.
return out
# .. which we can return to our caller.
return info

# ################################################################################################################################

Expand Down Expand Up @@ -403,7 +436,7 @@ def __str__(self) -> 'str':

# ################################################################################################################################

def format_address(self, cid:'str', params:'stranydict') -> 'tuple[str, dict]':
def format_address(self, cid:'str', params:'stranydict') -> 'tuple[str, stranydict]':
""" Formats a URL path to an external resource. Note that exceptions raised
do not contain anything except for CID. This is in order to keep any potentially
sensitive data from leaking to clients.
Expand Down Expand Up @@ -541,19 +574,11 @@ def http_request(
if isinstance(data, str):
data = data.encode('utf-8')

# Log what we are sending ..
msg = f'REST out → cid={cid}; {method} {address} name:{self.config["name"]}; params={params}; len={len(data)}'
logger.info(msg)

# .. do invoke the connection ..
response = self.invoke_http(cid, method, address, data, headers, {}, params=qs_params, *args, **kwargs)

# .. by default, we have no response at all ..
response.data = None

# .. now, log what we received.
msg = f'REST out ← cid={cid}; {response.status_code} time={response.elapsed}; len={len(response.text)}'
logger.info(msg)
response.data = None # type: ignore

# .. check if we are explicitly told that we handle JSON ..
_has_data_format_json = self.config['data_format'] == DATA_FORMAT.JSON
Expand Down Expand Up @@ -644,7 +669,7 @@ def rest_call(
self,
*,
cid, # type: str
data='', # type: str
data='', # type: ignore
model=None, # type: type_[Model] | None
callback, # type: callnone
params=None, # type: strdictnone
Expand Down Expand Up @@ -684,28 +709,28 @@ def rest_call(
if model:

# .. if this model is actually a list ..
if is_list(model, True):
if is_list(model, True): # type: ignore

# .. extract the underlying model ..
model_class:'type_[Model]' = extract_model_class(model)
model_class:'type_[Model]' = extract_model_class(model) # type: ignore

# .. build a list that we will map the response to ..
data:'list_[Model]' = []
data:'list_[Model]' = [] # type: ignore

# .. go through everything we had in the response ..
for item in response_data:
for item in response_data: # type: ignore

# .. build an actual model instance ..
_item = model_class.from_dict(item)

# .. and append it to the data that we are producing ..
data.append(_item)
data.append(_item) # type: ignore
else:
data:'Model' = model.from_dict(response_data)

# .. if there is no model, use the response as-is ..
else:
data = response_data
data = response_data # type: ignore

# .. run our callback, if there is any ..
if callback:
Expand Down

0 comments on commit 420d128

Please sign in to comment.