Skip to content

Commit

Permalink
Add docstrings and simplify the structure of main.py
Browse files Browse the repository at this point in the history
  • Loading branch information
StdioA committed Aug 27, 2024
1 parent 995251c commit 8370981
Show file tree
Hide file tree
Showing 6 changed files with 182 additions and 20 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,6 @@ lint:
@ruff check

test:
coverage run --source=. --omit="**/*_test.py,bots/mmbot.py,bots/telegram_bot.py,main.py,test.py" -m pytest
coverage run --source=. --omit="**/*_test.py,bots/*_bot.py,main.py,test.py" -m pytest
coverage report
@coverage html
129 changes: 120 additions & 9 deletions bean_utils/bean.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,22 @@ def __init__(self, fname=None) -> None:
self._load()

def _load(self):
"""
Load the beancount file and initialize the internal state.
This function loads the beancount file specified by `self.fname` and
initializes the internal state. The internal state includes the entries
(transactions and metadata), options, accounts, modification times of
included files, and account files.
This function does not return anything. It updates the following instance
variables:
- `_entries`: a list of parsed entries.
- `_options`: a dictionary of options.
- `_accounts`: a set of accounts.
- `mtimes`: a dictionary mapping filenames to modification times.
- `account_files`: a set of filenames.
"""
self._entries, errors, self._options = loader.load_file(self.fname)
self._accounts = set()
self.mtimes = {}
Expand All @@ -51,7 +67,13 @@ def _load(self):
self.mtimes[f] = Path(f).stat().st_mtime

def _auto_reload(self, accounts_only=False):
# Check and reload
"""
Check and reload if any of the files have been modified.
Args:
accounts_only (bool): If True, only check the files that contains account
transactions. Defaults to False.
"""
files_to_check = self.mtimes.keys()
if accounts_only:
files_to_check = self.account_files
Expand All @@ -76,13 +98,34 @@ def accounts(self):
return self._accounts

def find_account(self, account_str):
"""
Find an account that contains the given string.
Args:
account_str (str): A substring to search for in the account string.
Returns:
str or None: The account that contains the given substring, or None
if no such account is found.
"""
for account in self.accounts:
if account_str in account:
return account
return None

def find_account_by_payee(self, payee):
# Find the account with the same payee
"""
Find the account with the same payee.
Args:
payee (str): The payee to search for.
Returns:
str or None: The account with the same payee, or None if not found. If the
transaction has multiple postings with missing units, the first one is
returned. If no expense account is found, the first expense account in
the transaction is returned.
"""
target = None
for trx in reversed(self._entries):
if not isinstance(trx, Transaction):
Expand All @@ -103,9 +146,27 @@ def find_account_by_payee(self, payee):
return expense_account

def run_query(self, q):
"""
A procedural interface to the `beancount.query` module.
"""
return query.run_query(self.entries, self.options, q)

def match_new_args(self, args) -> List[List[str]]:
def modify_args_via_vec(self, args) -> List[List[str]]:
"""
Given a list of arguments, modify the arguments to match the transactions in the vector
database.
Args:
args (List[str]): The arguments to modify.
Returns:
List[List[str]]: A list of modified arguments.
This function queries the vector database to find transactions that match the given
arguments. It then rebuilds the narrations for each matching transaction and returns a
list of modified arguments. If no matching transactions are found, an empty list is
returned.
"""
# Query from vector db
matched_txs = query_txs(" ".join(args[1:]))
candidate_args = []
Expand All @@ -120,16 +181,32 @@ def match_new_args(self, args) -> List[List[str]]:
return candidate_args

def build_trx(self, args):
"""
The core function of transaction generation.
Args:
args (List[str]): A list of strings representing the transaction arguments.
The format is: {amount} {from_account} {to_account} {payee} {narration} [#{tag1} #{tag2} ...].
The to_account and narration are optional.
Returns:
str: The transaction string in the beancount format.
Raises:
ValueError: If from_account or to_account is not found.
"""
amount, from_acc, to_acc, *extra = args

amount = Decimal(amount)
from_account = self.find_account(from_acc)
to_account = self.find_account(to_acc)
payee = None

# from_account id requied
if from_account is None:
err_msg = _("Account {acc} not found").format(acc=from_acc)
raise ValueError(err_msg)
# Try to find the payee if to_account is not found
if to_account is None:
payee = to_acc
to_account = self.find_account_by_payee(payee)
Expand All @@ -139,7 +216,7 @@ def build_trx(self, args):

if payee is None:
payee, *extra = extra
kwargs = {
trx_info = {
"date": datetime.now().astimezone().date(),
"payee": payee,
"from_account": from_account,
Expand All @@ -154,14 +231,30 @@ def build_trx(self, args):
for arg in extra:
if arg.startswith(("#", "^")):
tags.append(arg)
elif not kwargs["desc"]:
kwargs["desc"] = arg
elif not trx_info["desc"]:
trx_info["desc"] = arg
if tags:
kwargs["tags"] = " " + " ".join(tags)
trx_info["tags"] = " " + " ".join(tags)

return transaction_tmpl.format(**kwargs)
return transaction_tmpl.format(**trx_info)

def generate_trx(self, line) -> List[str]:
"""
The entry procedure for transaction generation.
If the line cannot be directly converted into a transaction,
the function will attempt to match it from a vector database
or a RAG model. If all attempts fail, a ValueError will be raised.
Args:
line (str): The line to generate a transaction from.
Returns:
List[str]: A list of transactions generated from the line.
Raises:
ValueError: If all attempts to generate a transaction fail.
"""
args = parse_args(line)
try:
return [self.build_trx(args)]
Expand All @@ -177,7 +270,7 @@ def generate_trx(self, line) -> List[str]:
if vec_enabled:
# Query from vector db
candidate_txs = []
for new_args in self.match_new_args(args):
for new_args in self.modify_args_via_vec(args):
with contextlib.suppress(ValueError):
candidate_txs.append(self.build_trx(new_args))
if candidate_txs:
Expand All @@ -187,6 +280,15 @@ def generate_trx(self, line) -> List[str]:
raise e

def clone_trx(self, text) -> str:
"""
Clone a transaction from text.
Args:
text (str): Text contains one transaction.
Returns:
str: A transaction with today's date.
"""
entries, _, _ = parser.parse_string(text)
try:
txs = next(e for e in entries if isinstance(e, Transaction))
Expand All @@ -202,6 +304,15 @@ def clone_trx(self, text) -> str:
return "\n".join(segments)

def commit_trx(self, data):
"""
Commit a transaction to beancount file, and format.
Args:
data (str): The transaction data in beancount format.
Raises:
SubprocessError: If the bean-format command fails to execute.
"""
fname = self.fname
with open(fname, 'a') as f:
f.write("\n" + data + "\n")
Expand Down
53 changes: 52 additions & 1 deletion bean_utils/vec_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,20 @@ def embedding(texts):


def convert_account(account):
"""
Convert an account string to a specific segment.
Args:
account (str): The account string to convert.
Returns:
str: The converted account string.
This function takes an account string and converts it to a specific segment
based on the configuration in conf.config.beancount.account_distinguation_range.
If the account string does not contain any segment, the original account string
is returned.
"""
dist_range = conf.config.beancount.account_distinguation_range
segments = account.split(":")
if isinstance(dist_range, int):
Expand All @@ -44,7 +58,22 @@ def escape_quotes(s):


def convert_to_natural_language(transaction) -> str:
# date = transaction.date.strftime('%Y-%m-%d')
"""
Convert a transaction object to a string representation of natural language for input to RAG.
Args:
transactions (Transation): A Transaction object.
Returns:
str: The natural language representation of the transaction.
The format of the representation is:
`"{payee}" "{description}" "{account1} {account2} ..." [{#tag1} {#tag2} ...]`,
where `{payee}` is the payee of the transaction, `{description}` is the narration,
and `{account1} {account2} ...` is a space-separated list of accounts in the transaction.
The accounts are converted to the most distinguable level as specified in the configuration.
If the transaction has tags, they are appended to the end of the sentence.
"""
payee = f'"{escape_quotes(transaction.payee)}"'
description = f'"{escape_quotes(transaction.narration)}"'
accounts = " ".join([convert_account(posting.account) for posting in transaction.postings])
Expand All @@ -56,6 +85,18 @@ def convert_to_natural_language(transaction) -> str:


def build_tx_db(transactions):
"""
Build a transaction database from the given transactions. This function
consolidates the latest transactions and calculates their embeddings.
The embeddings are stored in a database for future use.
Args:
transactions (list): A list of Transaction objects representing the
transactions.
Returns:
int: The total number of tokens used for embedding.
"""
_content_cache = {}
def _read_lines(fname, start, end):
if fname not in _content_cache:
Expand Down Expand Up @@ -103,6 +144,16 @@ def _read_lines(fname, start, end):


def query_txs(query):
"""
Query transactions based on the given query string.
Args:
query (str): The query string to search for.
Returns:
list: A list of matched transactions. The length of the list is determined
by the `output_amount` configuration.
"""
candidates = conf.config.embedding.candidates or 3
output_amount = conf.config.embedding.output_amount or 1
match = query_by_embedding(embedding([query])[0][0]["embedding"], query, candidates)
Expand Down
File renamed without changes.
16 changes: 8 additions & 8 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,27 +13,27 @@ def init_bot(config_path):
bean.init_bean_manager()



def main():
def parse_args():
parser = argparse.ArgumentParser(prog='beanbot',
description='Bot to translate text into beancount transaction')
subparser = parser.add_subparsers(title='sub command', dest='command')
subparser = parser.add_subparsers(title='sub command', required=True, dest='command')

telegram_parser = subparser.add_parser("telegram")
telegram_parser.add_argument('-c', nargs="?", type=str, default="config.yaml", help="config file path")
mattermost_parser = subparser.add_parser("mattermost")
mattermost_parser.add_argument('-c', nargs="?", type=str, default="config.yaml", help="config file path")

args = parser.parse_args()
if args.command is None:
parser.print_help()
return
return parser.parse_args()


def main():
args = parse_args()
init_bot(args.c)

if args.command == "telegram":
from bots.telegram_bot import run_bot
elif args.command == "mattermost":
from bots.mmbot import run_bot
from bots.mattermost_bot import run_bot
run_bot()


Expand Down
2 changes: 1 addition & 1 deletion vec_db/sqlite_vec_db_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from vec_db import sqlite_vec_db


def test_sqlite_db(tmp_path, mock_config, monkeypatch):
def test_sqlite_db(tmp_path, mock_config, monkeypatch):
monkeypatch.setattr(sqlite_vec_db, "_db", None)
# Build DB
txs = [
Expand Down

0 comments on commit 8370981

Please sign in to comment.