Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: #1276 add Asyncio SQLAlchemy support #1633

Merged
merged 7 commits into from
Jan 28, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions requirements/testing.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,6 @@ boto3<=2
# For AWS tests
moto>=4.0.13,<6
mypy<=1.14.1
# For AsyncSQLAlchemy tests
greenlet<=4
aiosqlite<=1
281 changes: 280 additions & 1 deletion slack_sdk/oauth/installation_store/sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@
)
from sqlalchemy.engine import Engine
from sqlalchemy.sql.sqltypes import Boolean

from sqlalchemy.ext.asyncio import AsyncEngine
from slack_sdk.oauth.installation_store.installation_store import InstallationStore
from slack_sdk.oauth.installation_store.models.bot import Bot
from slack_sdk.oauth.installation_store.models.installation import Installation
from slack_sdk.oauth.installation_store.async_installation_store import (
AsyncInstallationStore,
)


class SQLAlchemyInstallationStore(InstallationStore):
Expand Down Expand Up @@ -362,3 +365,279 @@ def delete_installation(
)
)
conn.execute(deletion)


class AsyncSQLAlchemyInstallationStore(AsyncInstallationStore):
default_bots_table_name: str = "slack_bots"
default_installations_table_name: str = "slack_installations"

client_id: str
engine: AsyncEngine
metadata: MetaData
installations: Table

def __init__(
self,
client_id: str,
engine: AsyncEngine,
bots_table_name: str = default_bots_table_name,
installations_table_name: str = default_installations_table_name,
logger: Logger = logging.getLogger(__name__),
):
self.metadata = sqlalchemy.MetaData()
self.bots = self.build_bots_table(metadata=self.metadata, table_name=bots_table_name)
self.installations = self.build_installations_table(metadata=self.metadata, table_name=installations_table_name)
self.client_id = client_id
self._logger = logger
self.engine = engine

@classmethod
def build_installations_table(cls, metadata: MetaData, table_name: str) -> Table:
return SQLAlchemyInstallationStore.build_installations_table(metadata, table_name)

@classmethod
def build_bots_table(cls, metadata: MetaData, table_name: str) -> Table:
return SQLAlchemyInstallationStore.build_bots_table(metadata, table_name)

async def create_tables(self):
async with self.engine.begin() as conn:
await conn.run_sync(self.metadata.create_all)

@property
def logger(self) -> Logger:
return self._logger

async def async_save(self, installation: Installation):
async with self.engine.begin() as conn:
i = installation.to_dict()
i["client_id"] = self.client_id

i_column = self.installations.c
installations_rows = await conn.execute(
sqlalchemy.select(i_column.id)
.where(
and_(
i_column.client_id == self.client_id,
i_column.enterprise_id == installation.enterprise_id,
i_column.team_id == installation.team_id,
i_column.installed_at == i.get("installed_at"),
)
)
.limit(1)
)
installations_row_id: Optional[str] = None
for row in installations_rows.mappings():
installations_row_id = row["id"]
if installations_row_id is None:
await conn.execute(self.installations.insert(), i)
else:
update_statement = self.installations.update().where(i_column.id == installations_row_id).values(**i)
await conn.execute(update_statement, i)

# bots
await self.async_save_bot(installation.to_bot())

async def async_save_bot(self, bot: Bot):
async with self.engine.begin() as conn:
# bots
b = bot.to_dict()
b["client_id"] = self.client_id

b_column = self.bots.c
bots_rows = await conn.execute(
sqlalchemy.select(b_column.id)
.where(
and_(
b_column.client_id == self.client_id,
b_column.enterprise_id == bot.enterprise_id,
b_column.team_id == bot.team_id,
b_column.installed_at == b.get("installed_at"),
)
)
.limit(1)
)
bots_row_id: Optional[str] = None
for row in bots_rows.mappings():
bots_row_id = row["id"]
if bots_row_id is None:
await conn.execute(self.bots.insert(), b)
else:
update_statement = self.bots.update().where(b_column.id == bots_row_id).values(**b)
await conn.execute(update_statement, b)

async def async_find_bot(
self,
*,
enterprise_id: Optional[str],
team_id: Optional[str],
is_enterprise_install: Optional[bool] = False,
) -> Optional[Bot]:
if is_enterprise_install or team_id is None:
team_id = None

c = self.bots.c
query = (
self.bots.select()
.where(
and_(
c.client_id == self.client_id,
c.enterprise_id == enterprise_id,
c.team_id == team_id,
c.bot_token.is_not(None), # the latest one that has a bot token
)
)
.order_by(desc(c.installed_at))
.limit(1)
)

async with self.engine.connect() as conn:
result: object = await conn.execute(query)
for row in result.mappings(): # type: ignore[attr-defined]
return Bot(
galuszkak marked this conversation as resolved.
Show resolved Hide resolved
app_id=row["app_id"],
enterprise_id=row["enterprise_id"],
enterprise_name=row["enterprise_name"],
team_id=row["team_id"],
team_name=row["team_name"],
bot_token=row["bot_token"],
bot_id=row["bot_id"],
bot_user_id=row["bot_user_id"],
bot_scopes=row["bot_scopes"],
bot_refresh_token=row["bot_refresh_token"],
bot_token_expires_at=row["bot_token_expires_at"],
is_enterprise_install=row["is_enterprise_install"],
installed_at=row["installed_at"],
)
return None

async def async_find_installation(
self,
*,
enterprise_id: Optional[str],
team_id: Optional[str],
user_id: Optional[str] = None,
is_enterprise_install: Optional[bool] = False,
) -> Optional[Installation]:
if is_enterprise_install or team_id is None:
team_id = None

c = self.installations.c
where_clause = and_(
c.client_id == self.client_id,
c.enterprise_id == enterprise_id,
c.team_id == team_id,
)
if user_id is not None:
where_clause = and_(
c.client_id == self.client_id,
c.enterprise_id == enterprise_id,
c.team_id == team_id,
c.user_id == user_id,
)

query = self.installations.select().where(where_clause).order_by(desc(c.installed_at)).limit(1)

installation: Optional[Installation] = None
async with self.engine.connect() as conn:
result: object = await conn.execute(query)
for row in result.mappings(): # type: ignore[attr-defined]
installation = Installation(
galuszkak marked this conversation as resolved.
Show resolved Hide resolved
app_id=row["app_id"],
enterprise_id=row["enterprise_id"],
enterprise_name=row["enterprise_name"],
enterprise_url=row["enterprise_url"],
team_id=row["team_id"],
team_name=row["team_name"],
bot_token=row["bot_token"],
bot_id=row["bot_id"],
bot_user_id=row["bot_user_id"],
bot_scopes=row["bot_scopes"],
bot_refresh_token=row["bot_refresh_token"],
bot_token_expires_at=row["bot_token_expires_at"],
user_id=row["user_id"],
user_token=row["user_token"],
user_scopes=row["user_scopes"],
user_refresh_token=row["user_refresh_token"],
user_token_expires_at=row["user_token_expires_at"],
# Only the incoming webhook issued in the latest installation is set in this logic
incoming_webhook_url=row["incoming_webhook_url"],
incoming_webhook_channel=row["incoming_webhook_channel"],
incoming_webhook_channel_id=row["incoming_webhook_channel_id"],
incoming_webhook_configuration_url=row["incoming_webhook_configuration_url"],
is_enterprise_install=row["is_enterprise_install"],
token_type=row["token_type"],
installed_at=row["installed_at"],
)

has_user_installation = user_id is not None and installation is not None
no_bot_token_installation = installation is not None and installation.bot_token is None
should_find_bot_installation = has_user_installation or no_bot_token_installation
if should_find_bot_installation:
# Retrieve the latest bot token, just in case
# See also: https://github.com/slackapi/bolt-python/issues/664
latest_bot_installation = await self.async_find_bot(
enterprise_id=enterprise_id,
team_id=team_id,
is_enterprise_install=is_enterprise_install,
)
if (
latest_bot_installation is not None
and installation is not None
and installation.bot_token != latest_bot_installation.bot_token
):
installation.bot_id = latest_bot_installation.bot_id
installation.bot_user_id = latest_bot_installation.bot_user_id
installation.bot_token = latest_bot_installation.bot_token
installation.bot_scopes = latest_bot_installation.bot_scopes
installation.bot_refresh_token = latest_bot_installation.bot_refresh_token
installation.bot_token_expires_at = latest_bot_installation.bot_token_expires_at

return installation

async def async_delete_bot(
self,
*,
enterprise_id: Optional[str],
team_id: Optional[str],
) -> None:
table = self.bots
c = table.c
async with self.engine.begin() as conn:
deletion = table.delete().where(
and_(
c.client_id == self.client_id,
c.enterprise_id == enterprise_id,
c.team_id == team_id,
)
)
await conn.execute(deletion)

async def async_delete_installation(
self,
*,
enterprise_id: Optional[str],
team_id: Optional[str],
user_id: Optional[str] = None,
) -> None:
table = self.installations
c = table.c
async with self.engine.begin() as conn:
if user_id is not None:
deletion = table.delete().where(
and_(
c.client_id == self.client_id,
c.enterprise_id == enterprise_id,
c.team_id == team_id,
c.user_id == user_id,
)
)
await conn.execute(deletion)
else:
deletion = table.delete().where(
and_(
c.client_id == self.client_id,
c.enterprise_id == enterprise_id,
c.team_id == team_id,
)
)
await conn.execute(deletion)
71 changes: 71 additions & 0 deletions slack_sdk/oauth/state_store/sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
from uuid import uuid4

from ..state_store import OAuthStateStore
from ..async_state_store import AsyncOAuthStateStore
import sqlalchemy
from sqlalchemy import Table, Column, Integer, String, DateTime, and_, MetaData
from sqlalchemy.engine import Engine
from sqlalchemy.ext.asyncio import AsyncEngine


class SQLAlchemyOAuthStateStore(OAuthStateStore):
Expand Down Expand Up @@ -76,3 +78,72 @@ def consume(self, state: str) -> bool:
message = f"Failed to find any persistent data for state: {state} - {e}"
self.logger.warning(message)
return False


class AsyncSQLAlchemyOAuthStateStore(AsyncOAuthStateStore):
default_table_name: str = "slack_oauth_states"

expiration_seconds: int
engine: AsyncEngine
metadata: MetaData
oauth_states: Table

@classmethod
def build_oauth_states_table(cls, metadata: MetaData, table_name: str) -> Table:
return sqlalchemy.Table(
table_name,
metadata,
metadata,
Column("id", Integer, primary_key=True, autoincrement=True),
Column("state", String(200), nullable=False),
Column("expire_at", DateTime, nullable=False),
)

def __init__(
self,
expiration_seconds: int,
engine: Engine,
logger: Logger = logging.getLogger(__name__),
table_name: str = default_table_name,
):
self.expiration_seconds = expiration_seconds
self._logger = logger
self.engine = engine
self.metadata = MetaData()
self.oauth_states = self.build_oauth_states_table(self.metadata, table_name)

async def create_tables(self):
async with self.engine.begin() as conn:
await conn.run_sync(self.metadata.create_all)

@property
def logger(self) -> Logger:
if self._logger is None:
self._logger = logging.getLogger(__name__)
return self._logger

async def async_issue(self, *args, **kwargs) -> str:
state: str = str(uuid4())
now = datetime.utcfromtimestamp(time.time() + self.expiration_seconds)
async with self.engine.begin() as conn:
await conn.execute(
self.oauth_states.insert(),
{"state": state, "expire_at": now},
)
return state

async def async_consume(self, state: str) -> bool:
try:
async with self.engine.begin() as conn:
c = self.oauth_states.c
query = self.oauth_states.select().where(and_(c.state == state, c.expire_at > datetime.utcnow()))
result = await conn.execute(query)
for row in result.mappings():
self.logger.debug(f"consume's query result: {row}")
await conn.execute(self.oauth_states.delete().where(c.id == row["id"]))
return True
return False
except Exception as e:
message = f"Failed to find any persistent data for state: {state} - {e}"
self.logger.warning(message)
return False
Loading
Loading