Skip to content

Commit

Permalink
pr feedback part 1
Browse files Browse the repository at this point in the history
  • Loading branch information
artek0chumak committed Nov 8, 2024
1 parent 487fbae commit a400517
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 38 deletions.
10 changes: 7 additions & 3 deletions src/together/resources/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def create(
n_evals (int, optional): Number of evaluation loops to run. Defaults to 0.
n_checkpoints (int, optional): Number of checkpoints to save during fine-tuning.
Defaults to 1.
batch_size (int, optional): Batch size for fine-tuning. Defaults to max.
batch_size (int or "max"): Batch size for fine-tuning. Defaults to max.
learning_rate (float, optional): Learning rate multiplier to use for training
Defaults to 0.00001.
warmup_ratio (float, optional): Warmup ratio for learning rate scheduler.
Expand All @@ -157,7 +157,11 @@ def create(
Defaults to False.
model_limits (FinetuneTrainingLimits, optional): Limits for the hyperparameters the model in Fine-tuning.
Defaults to None.
train_on_inputs (bool, optional): Whether to mask the user messages in conversational data or prompts in instruction data.
train_on_inputs (bool or "auto"): Whether to mask the user messages in conversational data or prompts in instruction data.
"auto" will automatically determine whether to mask the inputs based on the data format.
Dataset with "text" (General format) field will not mask the inputs by default.
Dataset with "messages" (Conversational format) or "prompt" and "completion" (Instruction format)
fields will mask the inputs by default.
Defaults to "auto".
Returns:
Expand Down Expand Up @@ -472,7 +476,7 @@ async def create(
Defaults to False.
model_limits (FinetuneTrainingLimits, optional): Limits for the hyperparameters the model in Fine-tuning.
Defaults to None.
train_on_inputs (bool, optional): Whether to mask the inputs in conversational data. Defaults to "auto".
train_on_inputs (bool or "auto"): Whether to mask the user messages in conversational data or prompts in instruction data.
Returns:
FinetuneResponse: Object containing information about fine-tuning job.
Expand Down
42 changes: 23 additions & 19 deletions src/together/utils/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,19 @@
)


class InvalidFileFormatError(Exception):
class InvalidFileFormatError(ValueError):
"""Exception raised for invalid file formats during file checks."""

def __init__(
self,
message: str = "",
line_number: int | None = None,
field: str | None = None,
error_source: str | None = None,
) -> None:
super().__init__(message)
self.message = message
self.line_number = line_number
self.field = field
self.error_source = error_source


def check_file(
Expand All @@ -50,7 +50,7 @@ def check_file(
"line_type": None,
"text_field": None,
"key_value": None,
"min_samples": None,
"has_min_samples": None,
"num_samples": None,
"load_json": None,
}
Expand Down Expand Up @@ -121,7 +121,7 @@ def _check_jsonl(file: Path) -> Dict[str, Any]:
'Example of valid json: {"text": "my sample string"}. '
),
line_number=idx + 1,
field="line_type",
error_source="line_type",
)

current_format = None
Expand All @@ -137,6 +137,7 @@ def _check_jsonl(file: Path) -> Dict[str, Any]:
message="Found multiple dataset formats in the input file. "
f"Got {current_format} and {possible_format} on line {idx + 1}.",
line_number=idx + 1,
error_source="format",
)

if current_format is None:
Expand All @@ -146,6 +147,7 @@ def _check_jsonl(file: Path) -> Dict[str, Any]:
f"{json_line.keys()}"
),
line_number=idx + 1,
error_source="format",
)

if current_format == DatasetFormat.CONVERSATION:
Expand All @@ -157,43 +159,43 @@ def _check_jsonl(file: Path) -> Dict[str, Any]:
message=f"Invalid format on line {idx + 1} of the input file. "
f"Expected a list of messages. Found {type(json_line[message_column])}",
line_number=idx + 1,
field="key_value",
error_source="key_value",
)

previous_role = ""
for turn in json_line[message_column]:
for column in REQUIRED_COLUMNS_MESSAGE:
if column not in turn:
raise InvalidFileFormatError(
message=f"Field '{column}' is missing for a turn `{turn}` on line {idx + 1} "
message=f"Field `{column}` is missing for a turn `{turn}` on line {idx + 1} "
"of the the input file.",
line_number=idx + 1,
field="key_value",
error_source="key_value",
)
else:
if not isinstance(turn[column], str):
raise InvalidFileFormatError(
message=f"Invalid format on line {idx + 1} in the column {column} for turn `{turn}` "
f"of the input file. Expected string. Found {type(turn[column])}",
line_number=idx + 1,
field="text_field",
error_source="text_field",
)
role = turn["role"]

if role not in POSSIBLE_ROLES_CONVERSATION:
raise InvalidFileFormatError(
message=f"Found invalid role '{role}' in the messages on the line {idx + 1}. "
message=f"Found invalid role `{role}` in the messages on the line {idx + 1}. "
f"Possible roles in the conversation are: {POSSIBLE_ROLES_CONVERSATION}",
line_number=idx + 1,
field="key_value",
error_source="key_value",
)

if previous_role == role:
raise InvalidFileFormatError(
message=f"Invalid role turns on line {idx + 1} of the input file. "
"'user' and 'assistant' roles must alternate user/assistant/user/assistant/...",
"`user` and `assistant` roles must alternate user/assistant/user/assistant/...",
line_number=idx + 1,
field="key_value",
error_source="key_value",
)

previous_role = role
Expand All @@ -205,7 +207,7 @@ def _check_jsonl(file: Path) -> Dict[str, Any]:
message=f'Invalid value type for "{column}" key on line {idx + 1}. '
f"Expected string. Found {type(json_line[column])}.",
line_number=idx + 1,
field="key_value",
error_source="key_value",
)

if dataset_format is None:
Expand All @@ -217,18 +219,19 @@ def _check_jsonl(file: Path) -> Dict[str, Any]:
f"Got {dataset_format} for the first line and {current_format} "
f"for the line {idx + 1}.",
line_number=idx + 1,
error_source="format",
)

if idx + 1 < MIN_SAMPLES:
report_dict["min_samples"] = False
report_dict["has_min_samples"] = False
report_dict["message"] = (
f"Processing {file} resulted in only {idx + 1} samples. "
f"Our minimum is {MIN_SAMPLES} samples. "
)
report_dict["is_check_passed"] = False
else:
report_dict["num_samples"] = idx + 1
report_dict["min_samples"] = True
report_dict["has_min_samples"] = True
report_dict["is_check_passed"] = True

report_dict["load_json"] = True
Expand All @@ -251,8 +254,8 @@ def _check_jsonl(file: Path) -> Dict[str, Any]:
report_dict["message"] = e.message
if e.line_number is not None:
report_dict["line_number"] = e.line_number
if e.field is not None:
report_dict[e.field] = False
if e.error_source is not None:
report_dict[e.error_source] = False

if "text_field" not in report_dict:
report_dict["text_field"] = True
Expand Down Expand Up @@ -295,7 +298,8 @@ def _check_parquet(file: Path) -> Dict[str, Any]:

num_samples = len(table)
if num_samples < MIN_SAMPLES:
report_dict["min_samples"] = (
report_dict["has_min_samples"] = False
report_dict["message"] = (
f"Processing {file} resulted in only {num_samples} samples. "
f"Our minimum is {MIN_SAMPLES} samples. "
)
Expand Down
32 changes: 16 additions & 16 deletions tests/unit/test_files_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,32 +11,32 @@ def test_check_jsonl_valid_general(tmp_path: Path):
file = tmp_path / "valid.jsonl"
content = [{"text": "Hello, world!"}, {"text": "How are you?"}]
with file.open("w") as f:
f.write("\n".join([json.dumps(item) for item in content]))
f.write("\n".join(json.dumps(item) for item in content))

report = check_file(file)

assert report["is_check_passed"]
assert report["utf8"]
assert report["num_samples"] == len(content)
assert report["min_samples"] >= MIN_SAMPLES
assert report["has_min_samples"]


def test_check_jsonl_valid_instruction(tmp_path: Path):
# Create a valid JSONL file with instruction format
file = tmp_path / "valid_instruction.jsonl"
content = [
{"prompt": "Translate the following sentence.", "completion": "Hello, world!"},
{"prompt": "Summarize the text.", "completion": "OpenAI creates advanced AI."},
{"prompt": "Summarize the text.", "completion": "Weyland-Yutani Corporation creates advanced AI."},
]
with file.open("w") as f:
f.write("\n".join([json.dumps(item) for item in content]))
f.write("\n".join(json.dumps(item) for item in content))

report = check_file(file)

assert report["is_check_passed"]
assert report["utf8"]
assert report["num_samples"] == len(content)
assert report["min_samples"] >= MIN_SAMPLES
assert report["has_min_samples"]


def test_check_jsonl_valid_conversational_single_turn(tmp_path: Path):
Expand All @@ -57,14 +57,14 @@ def test_check_jsonl_valid_conversational_single_turn(tmp_path: Path):
},
]
with file.open("w") as f:
f.write("\n".join([json.dumps(item) for item in content]))
f.write("\n".join(json.dumps(item) for item in content))

report = check_file(file)

assert report["is_check_passed"]
assert report["utf8"]
assert report["num_samples"] == len(content)
assert report["min_samples"] >= MIN_SAMPLES
assert report["has_min_samples"]


def test_check_jsonl_valid_conversational_multiple_turns(tmp_path: Path):
Expand Down Expand Up @@ -92,14 +92,14 @@ def test_check_jsonl_valid_conversational_multiple_turns(tmp_path: Path):
},
]
with file.open("w") as f:
f.write("\n".join([json.dumps(item) for item in content]))
f.write("\n".join(json.dumps(item) for item in content))

report = check_file(file)

assert report["is_check_passed"]
assert report["utf8"]
assert report["num_samples"] == len(content)
assert report["min_samples"] >= MIN_SAMPLES
assert report["has_min_samples"]


def test_check_jsonl_empty_file(tmp_path: Path):
Expand Down Expand Up @@ -131,7 +131,7 @@ def test_check_jsonl_invalid_json(tmp_path: Path):
file = tmp_path / "invalid_json.jsonl"
content = [{"text": "Hello, world!"}, "Invalid JSON Line"]
with file.open("w") as f:
f.write("\n".join([json.dumps(item) for item in content]))
f.write("\n".join(json.dumps(item) for item in content))

report = check_file(file)

Expand All @@ -147,7 +147,7 @@ def test_check_jsonl_missing_required_field(tmp_path: Path):
{"prompt": "Summarize the text."},
]
with file.open("w") as f:
f.write("\n".join([json.dumps(item) for item in content]))
f.write("\n".join(json.dumps(item) for item in content))

report = check_file(file)

Expand All @@ -166,7 +166,7 @@ def test_check_jsonl_inconsistent_dataset_format(tmp_path: Path):
{"text": "How are you?"}, # Missing 'messages'
]
with file.open("w") as f:
f.write("\n".join([json.dumps(item) for item in content]))
f.write("\n".join(json.dumps(item) for item in content))

report = check_file(file)

Expand All @@ -182,7 +182,7 @@ def test_check_jsonl_invalid_role(tmp_path: Path):
file = tmp_path / "invalid_role.jsonl"
content = [{"messages": [{"role": "invalid_role", "content": "Hi"}]}]
with file.open("w") as f:
f.write("\n".join([json.dumps(item) for item in content]))
f.write("\n".join(json.dumps(item) for item in content))

report = check_file(file)

Expand All @@ -202,7 +202,7 @@ def test_check_jsonl_non_alternating_roles(tmp_path: Path):
}
]
with file.open("w") as f:
f.write("\n".join([json.dumps(item) for item in content]))
f.write("\n".join(json.dumps(item) for item in content))

report = check_file(file)

Expand All @@ -215,7 +215,7 @@ def test_check_jsonl_invalid_value_type(tmp_path: Path):
file = tmp_path / "invalid_value_type.jsonl"
content = [{"text": 123}]
with file.open("w") as f:
f.write("\n".join([json.dumps(item) for item in content]))
f.write("\n".join(json.dumps(item) for item in content))

report = check_file(file)
assert not report["is_check_passed"]
Expand All @@ -233,7 +233,7 @@ def test_check_jsonl_missing_field_in_conversation(tmp_path: Path):
}
]
with file.open("w") as f:
f.write("\n".join([json.dumps(item) for item in content]))
f.write("\n".join(json.dumps(item) for item in content))

report = check_file(file)
assert not report["is_check_passed"]
Expand Down

0 comments on commit a400517

Please sign in to comment.