diff --git a/.gitignore b/.gitignore index 563a769..84fa22f 100644 --- a/.gitignore +++ b/.gitignore @@ -226,4 +226,6 @@ pip-selfcheck.json *.xlsx *.xlsm -*.ods \ No newline at end of file +*.ods + +SheetGPT diff --git a/.vscode/settings.json b/.vscode/settings.json index 754aed4..215c7ea 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,5 +1,6 @@ { "cSpell.words": [ + "flet", "kivy", "openai", "openpyxl", diff --git a/README.md b/README.md index 5a467ef..8fbb01a 100644 --- a/README.md +++ b/README.md @@ -1,16 +1,39 @@ # SheetGPT + + A visual tool for completing a specific column/row using the data from other columns/rows in a XLSX/XLSM file utilizing ChatGPT. ## Status Still in development, but can be used. +## Usage + +1. Go to releases, and download the latest executable suitable for your platform. +2. Launch the executable. +3. Choose an input file. +4. Choose an output file. +5. Choose the sheet you would like to use. +6. Enter the input column(s)/row(s), separated by a comma "," or a space " " character. (*i.e. `a,b,c` or `1 2 5`*) +7. Enter the input starting position. You must enter a row if you have entered a column in the previous step, vice-versa. +8. *(Optional) Enter the input ending position. You must enter a row if you have entered a column in the previous step, vice-versa. SheetGPT will stop once it reaches this position (inclusive).* +9. Enter the output column and row. +10. Choose the output placement strategy (place the result on the next row, or column). +11. Enter your OpenAI API key. You can create one [here](https://platform.openai.com/account/api-keys). +12. Choose the ChatGPT model you would like to use. +13. Enter a system prompt to fine-tune the model to your needs. You can also use the default system prompt. +14. Enter a prompt to be used with each request. You can use `$0`, `$1`... placeholders (*zero-indexed, starting from 0*) to insert the inputs you have specified on the 6th step above. +15. *(Optional) Enter a processing limit, enter `0` to remove the limit. Only the items which were processed explicitly by ChatGPT will count towards the limit. Cached results, or existing results (if you leave* **Skip the existing results** *enabled) will not count towards the limit.* +16. Press **Start processing** to start. This might take a while depending on the data volume and device specs. + ## Requirements -- Python 3.7/3.8 +- Python 3.7+ - pip - venv +- Flet +- ezpyi (for creating AppImage) ## Development @@ -59,3 +82,35 @@ Still in development, but can be used. ```bash python3 main.py ``` + + Or to use hot-reload + + ```bash + flet run -d main.py + ``` + +### Packaging + +#### Executable + +Executable will be created for your platform (for Linux if you run it on Linux, for Windows if you run it on Windows...). + +```bash +flet pack main.py --icon icon.png --name SheetGPT --add-data "modules:modules" +``` + +#### AppImage + +An AppImage will be created to be used with Linux. + +```bash +ezpyi -A -i icon.png main.py SheetGPT +``` + +## [LICENSE](https://github.com/recoskyler/sheetGPT/blob/main/LICENSE) + +[MIT License](https://github.com/recoskyler/sheetGPT/blob/main/LICENSE) + +## About + +Made by [recoskyler](https://github.com/recoskyler) - 2023 diff --git a/modules/assistant.py b/assistant.py similarity index 80% rename from modules/assistant.py rename to assistant.py index 963c8e8..0a4f6ff 100644 --- a/modules/assistant.py +++ b/assistant.py @@ -7,7 +7,6 @@ def get_formatted_prompt(prompt: str, input_values: object) -> str: formatted_prompt = prompt - index = int(search("(\$)([0-9]+)", formatted_prompt).group()[1:]) while search("(\$)([0-9]+)", formatted_prompt) != None: match = search("(\$)([0-9]+)", formatted_prompt).group() @@ -24,13 +23,8 @@ def get_formatted_prompt(prompt: str, input_values: object) -> str: return formatted_prompt -def get_answer(prompt: str, system_prompt: str, model: str, api_key: str, cache: object) -> tuple: +def get_answer(prompt: str, system_prompt: str, model: str, api_key: str) -> tuple: openai.api_key = api_key - prompt_hash = str(hash(prompt)) - - if prompt_hash in cache.keys(): - print("Using cached answer...") - return (cache[prompt_hash], 0) try: print("Getting answer from ChatGPT...") @@ -41,6 +35,4 @@ def get_answer(prompt: str, system_prompt: str, model: str, api_key: str, cache: print(e) return ("", 0) - cache[prompt_hash] = res - return (res, 1) diff --git a/modules/cell.py b/cell.py similarity index 97% rename from modules/cell.py rename to cell.py index 006cb92..4daf564 100644 --- a/modules/cell.py +++ b/cell.py @@ -40,7 +40,7 @@ def get_numeric_value(coordinate: str) -> int: elif coordinate.isnumeric(): return int(coordinate) else: - raise Exception("Invalid coordinate. Must be numeric (1), or alphabetic (A). Not both (A1)") + raise Exception("Invalid coordinate. Must be numeric (1), or alphabetic (A). Not both (A1): ", coordinate) def get_cell_value(worksheet, cell: str): if worksheet == None: diff --git a/icon.png b/icon.png new file mode 100644 index 0000000..f812848 Binary files /dev/null and b/icon.png differ diff --git a/main.py b/main.py index aa86bcb..c3a1c46 100644 --- a/main.py +++ b/main.py @@ -1,11 +1,14 @@ #!/usr/bin/python from os import name, system -from modules.spreadsheet import create_result_book, load_workbook -from modules.cell import set_cell_value, get_cell_value, get_inputs, get_next_column, get_numeric_value -from modules.assistant import get_formatted_prompt, get_answer -from modules.io import save_progress, save_and_close -from modules.user_input import * +from os.path import exists, isfile +from spreadsheet import create_output_book, load_input_workbook, generate_result_path +from cell import set_cell_value, get_cell_value, get_inputs, get_next_column, get_numeric_value +from assistant import get_formatted_prompt, get_answer +from save import save_progress, save_and_close +from re import split, search + +import flet as ft if __name__ != '__main__': print("Not supported as a module") @@ -14,17 +17,23 @@ # Constants SAVE_INTERVAL = 25 # Save after every 25 processed items +MODELS = ["gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "gpt-4-32k"] +DEFAULT_SYSTEM_PROMPT = "You are a helpful assistant. You give only the answer, without forming a sentence. If you are not sure, try guessing. If you are still unsure about the answer, output '?'. If you don't know the answer, or if you cannot give a correct answer output '?'." +INPUTS_REGEX = "(^([0-9]+)([, ]([0-9]+))*$)|(^([a-zA-Z]+)([, ]([a-zA-Z]+))*$)" # Variables +exit_app = False +processing = False + file_path = "" result_path = "" api_key = "" -model = "" +model = MODELS[1] sheet_name = "Sheet 1" -system_prompt = "" +system_prompt = DEFAULT_SYSTEM_PROMPT prompt = "" -result_row = 1 +result_row = "" result_col = "" inputs = "" workbook = None @@ -38,160 +47,1105 @@ processed = 0 until = "" skip = True +overwrite_output = False +process_task = None +ended = False +valid_fields = { + "input_file": False, + "output_file": False, + "inputs": False, + "input_pos": False, + "output_row": False, + "output_col": False, + "api_key": False, + "system_prompt": True, + "until": True, + "prompt": False, + "limit": True, + "sheet": False +} # Functions -def get_user_inputs(): +def reset_globals(): global file_path + global result_path global api_key global model global sheet_name global system_prompt global prompt - global result_path - global result_book - global result_sheet - global result_placement global result_row global result_col global inputs - global worksheet global workbook + global worksheet + global result_book + global result_sheet + global result_placement global input_pos + global cache global limit - global skip + global processed global until + global skip + global overwrite_output + global process_task + global ended + global valid_fields + global MODELS + global DEFAULT_SYSTEM_PROMPT - file_paths = ask_file_paths() + file_path = "" + result_path = "" + api_key = "" + model = MODELS[1] + sheet_name = "Sheet 1" + system_prompt = DEFAULT_SYSTEM_PROMPT + prompt = "" + result_row = "" + result_col = "" + inputs = "" + workbook = None + worksheet = None + result_book = None + result_sheet = None + result_placement = None + input_pos = "" + cache = dict() + limit = 0 + processed = 0 + until = "" + skip = True + overwrite_output = False + process_task = None + ended = False + valid_fields = { + "input_file": False, + "output_file": False, + "inputs": False, + "input_pos": False, + "output_row": False, + "output_col": False, + "api_key": False, + "system_prompt": True, + "until": True, + "prompt": False, + "limit": True, + "sheet": False + } - file_path = file_paths[0] - result_path = file_paths[1] - workbook = file_paths[2] - result_book = file_paths[3] +def clear(): + if name == 'nt': + _ = system('cls') + else: + _ = system('clear') - print("\nInput file: " + file_path) - print("\nOutput file: " + result_path) +# GUI - sheets = ask_worksheet(workbook, result_book) +def main(page: ft.Page): + global MODELS + global DEFAULT_SYSTEM_PROMPT + global model + global skip - worksheet = sheets[0] - result_sheet = sheets[1] + def process_item(current_pos, current_col, current_row, input_values): + global worksheet + global result_sheet + global prompt + global result_sheet + global result_placement + global inputs + global skip + global until + global ended + global processed + global exit_app - api_key = ask_api_key() - model = ask_model() - system_prompt = ask_system_prompt() - prompt = ask_prompt() - inputs = ask_inputs() - input_pos = ask_input_start(inputs) - result_col = ask_output_column() - result_row = ask_output_row() + if ended or exit_app: return False - print("\n\nStarting point for results: ") - print(result_col + str(result_row)) - print("\n") + input_hash = str(hash(frozenset(input_values))) + result_value = get_cell_value(worksheet=result_sheet, cell=(current_col + current_row)) - result_placement = ask_output_placement() - limit = ask_limit() - skip = ask_skip() - until = ask_until(input_pos) + if result_value == None or str(result_value).strip() == "" or not skip: + answer = "" + processed_count = 0 -def process_worksheet(): - global worksheet - global result_sheet - global input_pos - global prompt - global result_book - global result_sheet - global result_placement - global result_row - global result_col - global inputs - global limit - global processed - global skip - global until + if input_hash in cache.keys(): + print("Using cached answer...") + answer = cache[input_hash] + else: + current_prompt = get_formatted_prompt(prompt, input_values) - item_no = 1 - processed = 0 + print("Prompt: " + current_prompt) - print("Processing... Press CTRL + C to cancel") + res = get_answer(current_prompt, system_prompt, model, api_key) + answer = res[0] + processed_count = res[1] - if limit == 0: - print("\nWARNING: No limit set!\n") - else: - print("\nLimiting to " + str(limit) + " items...\n") + set_cell_value(worksheet=result_sheet, cell=(current_col + current_row), value=answer) - while limit == 0 or processed <= limit: - if processed != 0 and processed % SAVE_INTERVAL == 0: - save_progress(result_book, result_path) + print("Answer: " + answer) - if until != "" and get_numeric_value(input_pos) > get_numeric_value(until): break + processed = processed + processed_count + else: + print("Already has result. Skipping...") - print("Processing item " + str(processed) + "/" + str(item_no) + " | Position: " + input_pos + "...") + if input_hash not in cache.keys(): + print("Caching answer...") + cache[input_hash] = str(result_value).strip() - input_values = get_inputs(worksheet, inputs, input_pos) + return True - if len(input_values) == 0: break + def process_worksheet(): + global worksheet + global result_sheet + global input_pos + global prompt + global result_book + global result_sheet + global result_placement + global result_row + global result_col + global inputs + global limit + global skip + global until + global processed + global processing + global exit_app - current_prompt = get_formatted_prompt(prompt, input_values) - result_value = get_cell_value(worksheet=result_sheet, cell=(result_col + result_row)) + if processing or exit_app: return - if result_value == None or str(result_value).strip() == "" or not skip: - print("Prompt: " + current_prompt) + item_no = 1 + processed = 0 + current_pos = input_pos + current_col = result_col + current_row = result_row + processing = True - res = get_answer(current_prompt, system_prompt, model, api_key, cache) - answer = res[0] - processed_count = res[1] + print("Processing... Press CTRL + C to cancel") - set_cell_value(worksheet=result_sheet, cell=(result_col + result_row), value=answer) + if limit == 0: + print("\nWARNING: No limit set!\n") + else: + print("\nLimiting to " + str(limit) + " items...\n") - print("Answer: " + answer) + while limit == 0 or processed <= limit: + if processed != 0 and processed % SAVE_INTERVAL == 0: + save_progress(result_book, result_path) - processed = processed + processed_count + if until != "" and get_numeric_value(current_pos) > get_numeric_value(until): break + + input_values = get_inputs(worksheet, inputs, current_pos) + + if len(input_values) == 0: break + + print("Processing item " + str(processed) + "/" + str(item_no) + " | Position: " + current_pos + "...") + + process_message.value = "Processing item " + str(processed) + "/" + str(item_no) + " | Position: " + current_pos + "..." + page.update() + + if not process_item(current_pos, current_col, current_row, input_values): break + + if result_placement == "r": # Place next result on the next row + current_row = str(int(current_row) + 1) + else: # Place next result on the next column + current_col = get_next_column(current_col) + + if current_pos.isnumeric(): + current_pos = str(int(current_pos) + 1) + else: + current_pos = get_next_column(current_pos) + + item_no = item_no + 1 + + processing = False + + print("Done!") + + def stop_processing(e): + global ended + + ended = True + process_message.value = "Stopping..." + + page.update() + + def handle_window_event(e): + global exit_app + global processing + + if e.data == "close": + if not processing: + page.window_destroy() + return + + exit_app = True + stop_processing("") + + page.title = "SheetGPT" + page.scroll = True + page.window_width = 790 + page.window_height = 600 + page.window_max_width = 790 + page.window_min_width = 510 + page.window_min_height = 300 + page.window_prevent_close = True + page.on_window_event = handle_window_event + + input_text = ft.Text( + "No input file selected", + overflow=ft.TextOverflow.ELLIPSIS, + max_lines=1, + col={"md": 9} + ) + + output_text = ft.Text( + "No output file selected", + overflow=ft.TextOverflow.ELLIPSIS, + max_lines=1, + col={"md": 9} + ) + + input_error = ft.Text("", color=ft.colors.RED, visible=False) + output_error = ft.Text("", color=ft.colors.RED, visible=False) + + input_progress = ft.Row( + [ft.ProgressRing(width=16, height=16)], + col={"md": 9}, + visible=False + ) + + output_progress = ft.Row( + [ft.ProgressRing(width=16, height=16)], + col={"md": 9}, + visible=False + ) + + process_progress = ft.ProgressRing(visible=False) + + def check_validity(): + global valid_fields + global result_book + global result_sheet + global workbook + global worksheet + global sheet_name + global result_path + global file_path + + valid = True + + # print("Checking validity ======") + + for item in valid_fields: + if not valid_fields[item]: + valid = False + # print("Not valid: ", item) + + if result_book == None or workbook == None or worksheet == None or result_sheet == None or sheet_name == "" or file_path == "" or result_path == "": + valid = False + + process_button.disabled = not valid + + page.update() + + def open_url(e): + page.launch_url(e.data) + + def on_choose_file(e: ft.FilePickerResultEvent): + global file_path + global workbook + global worksheet + global result_path + global result_book + global result_sheet + global valid_fields + global sheet_name + + if e.files == None or len(e.files) != 1: + input_progress.visible = False + input_text.visible = True + + page.update() + check_validity() + + return + + workbook = None + worksheet = None + result_book = None + result_sheet = None + sheet_name = "" + file_path = e.files[0].path + result_path = "" + valid_fields["output_file"] = False + valid_fields["input_file"] = False + + input_error.visible = False + input_text.visible = False + input_text.value = "No input file selected" + + output_button.disabled = True + output_error.visible = False + output_text.visible = False + output_text.value = "No output file selected" + + sheet_dropdown.value = None + sheet_dropdown.options = [] + sheet_dropdown.disabled = True + + try: + workbook = load_input_workbook(file_path) + input_text.value = file_path + + options = [] + + for item in workbook.sheetnames: + options.append(ft.dropdown.Option(item)) + + sheet_dropdown.options = options + sheet_dropdown.disabled = False + + if len(options) > 0: + sheet_name = workbook.sheetnames[0] + sheet_dropdown.value = sheet_name + worksheet = workbook[sheet_name] + + valid_fields["input_file"] = True + valid_fields["sheet"] = True + else: + sheet_dropdown.error_text = "No sheets found. Please use another file" + valid_fields["sheet"] = False + except Exception as e: + input_error.visible = True + input_error.value = str(e) + + input_text.visible = True + input_progress.visible = False + output_button.disabled = False + output_text.visible = True + + page.update() + check_validity() + + def on_choose_save(e: ft.FilePickerResultEvent): + global result_path + global result_book + global valid_fields + global sheet_name + global result_sheet + global file_path + + if e.path == None: + output_progress.visible = False + output_text.visible = True + + page.update() + check_validity() + + return + + result_book = None + result_sheet = None + result_path = e.path + + output_error.visible = False + output_text.visible = False + output_text.value = "No output file selected" + valid_fields["output_file"] = False + + try: + result_book = create_output_book(file_path, result_path, True) + output_text.value = result_path + result_sheet = result_book[sheet_name] + valid_fields["output_file"] = True + except Exception as e: + output_error.visible = True + output_error.value = str(e) + + output_text.visible = True + output_progress.visible = False + + page.update() + check_validity() + + file_picker = ft.FilePicker(on_result=on_choose_file) + save_picker = ft.FilePicker(on_result=on_choose_save) + + page.overlay.append(file_picker) + page.overlay.append(save_picker) + + def on_input_button_clicked(e): + input_text.visible = False + input_progress.visible = True + + page.update() + check_validity() + + file_picker.pick_files( + allow_multiple=False, + dialog_title="Choose input file", + allowed_extensions=["xlsx", "xlsm", "xltx", "xltm"], + file_type=ft.FilePickerFileType.CUSTOM + ) + + input_button = ft.FilledButton( + "Choose input file", + on_click=on_input_button_clicked, + col={"md": 3} + ) + + def on_output_button_clicked(e): + output_text.visible = False + output_progress.visible = True + + page.update() + check_validity() + + save_picker.save_file( + dialog_title="Choose output file", + allowed_extensions=["xlsx", "xlsm", "xltx", "xltm"], + file_type=ft.FilePickerFileType.CUSTOM, + file_name=generate_result_path(file_path) + ), + + output_button = ft.OutlinedButton( + "Choose output file", + on_click=on_output_button_clicked, + disabled=True, + col={"md": 3} + ) + + def sheet_changed(e): + global sheet_name + global worksheet + global workbook + global result_book + global result_sheet + + sheet_name = sheet_dropdown.value + worksheet = workbook[sheet_name] + result_sheet = result_book[sheet_name] + + sheet_dropdown = ft.Dropdown( + label="Sheet *", + on_change=sheet_changed, + disabled=True, + col={"md": 4} + ) + + def inputs_changed(e): + global inputs + global until + global input_pos + global INPUTS_REGEX + global valid_fields + + inputs = [] + input_pos = "" + until = "" + input_pos_field.disabled = True + until_field.disabled = True + input_pos_field.value = "" + until_field.value = "" + valid_fields["inputs"] = False + + if e.control.value != "" and search(INPUTS_REGEX, e.control.value) != None: + inputs_field.error_text = "" + inputs = split(",| ", e.control.value.upper()) + input_pos_field.disabled = False + valid_fields["inputs"] = True else: - print("Already has result. Skipping...") + inputs_field.error_text = "You must choose either row(s) (i.e. '2' or '4,2,0,6,9'), or column(s) (i.e. 'd' or 'p,r,n,d'). Not both (i.e. '1,a,c,4')." - prompt_hash = str(hash(current_prompt)) + page.update() + check_validity() - if prompt_hash not in cache.keys(): - print("Caching answer...") - cache[prompt_hash] = str(result_value).strip() + inputs_field = ft.TextField( + label="Input columns or rows (comma ',' or space ' ' separated) *", + on_change=inputs_changed, + max_length=500, + multiline=False, + col={"md": 8} + ) + + def input_pos_changed(e): + global input_pos + global inputs + global until + global valid_fields - if result_placement == "r": # Place next result on the next row - result_row = str(int(result_row) + 1) - else: # Place next result on the next column - result_col = get_next_column(result_col) + input_pos = "" + input_pos_field.error_text = "" + until_field.error_text = "" + until_field.disabled = True + valid_fields["input_pos"] = False - if input_pos.isnumeric(): - input_pos = str(int(input_pos) + 1) + if inputs[0].isnumeric(): + input_pos_field.error_text = "Please enter a valid column (must be a character)" else: - input_pos = get_next_column(input_pos) + input_pos_field.error_text = "Please enter a valid row (must be a number)" - item_no = item_no + 1 + if inputs[0].isnumeric() and e.control.value.strip().isalpha(): + input_pos = e.control.value.strip().upper() + until_field.disabled = False + input_pos_field.error_text = "" + valid_fields["input_pos"] = True - print("Done!") + if inputs[0].isalpha() and e.control.value.strip().isnumeric(): + input_pos = e.control.value.strip() + until_field.disabled = False + input_pos_field.error_text = "" + valid_fields["input_pos"] = True -def clear(): - if name == 'nt': - _ = system('cls') - else: - _ = system('clear') + if input_pos != "" and until != "" and input_pos.isalpha() and until.isnumeric(): + until_field.error_text = "Please enter a valid column (must be a character)" + valid_fields["until"] = False + elif input_pos != "" and until != "" and input_pos.isnumeric() and until.isalpha(): + until_field.error_text = "Please enter a valid row (must be a number)" + valid_fields["until"] = False + + if input_pos.isnumeric() and until.isnumeric() and int(input_pos) >= int(until): + until_field.error_text = "Ending row cannot be equal or smaller than the starting row" + valid_fields["until"] = False + + if input_pos.isalpha() and until.isalpha() and get_numeric_value(input_pos) >= get_numeric_value(until): + until_field.error_text = "Ending column cannot be equal or before the starting column" + valid_fields["until"] = False + + page.update() + check_validity() + + input_pos_field = ft.TextField( + label="Input starting position *", + on_change=input_pos_changed, + max_length=10, + col={"md": 6}, + multiline=False, + disabled=True + ) + + def until_changed(e): + global input_pos + global inputs + global until + global valid_fields + + until = "" + valid_fields["until"] = False + + if input_pos.isalpha(): + until_field.error_text = "Please enter a valid column (must be a character)" + else: + until_field.error_text = "Please enter a valid row (must be a number)" + + if e.control.value.strip() == "": + valid_fields["until"] = True + until_field.error_text = "" + + if input_pos.isalpha() and e.control.value.strip().isalpha(): + until = e.control.value.strip().upper() + until_field.error_text = "" + valid_fields["until"] = True + + if input_pos.isnumeric() and e.control.value.strip().isnumeric(): + until = e.control.value.strip().upper() + until_field.error_text = "" + valid_fields["until"] = True + + if until != "" and input_pos.isnumeric() and until.isnumeric() and get_numeric_value(input_pos) >= get_numeric_value(until): + until = "" + until_field.error_text = "Ending row cannot be equal or smaller than the starting row" + + if until != "" and input_pos.isalpha() and until.isalpha() and get_numeric_value(input_pos) >= get_numeric_value(until): + until = "" + until_field.error_text = "Ending column cannot be equal or before the starting column" + + page.update() + check_validity() + + until_field = ft.TextField( + label="Input ending", + on_change=until_changed, + max_length=10, + col={"md": 6}, + multiline=False, + disabled=True + ) + + def output_row_changed(e): + global result_row + global valid_fields + + result_row = e.control.value.strip() + output_row_field.error_text = "" + valid_fields["output_row"] = True + + if result_row == "" or not result_row.isnumeric(): + output_row_field.error_text = "Please enter a valid row" + result_row = "" + valid_fields["output_row"] = False + + page.update() + check_validity() + + output_row_field = ft.TextField( + label="Output row *", + on_change=output_row_changed, + max_length=10, + col={"md": 4}, + multiline=False + ) + + def output_col_changed(e): + global result_col + global valid_fields + + result_col = e.control.value.strip() + output_col_field.error_text = "" + valid_fields["output_col"] = True + + if result_col == "" or not result_col.isalpha(): + output_col_field.error_text = "Please enter a valid column" + result_col = "" + valid_fields["output_col"] = False + + page.update() + check_validity() + + output_col_field = ft.TextField( + label="Output column *", + on_change=output_col_changed, + max_length=10, + col={"md": 4}, + multiline=False + ) + + def output_placement_changed(e): + global result_placement + + if output_placement_dropdown.value == "Place on the next row": + result_placement = "r" + elif output_placement_dropdown.value == "Place on the next column": + result_placement = "c" + + output_placement_dropdown = ft.Dropdown( + label="Output placement *", + on_change=output_placement_changed, + options=[ + ft.dropdown.Option("Place on the next row"), + ft.dropdown.Option("Place on the next column") + ], + col={"md": 4}, + value="Place on the next row" + ) + + def api_key_changed(e): + global api_key + global valid_fields + + api_key = e.control.value.strip() + valid_fields["api_key"] = False + api_key_field.error_text = "" + + if api_key == "": + api_key_field.error_text = "Please enter a valid API key" + else: + valid_fields["api_key"] = True + + page.update() + check_validity() + + api_key_field = ft.TextField( + label="OpenAI API Key *", + on_change=api_key_changed, + expand=False, + max_length=64, + col={"md": 8}, + multiline=False + ) + + def model_changed(e): + global model + model = model_dropdown.value + + dropdown_options = [] + + for item in MODELS: + dropdown_options.append(ft.dropdown.Option(item)) + + model_dropdown = ft.Dropdown( + label="ChatGPT Model *", + on_change=model_changed, + options=dropdown_options, + col={"md": 4}, + value=model + ) + + api_key_info = ft.Markdown( + "You can get your OpenAI API key [here](https://platform.openai.com/account/api-keys)", + on_tap_link=open_url + ) + + def system_prompt_changed(e): + global system_prompt + global valid_fields + + system_prompt = e.control.value.strip() + valid_fields["system_prompt"] = False + system_prompt_field.error_text = "" + + if system_prompt == "": + system_prompt_field.error_text = "Please enter a valid system prompt" + else: + valid_fields["system_prompt"] = True + + page.update() + check_validity() + + system_prompt_field = ft.TextField( + label="System prompt *", + on_change=system_prompt_changed, + max_length=1024, + multiline=True, + value=DEFAULT_SYSTEM_PROMPT + ) + + def prompt_changed(e): + global prompt + global valid_fields + + prompt = e.control.value.strip() + valid_fields["prompt"] = False + prompt_field.error_text = "" + + if prompt == "": + prompt_field.error_text = "Please enter a valid prompt" + else: + valid_fields["prompt"] = True + + page.update() + check_validity() + + prompt_field = ft.TextField( + label="Prompt *", + on_change=prompt_changed, + max_length=1024, + multiline=True + ) + + prompt_info_md = "Enter a prompt to be used with each request. Use zero-based `$0` placeholder(s) to insert inputs.\n### Example\nAssuming you have entered `a,b` as the inputs, and `2` as the starting row:\n\n_\"Bla bla $0 and $0, $1\"_\n\nwill be converted to\n\n_\"Bla bla and , \"_" + + prompt_info = ft.Markdown(prompt_info_md) + + def limit_changed(e): + global limit + global valid_fields + + limit = 0 + limit_field.error_text = "" + valid_fields["limit"] = False + + if e.control.value.strip() == "" or not e.control.value.strip().isnumeric() or int(e.control.value.strip()) < 0: + limit_field.error_text = "Please enter a valid limit" + else: + limit = int(e.control.value.strip()) + valid_fields["limit"] = True + + page.update() + check_validity() + + limit_field = ft.TextField( + label="Processing limit *", + on_change=limit_changed, + max_length=10, + col={"md": 3}, + multiline=False, + value="0" + ) + + def on_skip_changed(e): + global skip + skip = skip_switch.value + + skip_switch = ft.Switch( + label="Skip the existing results", + value=skip, + on_change=on_skip_changed + ) + + def start_processing(e): + global file_path + global result_path + global api_key + global model + global sheet_name + global system_prompt + global prompt + global result_row + global result_col + global inputs + global workbook + global worksheet + global result_book + global result_sheet + global result_placement + global input_pos + global cache + global limit + global processed + global until + global skip + global overwrite_output + global process_task + global ended + global valid_fields + global MODELS + global DEFAULT_SYSTEM_PROMPT + global exit_app + + process_error.value = "" + process_message.value = "Processing..." + ended = False + process_spinner.visible = True + stop_button.disabled = False + stop_button.visible = True + process_button.disabled = True + process_button.visible = False + skip_switch.disabled = True + limit_field.disabled = True + api_key_field.disabled = True + prompt_field.disabled = True + system_prompt_field.disabled = True + until_field.disabled = True + inputs_field.disabled = True + input_pos_field.disabled = True + output_col_field.disabled = True + output_row_field.disabled = True + output_placement_dropdown.disabled = True + model_dropdown.disabled = True + sheet_dropdown.disabled = True + input_button.disabled = True + output_button.disabled = True + + page.update() + + try: + process_worksheet() + except Exception as e: + print("Failed processing") + print(e) + process_error.value = "An error occurred" + + process_message.value = "Saving..." + + page.update() + + process_message.value = "Failed" + + try: + save_and_close(workbook, result_book, result_path) + process_message.value = "Saved to " + result_path + except Exception as e: + print("Failed saving") + print(e) + process_error.value = "An error occurred while saving" + + process_spinner.visible = False + stop_button.disabled = True + stop_button.visible = False + process_button.disabled = False + process_button.visible = True + skip_switch.disabled = False + limit_field.disabled = False + api_key_field.disabled = False + prompt_field.disabled = False + system_prompt_field.disabled = False + until_field.disabled = True + inputs_field.disabled = False + input_pos_field.disabled = True + output_col_field.disabled = False + output_row_field.disabled = False + output_placement_dropdown.disabled = False + model_dropdown.disabled = False + sheet_dropdown.disabled = True + input_button.disabled = False + output_button.disabled = True + + reset_globals() + + limit_field.value = limit + api_key_field.value = api_key + prompt_field.value = prompt + system_prompt_field.value = DEFAULT_SYSTEM_PROMPT + until_field.value = until + inputs_field.value = ",".join(inputs) + input_pos_field.value = input_pos + output_col_field.value = result_col + output_row_field.value = result_row + output_placement_dropdown.value = "Place on the next row" + model_dropdown.value = model + sheet_dropdown.value = None + input_text.value = "No input file selected" + output_text.value = "No output file selected" + skip_switch.value = skip + + page.update() + check_validity() + + if exit_app: page.window_destroy() + + process_spinner = ft.Row( + [ft.ProgressRing(width=16, height=16)], + col={"md": 1}, + alignment=ft.MainAxisAlignment.END, + visible=False + ) + + process_message = ft.Text( + "(*) Required fields", + overflow=ft.TextOverflow.ELLIPSIS, + max_lines=1 + ) + + process_button = ft.ElevatedButton( + "Start processing", + disabled=True, + on_click=start_processing, + col={"md": 3} + ) + + stop_button = ft.OutlinedButton( + "Stop processing", + disabled=True, + visible=False, + on_click=stop_processing, + col={"md": 3} + ) -# Program + process_error = ft.Text( + "", + color=ft.colors.RED + ) -try: - clear() + root_column = ft.Column( + [ + ft.Column( + controls=[ + ft.Text( + "SheetGPT", + style=ft.TextThemeStyle.HEADLINE_MEDIUM, + text_align=ft.TextAlign.CENTER + ), + ], + horizontal_alignment=ft.CrossAxisAlignment.STRETCH, + ), + ft.Text("File configuration", style=ft.TextThemeStyle.TITLE_LARGE), + ft.Text("Input file", style=ft.TextThemeStyle.TITLE_MEDIUM), + ft.ResponsiveRow( + [ + input_button, + input_progress, + input_text + ], + spacing=10, + vertical_alignment=ft.CrossAxisAlignment.CENTER + ), + input_error, + ft.Text("Output file", style=ft.TextThemeStyle.TITLE_MEDIUM), + ft.ResponsiveRow( + [ + output_button, + output_progress, + output_text + ], + spacing=10, + vertical_alignment=ft.CrossAxisAlignment.CENTER + ), + output_error, + ft.Divider(), + ft.Text("Input configuration", style=ft.TextThemeStyle.TITLE_LARGE), + ft.ResponsiveRow( + [ + sheet_dropdown, + inputs_field + ], + vertical_alignment=ft.CrossAxisAlignment.START, + spacing=10 + ), + ft.ResponsiveRow( + [ + input_pos_field, + until_field + ], + vertical_alignment=ft.CrossAxisAlignment.START, + spacing=10 + ), + ft.Divider(), + ft.Text("Output configuration", style=ft.TextThemeStyle.TITLE_LARGE), + ft.ResponsiveRow( + [ + output_col_field, + output_row_field, + output_placement_dropdown + ], + vertical_alignment=ft.CrossAxisAlignment.START, + spacing=10 + ), + ft.Divider(), + ft.Text("ChatGPT configuration", style=ft.TextThemeStyle.TITLE_LARGE), + api_key_info, + ft.ResponsiveRow( + [ + model_dropdown, + api_key_field + ], + vertical_alignment=ft.CrossAxisAlignment.START, + spacing=10 + ), + system_prompt_field, + ft.Divider(), + ft.Text("Prompt configuration", style=ft.TextThemeStyle.TITLE_LARGE), + prompt_info, + prompt_field, + ft.Divider(), + ft.Text("Options", style=ft.TextThemeStyle.TITLE_LARGE), + skip_switch, + ft.Text("Results will be overwritten if skipping is disabled"), + ft.Divider(), + ft.Text("Processing limit", style=ft.TextThemeStyle.TITLE_MEDIUM), + ft.Text("Enter a limit for number of items to be processed by ChatGPT. Cached or existing results do not count towards the limit. (type 0 to remove limit)"), + limit_field, + ft.Divider(), + ft.ResponsiveRow( + [ + stop_button, + process_button, + ft.Row( + [process_message], + col={"md": 8}, + vertical_alignment=ft.CrossAxisAlignment.CENTER + ), + process_spinner + ], + spacing=10, + vertical_alignment=ft.CrossAxisAlignment.CENTER + ), + process_error + ], + spacing=15, + run_spacing=15 + ) - print("\n\nSheetGPT v0.1.0 by recoskyler\n\n") + root_container = ft.SafeArea( + content=ft.Container( + padding=5, + content=root_column + ), + ) - get_user_inputs() - process_worksheet() -except KeyboardInterrupt: - print("\n\nCtrl+C detected. Exiting...\n\n") -except Exception as e: - print("\n\nAn error occurred\n\n") - print(e) -finally: - save_and_close(workbook, result_book, result_path) + page.add(root_container) - print("\n======= Adios, Cowboy =======\n") +ft.app(target=main) diff --git a/modules/__init__.py b/modules/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/modules/io.py b/modules/io.py deleted file mode 100644 index 2018cca..0000000 --- a/modules/io.py +++ /dev/null @@ -1,40 +0,0 @@ -#!/usr/bin/python - -__all__ = ["save_progress", "save_and_close"] - -def save_progress(workbook, path: str): - if workbook == None: - raise TypeError("workbook must not be NoneType") - - try: - print("\nSaving file...") - - workbook.save(path) - - print("\nResults saved to: " + path) - except Exception as e: - print("\nAn error occurred while saving\n\n") - print(e) - -def save_and_close(workbook, result_book, result_path: str): - try: - if workbook != None: - workbook.close() - except Exception as e: - print("\nFailed to close workbook\n") - print(e) - - - try: - if result_book != None: - print("\nSaving file...") - - result_book.save(result_path) - result_book.close() - - print("\nResults saved to: " + result_path) - else: - print("\n\nNot saving...") - except Exception as e: - print("\nFailed to save and close result book\n") - print(e) diff --git a/modules/spreadsheet.py b/modules/spreadsheet.py deleted file mode 100644 index 99917ba..0000000 --- a/modules/spreadsheet.py +++ /dev/null @@ -1,79 +0,0 @@ -#!/usr/bin/python - -from openpyxl import load_workbook -from os.path import isfile, exists -from os import remove -from shutil import copyfile - -__all__ = ["load_input_workbook", "create_result_book"] - -def load_input_workbook(file_path: str): - wb = None - - if not exists(file_path) or not isfile(file_path): - print("File does not exist or invalid\n") - return False - - try: - print("\nLoading workbook...") - - wb = load_workbook(file_path, read_only=True) - - print("Loaded workbook\n") - except Exception as e: - print(e) - print("Failed to load workbook. Is the file format correct (.xlsx or .xlsm)?\n") - return False - - return wb - -def create_result_book(file_path: str, result_path: str): - overwrite = False - result_book = None - - if exists(result_path) and isfile(result_path): - while True: - res = input("Output worksheet already exists. Would you like to delete the old worksheet and create it again? (y/n) ").strip().lower() - - if res == "y": - overwrite = True - - print("Deleting old output file...\n") - - try: - remove(result_path) - except: - print("Failed to delete old output worksheet.\n") - return False - - break - elif res == "n": - print("Using the old file...") - break - - print("\nPlease enter a valid choice.\n") - - try: - if not exists(result_path) or not isfile(result_path) or overwrite: - print("Creating output worksheet...") - - copyfile(file_path, result_path) - - result_book = load_workbook(result_path) - - print("Loaded output worksheet: " + result_path) - except Exception as e: - print("\nFailed to create output worksheet.\n") - print(e) - - if exists(result_path) and isfile(result_path): - print("Removing invalid output file...") - - try: - remove(result_path) - except: - print("Failed to remove invalid output file") - - return False - - return result_book \ No newline at end of file diff --git a/modules/user_input.py b/modules/user_input.py deleted file mode 100644 index 3a1eb79..0000000 --- a/modules/user_input.py +++ /dev/null @@ -1,310 +0,0 @@ -#!/usr/bin/python - -from modules.spreadsheet import create_result_book, load_input_workbook -from modules.cell import get_numeric_value -from os.path import splitext -from re import split, search - -__all__ = [ - "ask_file_paths", - "ask_worksheet", - "ask_api_key", - "ask_model", - "ask_system_prompt", - "ask_prompt", - "ask_inputs", - "ask_input_start", - "ask_output_row", - "ask_output_column", - "ask_skip", - "ask_output_placement", - "ask_until", - "ask_limit" -] - -# Constants - -DEFAULT_SYSTEM_PROMPT = "You are a helpful assistant. You give only the answer, without forming a sentence. If you are not sure, try guessing. If you are still unsure about the answer, output '?'. If you don't know the answer, or if you cannot give a correct answer output '?'." - -MODELS = ["gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "gpt-4-32k"] - -INPUTS_REGEX = "(^([0-9]+)([, ]([0-9]+))*$)|(^([a-zA-Z]+)([, ]([a-zA-Z]+))*$)" - -# Functions - -def ask_file_paths() -> tuple: - file_path = "" - result_path = "" - workbook = None - result_book = None - - while True: - file_path = input("\nEnter the full or relative file path (.xlsx/.xlsm): ").strip() - - workbook = load_input_workbook(file_path) - - if workbook != None and workbook != False: - break - - while True: - result_path = input("\nEnter the full or relative output file path (.xlsx/.xlsm) (press ENTER to use default): ").strip() - - if result_path == "": - pre, ext = splitext(file_path) - result_path = pre + "_output" + ext - - result_book = create_result_book(file_path, result_path) - - if result_book != None and result_book != False: - break - - return (file_path, result_path, workbook, result_book) - -def ask_worksheet(workbook, result_book) -> tuple: - if workbook == None: - raise TypeError("workbook must not be NoneType") - - if result_book == None: - raise TypeError("result_book must not be NoneType") - - if len(workbook.sheetnames) == 0: - print("\nNo sheets found. Exiting...\n") - exit(0) - - while True: - print("\nSheets in the worksheet:\n") - - index = 1 - - for sheet in workbook.sheetnames: - print(str(index) + ") " + sheet) - index = index + 1 - - choice = input("\nChoose sheet (press ENTER to select the first sheet): ").strip() - - if choice == "": - sheet_name = workbook.sheetnames[0] - - break - elif choice.isnumeric() and int(choice) >= 1 and int(choice) <= len(workbook.sheetnames): - sheet_name = workbook.sheetnames[int(choice) - 1] - - print("\nUsing the sheet with name: ") - print(sheet_name) - print("\n") - - break - - print("\nPlease choose a valid sheet") - - return(workbook[sheet_name], result_book[sheet_name]) - -def ask_api_key() -> str: - api_key = "" - - while True: - api_key = input("\nEnter your OpenAI API key: ").strip() - - if api_key != "": break - - print("\nPlease enter a valid API key\n") - - return api_key - -def ask_model() -> str: - global MODELS - - model = "" - - while True: - choice = input("\nChoose a ChatGPT model:\n\n1) gpt-3.5-turbo\n2) gpt-3.5-turbo-16k (default)\n3) gpt-4\n4) gpt-4-32k\n\nChoice (1-4, ENTER for default): ").strip() - - if choice == "": - model = MODELS[1] - print("\nUsing the default model\n") - break - - if choice.isnumeric() and int(choice) > 0 and int(choice) < 5: - model = MODELS[int(choice) - 1] - print("\nUsing the model: ") - print(model) - break - - print("\nPlease choose a valid model\n") - - return model - -def ask_system_prompt() -> str: - global DEFAULT_SYSTEM_PROMPT - - system_prompt = input("\nEnter a system prompt to fine-tune the response (press ENTER for default):\n\n").strip() - - if system_prompt == "": - system_prompt = DEFAULT_SYSTEM_PROMPT - - print("\nUsing the default prompt:\n") - print(system_prompt) - print("\n") - - return system_prompt - -def ask_prompt() -> str: - prompt = "" - - while True: - prompt = input("\nEnter a prompt to be used with each request. Use zero-based '$0' placeholder(s) to insert inputs (i.e. Assuming you have entered 'a,b' as the inputs, and '2' as the starting row: 'Something $0 and $0, $1' will be converted to 'Something and , ')\n\n").strip() - - if prompt != "": - break - - print("\nPlease enter a valid prompt\n") - - return prompt - -def ask_inputs() -> object: - global INPUTS_REGEX - - inputs = [] - - while True: - inputs = input("\nEnter one or more input columns (A-Z)/rows (1-9) separated by a space ('1 3') or a comma ('1,3') character:\n\n").strip().strip(",") - - if inputs != "" and search(INPUTS_REGEX, inputs) != None: - inputs = split(",| ", inputs.upper()) - break - - print("\nPlease enter a valid input. You must choose either row(s) (i.e. '2' or '4,2,0,6,9'), or column(s) (i.e. 'd' or 'p,r,n,d'). Not both (i.e. '1,a,c,4').\n") - - return inputs - -def ask_input_start(inputs: object) -> str: - if not hasattr(inputs, "__len__"): - raise TypeError("inputs must be an array") - - input_pos = "" - - while True: - question = "\nEnter the starting row for the inputs (i.e. '2'): " - - if inputs[0].isnumeric(): - question = "\nEnter the starting column for the inputs (i.e. 'C'): " - - choice = input(question).strip() - - if inputs[0].isnumeric() and choice.isalpha(): - input_pos = choice.upper() - break - - if inputs[0].isalpha() and choice.isnumeric(): - input_pos = choice - break - - if inputs[0].isnumeric(): - print("\nPlease enter a valid column (must be a character)\n") - else: - print("\nPlease enter a valid row (must be a number)\n") - - return input_pos - -def ask_output_row() -> str: - result_row = "" - - while True: - result_row = input("\nEnter the starting row for the results (i.e. '2'): ").strip() - - if result_row != "" and result_row.isnumeric(): - result_row = result_row - break - - print("\nPlease enter a valid row.\n") - - return result_row - -def ask_output_column() -> str: - result_col = "" - - while True: - result_col = input("\nEnter the starting column for the results (i.e. 'B'): ").strip().upper() - - if result_col != "" and result_col.isalpha(): - break - - print("\nPlease enter a valid column.\n") - - return result_col - -def ask_skip() -> bool: - while True: - choice = input("\nShould the existing results be skipped? Results will be overwritten if skipping is disabled. (y/n) ").strip().lower() - - if choice == "y": - return True - elif choice == "n": - return False - - print("\nPlease enter a valid choice.\n") - -def ask_output_placement() -> str: - choice = "" - - while True: - choice = input("\nShould the result be output to the next row, or column? (type 'r' for next row, 'c' for next column) ").strip().lower() - - if choice == "r" or choice == "c": break - - print("\nPlease enter a valid choice ('r' or 'c')\n") - - return choice.lower() - -def ask_until(input_pos: str) -> str: - if not input_pos.isalpha() and not input_pos.isnumeric(): - raise Exception("input_pos must be either numeric or alphabetic. not both") - - question = "\nEnter the ending row for the inputs (inclusive, i.e. '2', press ENTER to continue until empty input): " - until = "" - - if input_pos.isalpha(): - question = "\nEnter the ending column for the inputs (inclusive, i.e. 'C', press ENTER to continue until empty input): " - - while True: - choice = input(question).strip().upper() - - if choice == "": - return "" - - if input_pos.isnumeric() and choice.isnumeric() and get_numeric_value(input_pos) >= int(choice): - print("\nEnding row cannot be equal or smaller than the starting row\n") - continue - - if input_pos.isalpha() and choice.isalpha() and get_numeric_value(input_pos) >= int(choice): - print("\nEnding column cannot be equal or before the starting column\n") - continue - - if input_pos.isalpha() and choice.isalpha(): - until = choice - break - - if input_pos.isnumeric() and choice.isnumeric(): - until = choice - break - - if input_pos.isalpha(): - print("\nPlease enter a valid column (must be a character)\n") - else: - print("\nPlease enter a valid row (must be a number)\n") - - return until - -def ask_limit() -> int: - limit = 0 - - while True: - limit = input("\nEnter a limit for number of items to be processed by ChatGPT. Cached or existing results do not count towards the limit. (type 0 to remove limit): ").strip() - - if limit.isnumeric() and int(limit) >= 0: - limit = int(limit) - break - - print("\nPlease enter a valid limit.\n") - - return limit diff --git a/requirements.txt b/requirements.txt index dab712f..7ca2723 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,13 @@ aiohttp==3.8.5 aiosignal==1.3.1 +altgraph==0.17.3 +annotated-types==0.5.0 anyio==3.7.1 async-timeout==4.0.3 attrs==23.1.0 +auto-py-to-exe==2.37.0 +bottle==0.12.25 +bottle-websocket==0.2.9 certifi==2023.7.22 charset-normalizer==3.2.0 colorama==0.4.6 @@ -10,10 +15,19 @@ copier==8.2.0 decorator==5.1.1 docutils==0.20.1 dunamai==1.18.0 +Eel==0.16.0 et-xmlfile==1.1.0 exceptiongroup==1.1.3 +ezpyi==0.2.1 +flet==0.10.0 +flet-core==0.10.0 +flet-runtime==0.10.0 frozenlist==1.4.0 funcy==2.0 +future==0.18.3 +gevent==23.7.0 +gevent-websocket==0.10.1 +greenlet==2.0.2 h11==0.14.0 httpcore==0.17.3 httpx==0.24.1 @@ -32,8 +46,12 @@ pathspec==0.11.2 Pillow==10.0.0 plumbum==1.8.2 prompt-toolkit==3.0.36 -pydantic==1.10.12 +pydantic==2.3.0 +pydantic-core==2.6.3 Pygments==2.16.1 +pyinstaller==5.13.2 +pyinstaller-hooks-contrib==2023.8 +pyparsing==3.1.1 pypng==0.20220715.0 PyYAML==6.0.1 pyyaml-include==1.3.1 @@ -50,4 +68,7 @@ watchdog==3.0.0 wcwidth==0.2.6 websocket-client==1.6.2 websockets==11.0.3 +whichcraft==0.6.1 yarl==1.9.2 +zope.event==5.0 +zope.interface==6.0 diff --git a/save.py b/save.py new file mode 100644 index 0000000..e9a92b8 --- /dev/null +++ b/save.py @@ -0,0 +1,32 @@ +#!/usr/bin/python + +__all__ = ["save_progress", "save_and_close"] + +def save_progress(workbook, path: str): + if workbook == None: + raise TypeError("workbook must not be NoneType") + + print("\nSaving file...") + + workbook.save(path) + + print("\nResults saved to: " + path) + +def save_and_close(workbook, result_book, result_path: str): + try: + if workbook != None: + workbook.close() + except Exception as e: + print("\nFailed to close workbook\n") + print(e) + + + if result_book != None: + print("\nSaving file...") + + result_book.save(result_path) + result_book.close() + + print("\nResults saved to: " + result_path) + else: + print("\n\nNot saving...") diff --git a/spreadsheet.py b/spreadsheet.py new file mode 100644 index 0000000..1a5f96b --- /dev/null +++ b/spreadsheet.py @@ -0,0 +1,39 @@ +#!/usr/bin/python + +from openpyxl import load_workbook +from os.path import isfile, exists +from os import remove +from os.path import splitext +from shutil import copyfile + +__all__ = ["load_input_workbook", "create_result_book"] + +def load_input_workbook(file_path: str): + wb = None + + if not exists(file_path) or not isfile(file_path): + raise Exception("File does not exist or invalid\n") + + try: + wb = load_workbook(file_path, read_only=True) + except Exception as e: + print(e) + raise Exception("Failed to load workbook. Is the file format correct (xlsx/xlsm/xltx/xltm)?\n") + + return wb + +def create_output_book(file_path: str, result_path: str, overwrite=False): + result_book = None + + try: + copyfile(file_path, result_path) + result_book = load_workbook(result_path) + except Exception as e: + print(e) + raise Exception("\nFailed to create output worksheet.\n") + + return result_book + +def generate_result_path(file_path: str) -> str: + pre, ext = splitext(file_path) + return pre + "_output" + ext