Skip to content

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
  • Loading branch information
jalencato committed Oct 4, 2024
1 parent 393a126 commit e84d874
Showing 1 changed file with 30 additions and 7 deletions.
37 changes: 30 additions & 7 deletions graphstorm-processing/tests/test_dist_heterogenous_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -1037,14 +1037,14 @@ def test_edge_custom_label(spark, dghl_loader: DistHeterogeneousGraphLoader, tmp

def test_node_custom_label_multitask(spark, dghl_loader: DistHeterogeneousGraphLoader, tmp_path):
"""Test using custom label splits for nodes"""
data = [(i,) for i in range(1, 11)]
data = [(i,) for i in range(1, 120)]

# Create DataFrame
nodes_df = spark.createDataFrame(data, ["orig"])

train_df = spark.createDataFrame([(i,) for i in range(1, 6)], ["mask_id"])
val_df = spark.createDataFrame([(i,) for i in range(6, 9)], ["mask_id"])
test_df = spark.createDataFrame([(i,) for i in range(9, 11)], ["mask_id"])
train_df = spark.createDataFrame([(i,) for i in range(1, 100)], ["mask_id"])
val_df = spark.createDataFrame([(i,) for i in range(101, 110)], ["mask_id"])
test_df = spark.createDataFrame([(i,) for i in range(111, 120)], ["mask_id"])

train_df.repartition(1).write.parquet(f"{tmp_path}/train.parquet")
val_df.repartition(1).write.parquet(f"{tmp_path}/val.parquet")
Expand All @@ -1054,6 +1054,7 @@ def test_node_custom_label_multitask(spark, dghl_loader: DistHeterogeneousGraphL
f"custom_split_val_mask",
f"custom_split_test_mask",
]
# Will only do custom data split although provided split rate
config_dict = {
"column": "orig",
"type": "classification",
Expand Down Expand Up @@ -1084,9 +1085,31 @@ def test_node_custom_label_multitask(spark, dghl_loader: DistHeterogeneousGraphL
train_total_ones = train_mask_df.agg(F.sum("custom_split_train_mask")).collect()[0][0]
val_total_ones = val_mask_df.agg(F.sum("custom_split_val_mask")).collect()[0][0]
test_total_ones = test_mask_df.agg(F.sum("custom_split_test_mask")).collect()[0][0]
assert train_total_ones == 5
assert val_total_ones == 3
assert test_total_ones == 2
assert train_total_ones == 99
assert val_total_ones == 9
assert test_total_ones == 9

# Check the order of the train_mask_df
train_mask_df = train_mask_df.withColumn(
"order_check_id", F.monotonically_increasing_id()
)
val_mask_df = val_mask_df.withColumn(
"order_check_id", F.monotonically_increasing_id()
)
test_mask_df = test_mask_df.withColumn(
"order_check_id", F.monotonically_increasing_id()
)
train_mask_df = train_mask_df.filter((F.col("order_check_id") <= 98)).drop("order_check_id")
val_mask_df = val_mask_df.filter((F.col("order_check_id") >= 100) & (F.col("order_check_id") < 108)).drop("order_check_id")
test_mask_df = test_mask_df.filter((F.col("order_check_id") >= 110) & (F.col("order_check_id") < 118)).drop("order_check_id")

train_unique_rows = train_mask_df.distinct().collect()
train_mask_df.show(n=100)
assert len(train_unique_rows) == 1 and all(value == 1 for value in train_unique_rows[0])
val_unique_rows = val_mask_df.distinct().collect()
assert len(val_unique_rows) == 1 and all(value == 1 for value in val_unique_rows[0])
test_unique_rows = test_mask_df.distinct().collect()
assert len(test_unique_rows) == 1 and all(value == 1 for value in test_unique_rows[0])


def test_edge_custom_label_multitask(spark, dghl_loader: DistHeterogeneousGraphLoader, tmp_path):
Expand Down

0 comments on commit e84d874

Please sign in to comment.