-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: support classification #95
Conversation
…into revision-bg
Reviewer's Guide by SourceryThis pull request introduces support for classification tasks in the existing machine learning pipeline. It includes significant changes to the model architecture, benchmarking process, and data handling. The changes are implemented across multiple files, with major updates to the core functionality in src/mattext/models/benchmark.py, src/mattext/main.py, src/mattext/models/finetune.py, src/mattext/models/predict.py, and src/mattext/models/score.py. New configuration files and data preparation scripts have also been added to support the classification tasks. File-Level Changes
Tips
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @n0w0f - I've reviewed your changes - here's some feedback:
Overall Comments:
- Consider refactoring common code patterns across the new classification and benchmark files to reduce duplication and improve maintainability.
- Improve consistency in the use of type hints and docstrings throughout the new code to enhance readability and maintainability.
- Review the configuration files for different models and representations, and consider creating a more modular structure to reduce repetition in the YAML files.
Here's what I looked at during the review
- 🟡 General issues: 4 issues found
- 🟢 Security: all looks good
- 🟢 Testing: all looks good
- 🟡 Complexity: 3 issues found
- 🟢 Documentation: all looks good
Help me be more useful! Please click 👍 or 👎 on each comment to tell me if it was helpful.
|
||
return prediction_df, prediction_ids | ||
|
||
def evaluate(self, true_labels: List[int]) -> dict: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggestion: Consider incorporating true labels into _prepare_datasets method
The evaluate method requires true_labels as a separate argument. Consider loading true labels along with the test data in the _prepare_datasets method for better encapsulation and consistency.
def evaluate(self, test_data: pd.DataFrame) -> dict:
true_labels = test_data['label'].tolist()
@sourcery-ai review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @n0w0f - I've reviewed your changes - here's some feedback:
Overall Comments:
- Consider refactoring the new classification classes to reduce code duplication with existing regression classes.
- Evaluate the possibility of separating classification logic more distinctly from the existing benchmark and inference code to improve maintainability.
Here's what I looked at during the review
- 🟡 General issues: 1 issue found
- 🟢 Security: all looks good
- 🟢 Testing: all looks good
- 🟡 Complexity: 2 issues found
- 🟢 Documentation: all looks good
Help me be more useful! Please click 👍 or 👎 on each comment to tell me if it was helpful.
@@ -22,6 +22,9 @@ def __init__(self): | |||
def run_task(self, run: list, task_cfg: DictConfig, local_rank=None) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
issue (complexity): Consider refactoring to reduce duplication and improve maintainability.
The recent changes have increased the complexity of the code, primarily due to the added conditional logic and duplication in the run_task
method. The introduction of the "classification"
condition has increased the number of branches, making the logic harder to follow. Additionally, the run_classification
method is very similar to run_benchmarking
, leading to duplicated logic that can complicate maintenance. To address these issues, consider refactoring to reduce duplication and improve maintainability. For example, you could use a helper method to consolidate common logic and a dictionary to map task names to methods, simplifying the run_task
logic and making it easier to extend or modify. This approach would centralize task execution logic, improving maintainability and reducing complexity.
src/mattext/models/benchmark.py
Outdated
name=exp_name, | ||
) | ||
fold_name = fold_key_namer(i) | ||
print("-------------------------") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
issue (code-quality): Extract duplicate code into method (extract-duplicate-method
)
Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
As you and Sourcery said, there is a lot of duplication, and I'm unsure if we should merge it in the current form. |
return None | ||
|
||
|
||
def process_entry_test_matbench(entry: List, timeout: int) -> dict: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should also try to revise functions like this one later - e.g. timeout
seems to not be used
|
||
batch_iterator = (data[i:i + batch_size] for i in range(0, len(data), batch_size)) | ||
|
||
for i, batch_data in enumerate(batch_iterator, start=1): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do you actually start at 1 and not 0?
Is there anything particular you have concerns about, @n0w0f ? |
I wanted to jsut confirm if the metrics are correctly computed. |
@kjappelbaum I agree. we need not merge this. I did not want to break the codebase hence the duplications I can clean up after the sprint |
i'd also consider moving to something like unsloth for fine-tuning as it is much faster |
Another point we had mentioned on Zulip was to add
|
@sourcery-ai review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @n0w0f - I've reviewed your changes and they look great!
Here's what I looked at during the review
- 🟡 General issues: 4 issues found
- 🟢 Security: all looks good
- 🟢 Testing: all looks good
- 🟡 Complexity: 1 issue found
- 🟢 Documentation: all looks good
Help me be more useful! Please click 👍 or 👎 on each comment to tell me if it was helpful.
print( | ||
f"Error occurred during inference for finetuned checkpoint '{exp_name}': {str(e)}" | ||
) | ||
if isinstance(e, (ValueError, TypeError)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggestion: Improve granularity of exception handling
Consider handling more specific exceptions and providing more detailed error messages or logging. This will make debugging easier and provide more informative feedback.
if isinstance(e, (ValueError, TypeError, RuntimeError, AttributeError)):
logger.error(f"Critical error during inference: {e}")
raise
elif isinstance(e, (FileNotFoundError, IOError)):
logger.warning(f"File-related error during inference: {e}")
else:
logger.exception("Unexpected error during inference")
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could be a different PR. logger is not used yet
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a valid point. If the logger isn't currently being used in the codebase, introducing it might be better suited for a separate PR to keep changes focused and manageable. For now, you could consider adding more specific exception handling without the logger, and then plan to integrate logging in a future update. This way, you can still improve the granularity of exception handling in this PR without introducing new dependencies. Let me know if you need any help with that!
else: | ||
print(f"Unknown task: {task}") | ||
|
||
def _run_experiment( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggestion: Break down _run_experiment into smaller methods
The _run_experiment
method is quite long and handles multiple responsibilities. Consider breaking it down into smaller, more focused methods to improve readability and maintainability.
def _run_experiment(self, task_cfg: DictConfig):
return self._execute_experiment_steps(task_cfg)
def _execute_experiment_steps(self, task_cfg: DictConfig):
# Move the existing content of _run_experiment here
# and break it down into smaller methods as needed
probabilities, columns=[f"class_{i}" for i in range(self.num_labels)] | ||
) | ||
|
||
def evaluate(self, true_labels: List[int]) -> dict: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggestion: Break down evaluate method into smaller functions
The evaluate
method in InferenceClassification
is quite long and performs multiple operations. Consider breaking it down into smaller, more focused methods for each evaluation metric or step.
def evaluate(self, true_labels: List[int]) -> dict:
predictions = self._get_predictions()
metrics = self._calculate_metrics(predictions, true_labels)
return metrics
def _get_predictions(self) -> np.ndarray:
predictions, _ = self.predict()
return np.argmax(predictions.values, axis=1)
def _calculate_metrics(self, predictions: np.ndarray, true_labels: List[int]) -> dict:
Refactored to incorporate the comments. can you @sourcery-ai review |
There is lot of duplications, so as to avoid any breakage at this moment.
might have to refactor later.
we could also choose to not merge it. but would be good to review the code.
Summary by Sourcery
Add support for classification tasks by introducing new classes and methods, refactor benchmarking and task execution logic for better modularity, and enhance model fine-tuning and inference processes with abstract base classes. Include new configuration files and scripts for data preparation and model setup.
New Features:
Enhancements:
Documentation:
Chores: