Skip to content

Commit

Permalink
Add unit test for other modules
Browse files Browse the repository at this point in the history
  • Loading branch information
StdioA committed Aug 25, 2024
1 parent 614713f commit b02a5bc
Show file tree
Hide file tree
Showing 12 changed files with 232 additions and 15 deletions.
1 change: 1 addition & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
testdata/example.bean binary
2 changes: 1 addition & 1 deletion .github/workflows/unit_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,4 @@ jobs:
run: |
pip install pytest pytest-cov coverage
coverage run -m pytest
coverage report
coverage report --include="**/*.py" --omit="**/*_test.py"
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ After input, the bot will complete the transaction details and output them for u
- [x] Clone transaction
- [ ] Withdraw transaction
- [x] Docker support
- [ ] Unit tests
- [x] Unit tests
- [ ] Web-based Chat UI
- [x] RAG (More precise element replacement through LLM, such as automatically changing "lunch" to "dinner", or automatically updating account changes, etc.)
- [ ] Support incremental construction of vector databases (If using OpenAI's `text-embedding-3-large`, currently building a database consisting of 1000 transactions costs approximately $0.003, and most providers of embedding do not charge for the embedding function, so the priority is not high)
Expand Down
3 changes: 2 additions & 1 deletion bean_utils/bean.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def find_account_by_payee(self, payee):
# Find the posting with missing units
# If not found, return the first expense account
for posting in target.postings:
if posting.units is MISSING:
if posting.units is MISSING or posting.meta.get("__automatic__"):
return posting.account
elif posting.account.startswith("Expenses:") and expense_account is None:
expense_account = posting.account
Expand Down Expand Up @@ -252,3 +252,4 @@ def parse_args(line):
def init_bean_manager(fname=None):
global bean_manager
bean_manager = BeanManager(fname)
return bean_manager
29 changes: 26 additions & 3 deletions bean_utils/bean_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from conf import _load_config_from_dict
from conf.config_data import Config
from beancount.parser import parser
from bean_utils import bean, txs_query, rag
from bean_utils import bean, txs_query


today = str(datetime.now().astimezone().date())
Expand Down Expand Up @@ -60,6 +60,10 @@ def test_account_search(mock_config):
assert exp_account == "Income:US:ETrade:PnL"

# Find account by payee
# Select by missing unit
exp_account = manager.find_account_by_payee("Chichipotle")
assert exp_account == "Expenses:Food:Restaurant"
# Select by account type
exp_account = manager.find_account_by_payee("China Garden")
assert exp_account == "Expenses:Food:Restaurant"

Expand Down Expand Up @@ -144,7 +148,7 @@ def mock_embedding(texts):
from vec_db.json_vec_db_test import easy_embedding
return [{
"embedding": easy_embedding(text)
} for text in texts], 0
} for text in texts], len(texts)


def test_generate_trx_with_vector_db(mock_config, monkeypatch):
Expand Down Expand Up @@ -232,4 +236,23 @@ def test_clone_trx(mock_config):
Expenses:Food:Restaurant
"""
assert_txs_equal(trx, exp_trx)



def test_parse_args():
assert bean.parse_args("") == []
assert bean.parse_args(" ") == []
assert bean.parse_args("a b c") == ["a", "b", "c"]
assert bean.parse_args("a 'b c' d") == ["a", "b c", "d"]
assert bean.parse_args("a 'b\"' c") == ["a", "b\"", "c"]
assert bean.parse_args("a 'b' c d") == ["a", "b", "c", "d"]
assert bean.parse_args("a ”b“ c d") == ["a", "b", "c", "d"]
assert bean.parse_args("a “b ” c d") == ["a", "b ", "c", "d"]

with pytest.raises(ValueError):
bean.parse_args("a 'b")

with pytest.raises(ValueError):
bean.parse_args("a 'b c")

with pytest.raises(ValueError):
bean.parse_args("a “b c'")
2 changes: 1 addition & 1 deletion bots/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def fetch_bill(start: date, end: date, root_level: int = 2) -> Table:
def clone_txs(message: str) -> Union[BaseMessage, ErrorMessage]:
try:
cloned_txs = bean_manager.clone_trx(message)
except (ValueError, requests.exceptions.RequestException) as e:
except ValueError as e:
if e == NoTransactionError:
err_msg = e.args[0]
else:
Expand Down
136 changes: 136 additions & 0 deletions bots/controller_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import pytest
from datetime import datetime, date
from bean_utils import txs_query
from conf import _load_config_from_dict
from bean_utils.bean import init_bean_manager
from bots import controller
from bean_utils.bean_test import assert_txs_equal, mock_embedding
from conf.config_data import Config
from conf.i18n import gettext as _


today = str(datetime.now().astimezone().date())


@pytest.fixture
def mock_env(tmp_path, monkeypatch):
conf_data = {
"embedding": {
"enable": False,
"db_store_folder": tmp_path,
},
"beancount": {
"filename": "testdata/example.bean",
"currency": "USD",
"account_distinguation_range": [2, 3],
}
}
config = _load_config_from_dict(conf_data)
manager = init_bean_manager()
monkeypatch.setattr(controller, "bean_manager", manager)
return config


def test_fetch_expense(monkeypatch, mock_env):
# Start and end is the same
start, end = date(2023, 6, 29), date(2023, 6, 30)
resp_table = controller.fetch_expense(start, end)
assert resp_table.title == _("Expenditures on 2023-06-29")
assert resp_table.headers == [_("Account"), _("Position")]
assert resp_table.rows == [
["Expenses:Food", "31.59 USD"],
]

# Start and end is different
# Test level
start, end = date(2023, 6, 1), date(2023, 7, 1)
resp_table = controller.fetch_expense(start, end, root_level=1)
assert resp_table.title == _("Expenditures between 2023-06-01 - 2023-07-01")
assert resp_table.headers == [_("Account"), _("Position")]
assert resp_table.rows == [
["Expenses", "7207.08 USD, 2400.00 IRAUSD"],
]


def test_fetch_bill(monkeypatch, mock_env):
# Start and end is the same
start, end = date(2023, 6, 29), date(2023, 6, 30)
resp_table = controller.fetch_bill(start, end)
assert resp_table.title == _("Account changes on 2023-06-29")
assert resp_table.headers == [_("Account"), _("Position")]
assert resp_table.rows == [
["Expenses:Food", "31.59 USD"],
["Liabilities:US", "-31.59 USD"],
]

# Start and end is different
# Test level
start, end = date(2023, 6, 1), date(2023, 7, 1)
resp_table = controller.fetch_bill(start, end, root_level=1)
assert resp_table.title == _("Account changes between 2023-06-01 - 2023-07-01")
assert resp_table.headers == [_("Account"), _("Position")]
assert resp_table.rows == [
['Assets', '3210.66768 USD, 10 VACHR, -2400.00 IRAUSD'],
['Expenses', '7207.08 USD, 2400.00 IRAUSD'],
['Income', '-10532.95 USD, -10 VACHR'],
['Liabilities', '115.20 USD']
]


def test_clone_txs(monkeypatch, mock_env):
# Normal generation
param = """
2023-05-23 * "Kin Soy" "Eating" #tag1 #tag2
Assets:US:BofA:Checking -23.40 USD
Expenses:Food:Restaurant
"""
exp_trx = f"""
{today} * "Kin Soy" "Eating" #tag1 #tag2
Assets:US:BofA:Checking -23.40 USD
Expenses:Food:Restaurant
"""
response = controller.clone_txs(param)
assert isinstance(response, controller.BaseMessage)
assert_txs_equal(response.content, exp_trx)

# Generate with error
response = controller.clone_txs('')
assert isinstance(response, controller.ErrorMessage)
assert response.content == "No transaction found"


def test_render_txs(monkeypatch, mock_env):
# Normal generation
responses = controller.render_txs('23.4 BofA:Checking "Kin Soy" Eating #tag1 #tag2')
assert len(responses) == 1
exp_trx = f"""
{today} * "Kin Soy" "Eating" #tag1 #tag2
Assets:US:BofA:Checking -23.40 USD
Expenses:Food:Restaurant
"""
assert isinstance(responses[0], controller.BaseMessage)
assert_txs_equal(responses[0].content, exp_trx)

# Generate with error
response = controller.render_txs('10.00 ICBC:Checking NotFound McDonalds "Big Mac"')
assert isinstance(response, controller.ErrorMessage)
assert response.content == 'ValueError: Account ICBC:Checking not found'


def test_build_db(monkeypatch, mock_env):
# Build db without embedding enabled
response = controller.build_db()
assert isinstance(response, controller.BaseMessage)
assert response.content == _("Embedding is not enabled.")

# Build db with embedding enabled
monkeypatch.setattr(mock_env, "embedding", Config.from_dict({
"enable": True,
"transaction_amount": 100,
"candidates": 3,
"output_amount": 2,
}))
monkeypatch.setattr(txs_query, "embedding", mock_embedding)
response = controller.build_db()
assert isinstance(response, controller.BaseMessage)
assert response.content == f"Token usage: {mock_env.embedding.transaction_amount}"
1 change: 0 additions & 1 deletion conf/config_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

_ImmutableError = TypeError("This dictionary is immutable")


class ImmutableDict(dict):
def __setitem__(self, key, value):
raise _ImmutableError
Expand Down
45 changes: 45 additions & 0 deletions conf/config_data_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import pytest
import yaml
from conf.config_data import Config, ImmutableDict


@pytest.mark.parametrize(
"method, args, kwargs",
[
("__setitem__", ["key", "value"], {}),
("__delitem__", ["key"], {}),
("update", [["key", "value"]], {}),
("clear", [], {}),
("pop", ["key"], {}),
("popitem", [], {}),
("setdefault", ["key", "value"], {}),
],
)
def test_immutable_dict(method, args, kwargs):
immutable_dict = ImmutableDict({"key": "value"})

with pytest.raises(TypeError):
getattr(immutable_dict, method)(*args, **kwargs)


def test_config(tmp_path):
conf_data = {
"embedding": {
"enable": False,
"db_store_folder": str(tmp_path),
},
"beancount": {
"filename": "testdata/example.bean",
"currency": "USD",
"account_distinguation_range": [2, 3],
}
}
config_path = tmp_path / "config.yaml"
with open(config_path, 'w') as file:
yaml.dump(conf_data, file)
config = Config(str(config_path))

assert config
assert config.embedding.enable is False
assert config.beancount.filename == "testdata/example.bean"
assert config.beancount.get("whatever") is None
2 changes: 1 addition & 1 deletion testdata/example.bean
Original file line number Diff line number Diff line change
Expand Up @@ -2786,7 +2786,7 @@ option "operating_currency" "USD"

2024-06-12 * "Chichipotle" "Eating out with Julie"
Liabilities:US:Chase:Slate -37.53 USD
Expenses:Food:Restaurant 37.53 USD
Expenses:Food:Restaurant ; 37.53 USD ; Manually modified to test empty value

2024-06-14 * "Kin Soy" "Eating out with Julie"
Liabilities:US:Chase:Slate -19.95 USD
Expand Down
12 changes: 6 additions & 6 deletions vec_db/match.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import math


def match(query, sentence):
sentence_seg = {x.strip('"') for x in sentence.split()}
query_seg = [x.strip('"') for x in query.split()]
score = sum(1 for x in sentence_seg if x in query_seg)
return (score + 1) / (len(query_seg) + 1)
# Currently not used
# def match(query, sentence):
# sentence_seg = {x.strip('"') for x in sentence.split()}
# query_seg = [x.strip('"') for x in query.split()]
# score = sum(1 for x in sentence_seg if x in query_seg)
# return (score + 1) / (len(query_seg) + 1)


def calculate_score(txs, sentence):
Expand Down
12 changes: 12 additions & 0 deletions vec_db/sqlite_vec_db_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,18 @@ def test_sqlite_db(tmp_path, mock_config, monkeypatch):
"embedding": easy_embedding("another-3"),
},
]
# Query without db built
candidates = sqlite_vec_db.query_by_embedding(
easy_embedding("content-1"), "sentence-1", 2,
)
assert len(candidates) == 0
# Query with empty table
sqlite_vec_db.build_db([])
candidates = sqlite_vec_db.query_by_embedding(
easy_embedding("content-1"), "sentence-1", 2,
)
assert len(candidates) == 0
# Build DB
sqlite_vec_db.build_db(txs)
db_path = sqlite_vec_db._get_db_name()
assert db_path.exists()
Expand Down

0 comments on commit b02a5bc

Please sign in to comment.