From b02a5bcb99c005df41ad507fed05ed6678ec357c Mon Sep 17 00:00:00 2001 From: David Dai Date: Sun, 25 Aug 2024 22:31:30 +0800 Subject: [PATCH] Add unit test for other modules --- .gitattributes | 1 + .github/workflows/unit_test.yml | 2 +- README.md | 2 +- bean_utils/bean.py | 3 +- bean_utils/bean_test.py | 29 ++++++- bots/controller.py | 2 +- bots/controller_test.py | 136 ++++++++++++++++++++++++++++++++ conf/config_data.py | 1 - conf/config_data_test.py | 45 +++++++++++ testdata/example.bean | 2 +- vec_db/match.py | 12 +-- vec_db/sqlite_vec_db_test.py | 12 +++ 12 files changed, 232 insertions(+), 15 deletions(-) create mode 100644 .gitattributes create mode 100644 bots/controller_test.py create mode 100644 conf/config_data_test.py diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..b5be640 --- /dev/null +++ b/.gitattributes @@ -0,0 +1 @@ +testdata/example.bean binary diff --git a/.github/workflows/unit_test.yml b/.github/workflows/unit_test.yml index 0730515..f7723eb 100644 --- a/.github/workflows/unit_test.yml +++ b/.github/workflows/unit_test.yml @@ -31,4 +31,4 @@ jobs: run: | pip install pytest pytest-cov coverage coverage run -m pytest - coverage report + coverage report --include="**/*.py" --omit="**/*_test.py" diff --git a/README.md b/README.md index c36d4ea..6509c23 100644 --- a/README.md +++ b/README.md @@ -68,7 +68,7 @@ After input, the bot will complete the transaction details and output them for u - [x] Clone transaction - [ ] Withdraw transaction - [x] Docker support -- [ ] Unit tests +- [x] Unit tests - [ ] Web-based Chat UI - [x] RAG (More precise element replacement through LLM, such as automatically changing "lunch" to "dinner", or automatically updating account changes, etc.) - [ ] Support incremental construction of vector databases (If using OpenAI's `text-embedding-3-large`, currently building a database consisting of 1000 transactions costs approximately $0.003, and most providers of embedding do not charge for the embedding function, so the priority is not high) diff --git a/bean_utils/bean.py b/bean_utils/bean.py index f32ff00..28394e2 100644 --- a/bean_utils/bean.py +++ b/bean_utils/bean.py @@ -90,7 +90,7 @@ def find_account_by_payee(self, payee): # Find the posting with missing units # If not found, return the first expense account for posting in target.postings: - if posting.units is MISSING: + if posting.units is MISSING or posting.meta.get("__automatic__"): return posting.account elif posting.account.startswith("Expenses:") and expense_account is None: expense_account = posting.account @@ -252,3 +252,4 @@ def parse_args(line): def init_bean_manager(fname=None): global bean_manager bean_manager = BeanManager(fname) + return bean_manager diff --git a/bean_utils/bean_test.py b/bean_utils/bean_test.py index 41fc002..6a8d4eb 100644 --- a/bean_utils/bean_test.py +++ b/bean_utils/bean_test.py @@ -4,7 +4,7 @@ from conf import _load_config_from_dict from conf.config_data import Config from beancount.parser import parser -from bean_utils import bean, txs_query, rag +from bean_utils import bean, txs_query today = str(datetime.now().astimezone().date()) @@ -60,6 +60,10 @@ def test_account_search(mock_config): assert exp_account == "Income:US:ETrade:PnL" # Find account by payee + # Select by missing unit + exp_account = manager.find_account_by_payee("Chichipotle") + assert exp_account == "Expenses:Food:Restaurant" + # Select by account type exp_account = manager.find_account_by_payee("China Garden") assert exp_account == "Expenses:Food:Restaurant" @@ -144,7 +148,7 @@ def mock_embedding(texts): from vec_db.json_vec_db_test import easy_embedding return [{ "embedding": easy_embedding(text) - } for text in texts], 0 + } for text in texts], len(texts) def test_generate_trx_with_vector_db(mock_config, monkeypatch): @@ -232,4 +236,23 @@ def test_clone_trx(mock_config): Expenses:Food:Restaurant """ assert_txs_equal(trx, exp_trx) - \ No newline at end of file + + +def test_parse_args(): + assert bean.parse_args("") == [] + assert bean.parse_args(" ") == [] + assert bean.parse_args("a b c") == ["a", "b", "c"] + assert bean.parse_args("a 'b c' d") == ["a", "b c", "d"] + assert bean.parse_args("a 'b\"' c") == ["a", "b\"", "c"] + assert bean.parse_args("a 'b' c d") == ["a", "b", "c", "d"] + assert bean.parse_args("a ”b“ c d") == ["a", "b", "c", "d"] + assert bean.parse_args("a “b ” c d") == ["a", "b ", "c", "d"] + + with pytest.raises(ValueError): + bean.parse_args("a 'b") + + with pytest.raises(ValueError): + bean.parse_args("a 'b c") + + with pytest.raises(ValueError): + bean.parse_args("a “b c'") diff --git a/bots/controller.py b/bots/controller.py index 5fa8466..d5cd426 100644 --- a/bots/controller.py +++ b/bots/controller.py @@ -81,7 +81,7 @@ def fetch_bill(start: date, end: date, root_level: int = 2) -> Table: def clone_txs(message: str) -> Union[BaseMessage, ErrorMessage]: try: cloned_txs = bean_manager.clone_trx(message) - except (ValueError, requests.exceptions.RequestException) as e: + except ValueError as e: if e == NoTransactionError: err_msg = e.args[0] else: diff --git a/bots/controller_test.py b/bots/controller_test.py new file mode 100644 index 0000000..386fcf8 --- /dev/null +++ b/bots/controller_test.py @@ -0,0 +1,136 @@ +import pytest +from datetime import datetime, date +from bean_utils import txs_query +from conf import _load_config_from_dict +from bean_utils.bean import init_bean_manager +from bots import controller +from bean_utils.bean_test import assert_txs_equal, mock_embedding +from conf.config_data import Config +from conf.i18n import gettext as _ + + +today = str(datetime.now().astimezone().date()) + + +@pytest.fixture +def mock_env(tmp_path, monkeypatch): + conf_data = { + "embedding": { + "enable": False, + "db_store_folder": tmp_path, + }, + "beancount": { + "filename": "testdata/example.bean", + "currency": "USD", + "account_distinguation_range": [2, 3], + } + } + config = _load_config_from_dict(conf_data) + manager = init_bean_manager() + monkeypatch.setattr(controller, "bean_manager", manager) + return config + + +def test_fetch_expense(monkeypatch, mock_env): + # Start and end is the same + start, end = date(2023, 6, 29), date(2023, 6, 30) + resp_table = controller.fetch_expense(start, end) + assert resp_table.title == _("Expenditures on 2023-06-29") + assert resp_table.headers == [_("Account"), _("Position")] + assert resp_table.rows == [ + ["Expenses:Food", "31.59 USD"], + ] + + # Start and end is different + # Test level + start, end = date(2023, 6, 1), date(2023, 7, 1) + resp_table = controller.fetch_expense(start, end, root_level=1) + assert resp_table.title == _("Expenditures between 2023-06-01 - 2023-07-01") + assert resp_table.headers == [_("Account"), _("Position")] + assert resp_table.rows == [ + ["Expenses", "7207.08 USD, 2400.00 IRAUSD"], + ] + + +def test_fetch_bill(monkeypatch, mock_env): + # Start and end is the same + start, end = date(2023, 6, 29), date(2023, 6, 30) + resp_table = controller.fetch_bill(start, end) + assert resp_table.title == _("Account changes on 2023-06-29") + assert resp_table.headers == [_("Account"), _("Position")] + assert resp_table.rows == [ + ["Expenses:Food", "31.59 USD"], + ["Liabilities:US", "-31.59 USD"], + ] + + # Start and end is different + # Test level + start, end = date(2023, 6, 1), date(2023, 7, 1) + resp_table = controller.fetch_bill(start, end, root_level=1) + assert resp_table.title == _("Account changes between 2023-06-01 - 2023-07-01") + assert resp_table.headers == [_("Account"), _("Position")] + assert resp_table.rows == [ + ['Assets', '3210.66768 USD, 10 VACHR, -2400.00 IRAUSD'], + ['Expenses', '7207.08 USD, 2400.00 IRAUSD'], + ['Income', '-10532.95 USD, -10 VACHR'], + ['Liabilities', '115.20 USD'] + ] + + +def test_clone_txs(monkeypatch, mock_env): + # Normal generation + param = """ + 2023-05-23 * "Kin Soy" "Eating" #tag1 #tag2 + Assets:US:BofA:Checking -23.40 USD + Expenses:Food:Restaurant + """ + exp_trx = f""" + {today} * "Kin Soy" "Eating" #tag1 #tag2 + Assets:US:BofA:Checking -23.40 USD + Expenses:Food:Restaurant + """ + response = controller.clone_txs(param) + assert isinstance(response, controller.BaseMessage) + assert_txs_equal(response.content, exp_trx) + + # Generate with error + response = controller.clone_txs('') + assert isinstance(response, controller.ErrorMessage) + assert response.content == "No transaction found" + + +def test_render_txs(monkeypatch, mock_env): + # Normal generation + responses = controller.render_txs('23.4 BofA:Checking "Kin Soy" Eating #tag1 #tag2') + assert len(responses) == 1 + exp_trx = f""" + {today} * "Kin Soy" "Eating" #tag1 #tag2 + Assets:US:BofA:Checking -23.40 USD + Expenses:Food:Restaurant + """ + assert isinstance(responses[0], controller.BaseMessage) + assert_txs_equal(responses[0].content, exp_trx) + + # Generate with error + response = controller.render_txs('10.00 ICBC:Checking NotFound McDonalds "Big Mac"') + assert isinstance(response, controller.ErrorMessage) + assert response.content == 'ValueError: Account ICBC:Checking not found' + + +def test_build_db(monkeypatch, mock_env): + # Build db without embedding enabled + response = controller.build_db() + assert isinstance(response, controller.BaseMessage) + assert response.content == _("Embedding is not enabled.") + + # Build db with embedding enabled + monkeypatch.setattr(mock_env, "embedding", Config.from_dict({ + "enable": True, + "transaction_amount": 100, + "candidates": 3, + "output_amount": 2, + })) + monkeypatch.setattr(txs_query, "embedding", mock_embedding) + response = controller.build_db() + assert isinstance(response, controller.BaseMessage) + assert response.content == f"Token usage: {mock_env.embedding.transaction_amount}" diff --git a/conf/config_data.py b/conf/config_data.py index 027cce9..03067ee 100644 --- a/conf/config_data.py +++ b/conf/config_data.py @@ -3,7 +3,6 @@ _ImmutableError = TypeError("This dictionary is immutable") - class ImmutableDict(dict): def __setitem__(self, key, value): raise _ImmutableError diff --git a/conf/config_data_test.py b/conf/config_data_test.py new file mode 100644 index 0000000..54d3d55 --- /dev/null +++ b/conf/config_data_test.py @@ -0,0 +1,45 @@ +import pytest +import yaml +from conf.config_data import Config, ImmutableDict + + +@pytest.mark.parametrize( + "method, args, kwargs", + [ + ("__setitem__", ["key", "value"], {}), + ("__delitem__", ["key"], {}), + ("update", [["key", "value"]], {}), + ("clear", [], {}), + ("pop", ["key"], {}), + ("popitem", [], {}), + ("setdefault", ["key", "value"], {}), + ], +) +def test_immutable_dict(method, args, kwargs): + immutable_dict = ImmutableDict({"key": "value"}) + + with pytest.raises(TypeError): + getattr(immutable_dict, method)(*args, **kwargs) + + +def test_config(tmp_path): + conf_data = { + "embedding": { + "enable": False, + "db_store_folder": str(tmp_path), + }, + "beancount": { + "filename": "testdata/example.bean", + "currency": "USD", + "account_distinguation_range": [2, 3], + } + } + config_path = tmp_path / "config.yaml" + with open(config_path, 'w') as file: + yaml.dump(conf_data, file) + config = Config(str(config_path)) + + assert config + assert config.embedding.enable is False + assert config.beancount.filename == "testdata/example.bean" + assert config.beancount.get("whatever") is None diff --git a/testdata/example.bean b/testdata/example.bean index 42d262f..30917c3 100644 --- a/testdata/example.bean +++ b/testdata/example.bean @@ -2786,7 +2786,7 @@ option "operating_currency" "USD" 2024-06-12 * "Chichipotle" "Eating out with Julie" Liabilities:US:Chase:Slate -37.53 USD - Expenses:Food:Restaurant 37.53 USD + Expenses:Food:Restaurant ; 37.53 USD ; Manually modified to test empty value 2024-06-14 * "Kin Soy" "Eating out with Julie" Liabilities:US:Chase:Slate -19.95 USD diff --git a/vec_db/match.py b/vec_db/match.py index 308e14f..b37d995 100644 --- a/vec_db/match.py +++ b/vec_db/match.py @@ -1,11 +1,11 @@ import math - -def match(query, sentence): - sentence_seg = {x.strip('"') for x in sentence.split()} - query_seg = [x.strip('"') for x in query.split()] - score = sum(1 for x in sentence_seg if x in query_seg) - return (score + 1) / (len(query_seg) + 1) +# Currently not used +# def match(query, sentence): +# sentence_seg = {x.strip('"') for x in sentence.split()} +# query_seg = [x.strip('"') for x in query.split()] +# score = sum(1 for x in sentence_seg if x in query_seg) +# return (score + 1) / (len(query_seg) + 1) def calculate_score(txs, sentence): diff --git a/vec_db/sqlite_vec_db_test.py b/vec_db/sqlite_vec_db_test.py index b4716f1..60a5505 100644 --- a/vec_db/sqlite_vec_db_test.py +++ b/vec_db/sqlite_vec_db_test.py @@ -35,6 +35,18 @@ def test_sqlite_db(tmp_path, mock_config, monkeypatch): "embedding": easy_embedding("another-3"), }, ] + # Query without db built + candidates = sqlite_vec_db.query_by_embedding( + easy_embedding("content-1"), "sentence-1", 2, + ) + assert len(candidates) == 0 + # Query with empty table + sqlite_vec_db.build_db([]) + candidates = sqlite_vec_db.query_by_embedding( + easy_embedding("content-1"), "sentence-1", 2, + ) + assert len(candidates) == 0 + # Build DB sqlite_vec_db.build_db(txs) db_path = sqlite_vec_db._get_db_name() assert db_path.exists()