Skip to content

Commit

Permalink
For pyspark
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego committed Apr 12, 2024
1 parent 80f023b commit cf49e9d
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 82 deletions.
40 changes: 2 additions & 38 deletions queries/pyspark/executor.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,4 @@
from linetimer import CodeTimer

# TODO: works for now, but need dynamic imports for this.
from queries.pyspark import ( # noqa: F401
q1,
q2,
q3,
q4,
q5,
q6,
q7,
q8,
q9,
q10,
q11,
q12,
q13,
q14,
q15,
q16,
q17,
q18,
q19,
q20,
q21,
q22,
)
from queries.common_utils import execute_all

if __name__ == "__main__":
num_queries = 22

with CodeTimer(name="Overall execution of ALL spark queries", unit="s"):
for query_number in range(1, num_queries + 1):
submodule = f"q{query_number}"
try:
eval(f"{submodule}.q()")
except Exception as exc:
print(
f"Exception occurred while executing PySpark query {query_number}:\n{exc}"
)
execute_all("pyspark")
79 changes: 35 additions & 44 deletions queries/pyspark/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,11 @@

from pyspark.sql import SparkSession

from queries.common_utils import (
check_query_result_pd,
on_second_call,
run_query_generic,
)
from queries.common_utils import check_query_result_pd, run_query_generic
from settings import Settings

if TYPE_CHECKING:
from pathlib import Path

from pyspark.sql import DataFrame as SparkDF
from pyspark.sql import DataFrame

settings = Settings()

Expand All @@ -31,62 +25,59 @@ def get_or_create_spark() -> SparkSession:
return spark


def _read_parquet_ds(path: Path, table_name: str) -> SparkDF:
df = get_or_create_spark().read.parquet(str(path))
df.createOrReplaceTempView(table_name)
return df
def _read_ds(table_name: str) -> DataFrame:
# TODO: Persist data in memory before query
if not settings.run.include_io:
msg = "cannot run PySpark starting from an in-memory representation"
raise RuntimeError(msg)

path = settings.dataset_base_dir / f"{table_name}.{settings.run.file_type}"

@on_second_call
def get_line_item_ds() -> SparkDF:
return _read_parquet_ds(settings.dataset_base_dir / "lineitem.parquet", "lineitem")
if settings.run.file_type == "parquet":
df = get_or_create_spark().read.parquet(str(path))
elif settings.run.file_type == "csv":
df = get_or_create_spark().read.csv(str(path), header=True, inferSchema=True)
else:
msg = f"unsupported file type: {settings.run.file_type!r}"
raise ValueError(msg)

df.createOrReplaceTempView(table_name)
return df


@on_second_call
def get_orders_ds() -> SparkDF:
return _read_parquet_ds(settings.dataset_base_dir / "orders.parquet", "orders")
def get_line_item_ds() -> DataFrame:
return _read_ds("lineitem")


@on_second_call
def get_customer_ds() -> SparkDF:
return _read_parquet_ds(settings.dataset_base_dir / "customer.parquet", "customer")
def get_orders_ds() -> DataFrame:
return _read_ds("orders")


@on_second_call
def get_region_ds() -> SparkDF:
return _read_parquet_ds(settings.dataset_base_dir / "region.parquet", "region")
def get_customer_ds() -> DataFrame:
return _read_ds("customer")


@on_second_call
def get_nation_ds() -> SparkDF:
return _read_parquet_ds(settings.dataset_base_dir / "nation.parquet", "nation")
def get_region_ds() -> DataFrame:
return _read_ds("region")


@on_second_call
def get_supplier_ds() -> SparkDF:
return _read_parquet_ds(settings.dataset_base_dir / "supplier.parquet", "supplier")
def get_nation_ds() -> DataFrame:
return _read_ds("nation")


@on_second_call
def get_part_ds() -> SparkDF:
return _read_parquet_ds(settings.dataset_base_dir / "part.parquet", "part")
def get_supplier_ds() -> DataFrame:
return _read_ds("supplier")


@on_second_call
def get_part_supp_ds() -> SparkDF:
return _read_parquet_ds(settings.dataset_base_dir / "partsupp.parquet", "partsupp")
def get_part_ds() -> DataFrame:
return _read_ds("part")


def drop_temp_view() -> None:
spark = get_or_create_spark()
[
spark.catalog.dropTempView(t.name)
for t in spark.catalog.listTables()
if t.isTemporary
]
def get_part_supp_ds() -> DataFrame:
return _read_ds("partsupp")


def run_query(query_number: int, df: SparkDF) -> None:
def run_query(query_number: int, df: DataFrame) -> None:
query = df.toPandas
run_query_generic(
query, query_number, "pyspark", query_checker=check_query_result_pd
Expand Down

0 comments on commit cf49e9d

Please sign in to comment.