Skip to content

Commit

Permalink
auth azure: support access token
Browse files Browse the repository at this point in the history
  • Loading branch information
LiliDeng committed Jan 7, 2025
1 parent 21abd2a commit f35f51f
Showing 1 changed file with 57 additions and 6 deletions.
63 changes: 57 additions & 6 deletions lisa/sut_orchestrator/azure/platform_.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import base64
import copy
import json
import logging
Expand All @@ -17,6 +18,7 @@
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union, cast

import requests
from azure.core.credentials import AccessToken, TokenCredential
from azure.core.exceptions import HttpResponseError, ResourceNotFoundError
from azure.identity import DefaultAzureCredential
from azure.mgmt.compute.models import (
Expand Down Expand Up @@ -246,6 +248,7 @@ class AzurePlatformSchema:
),
)
service_principal_key: str = field(default="")
access_token: str = field(default="")
subscription_id: str = field(
default="",
metadata=field_metadata(
Expand Down Expand Up @@ -320,6 +323,7 @@ def __post_init__(self, *args: Any, **kwargs: Any) -> None:
"service_principal_tenant_id",
"service_principal_client_id",
"service_principal_key",
"access_token",
"subscription_id",
"shared_resource_group_name",
"resource_group_name",
Expand All @@ -338,6 +342,8 @@ def __post_init__(self, *args: Any, **kwargs: Any) -> None:
add_secret(self.subscription_id, mask=PATTERN_GUID)
if self.service_principal_key:
add_secret(self.service_principal_key)
if self.access_token:
add_secret(self.access_token)
if self.service_principal_client_id:
add_secret(self.service_principal_client_id, mask=PATTERN_GUID)

Expand Down Expand Up @@ -400,21 +406,44 @@ def cloud(self, value: Optional[CloudSchema]) -> None:
self.cloud_raw = value.to_dict() # type: ignore


class StaticAccessTokenCredential(TokenCredential):
def __init__(self, token: str, expires_on: int) -> None:
"""
Initialize StaticAccessTokenCredential with the provided token and expiry time.
:param token: The Azure access token as a string.
:param expires_on: The expiry time of the token as an integer (Unix timestamp).
"""
self._token = token
self._expires_on = expires_on

def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
"""
Get the access token for the specified scopes.
:param scopes: The OAuth 2.0 scopes the token applies to.
:param kwargs: Additional keyword arguments that may be required by the SDK.
:return: An AccessToken instance containing the token and its expiry time.
"""
# You can choose to print or log the scopes and kwargs for debugging if needed
return AccessToken(self._token, self._expires_on)


class AzurePlatform(Platform):
_diagnostic_storage_container_pattern = re.compile(
r"(https:\/\/)(?P<storage_name>.*)([.].*){4}\/(?P<container_name>.*)\/",
re.M,
)
_arm_template: Any = None

_credentials: Dict[str, DefaultAzureCredential] = {}
_credentials: Dict[str, Union[DefaultAzureCredential, TokenCredential]] = {}
_locations_data_cache: Dict[str, AzureLocation] = {}

def __init__(self, runbook: schema.Platform) -> None:
super().__init__(runbook=runbook)

# for type detection
self.credential: DefaultAzureCredential
self.credential: Union[DefaultAzureCredential, TokenCredential]
self.cloud: Cloud

# It has to be defined after the class definition is loaded. So it
Expand Down Expand Up @@ -914,6 +943,16 @@ def _initialize(self, *args: Any, **kwargs: Any) -> None:
self.credential, self.subscription_id, self.cloud
)

def decode_jwt(self, token: str) -> Any:
# The second part of the JWT is the payload
payload = token.split(".")[1]
# Add padding to ensure Base64 decoding works properly
padded_payload = payload + "=" * (4 - len(payload) % 4)
# Decode the Base64 URL-safe encoded payload
decoded_payload = base64.urlsafe_b64decode(padded_payload)
# Convert the payload into a dictionary
return json.loads(decoded_payload)

def _initialize_credential(self) -> None:
azure_runbook = self._azure_runbook

Expand All @@ -936,10 +975,22 @@ def _initialize_credential(self) -> None:
] = azure_runbook.service_principal_client_id
if azure_runbook.service_principal_key:
os.environ["AZURE_CLIENT_SECRET"] = azure_runbook.service_principal_key

credential = DefaultAzureCredential(
authority=self.cloud.endpoints.active_directory,
)
if azure_runbook.access_token:
os.environ["AZURE_ACCESS_TOKEN"] = azure_runbook.access_token

if "AZURE_ACCESS_TOKEN" in os.environ:
# Parse the token to get the expiration timestamp
decoded_token = self.decode_jwt(os.environ["AZURE_ACCESS_TOKEN"])
# 'exp' is the UNIX timestamp for expiration
expiry_timestamp = decoded_token.get("exp")
credential = StaticAccessTokenCredential(
os.environ["AZURE_ACCESS_TOKEN"],
expiry_timestamp,
)
else:
credential = DefaultAzureCredential(
authority=self.cloud.endpoints.active_directory,
)

with SubscriptionClient(
credential,
Expand Down

0 comments on commit f35f51f

Please sign in to comment.