Skip to content

Commit

Permalink
extract out string representation to per class basis, code cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
ynx0 committed Mar 17, 2022
1 parent 5698a6f commit 9c6d8dd
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 34 deletions.
54 changes: 26 additions & 28 deletions cas/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from typing import Union
from abc import ABC
from enum import Enum, unique

Expand All @@ -9,25 +8,13 @@
# root object. either an equation or a node
class Obj(ABC):

def __str__(self):
n = self
# assert type(n) == Equation or issubclass(type(n), Node), f'Invalid input type {type(n)}'
if type(n) == Num:
# remove decimal if float is int
val = int(n.val) if n.val.is_integer() else n.val
return str(val)
elif type(n) == Var:
return n.name
elif type(n) == Function:
return f"{n.f.value}({n.x})"
elif type(n) == UnaryOp:
return f"{n.op.value}{n.a}"
elif type(n) == BinaryOp:
return f"{n.a} {n.op.value} {n.b}"
elif type(n) == Equation or type(n) == Assignment:
return f"{n.lhs} = {n.rhs}"
else:
assert False, f'Unhandled node type {type(n)}'
@property
def is_node(self):
return issubclass(type(self), Node)

@property
def is_equation(self):
return issubclass(type(self), Equation)


class Node(Obj, ABC):
Expand All @@ -42,6 +29,9 @@ def __init__(self, lhs: Node, rhs: Node):
def __repr__(self):
return f"Equation(lhs={self.lhs} rhs={self.rhs})"

def __str__(self):
return f"{self.lhs} = {self.rhs}"


class Assignment(Equation):
def __init__(self, lhs: 'Var', rhs: Node):
Expand Down Expand Up @@ -69,6 +59,9 @@ def __init__(self, op, a, b):
def __repr__(self):
return f"BinOp(op={self.op.name}, a={self.a} b={self.b})"

def __str__(self):
return f"{self.a} {self.op.value} {self.b}"


class UnaryOp(Node):
class Op(Enum):
Expand All @@ -81,6 +74,9 @@ def __init__(self, op, a):
def __repr__(self):
return f"UnaryOp(op={self.op.name}, a={self.a})"

def __str__(self):
return f"{self.op.value}{self.a}"


class Function(Node):
@unique
Expand All @@ -99,6 +95,9 @@ def __init__(self, f, x):
def __repr__(self):
return f"Fn({self.f.name}, x={self.x})"

def __str__(self):
return f"{self.f.value}({self.x})"


class Var(Node):
def __init__(self, name: str):
Expand All @@ -108,6 +107,9 @@ def __init__(self, name: str):
def __repr__(self):
return f"Var({self.name})"

def __str__(self):
return self.name


class Num(Node):

Expand All @@ -117,12 +119,8 @@ def __init__(self, val):
def __repr__(self):
return f"Num({self.val})"

def __str__(self):
# remove decimal if float is int
val = int(self.val) if self.val.is_integer() else self.val
return str(val)

# misc functions

def is_node(o: Union[Equation, Node]):
return issubclass(type(o), Node)


def is_equation(o: Union[Equation, Node]):
return type(o) == Equation
10 changes: 4 additions & 6 deletions evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,10 @@
# Node Evaluation
class UnboundVariableError(RuntimeError):
pass
# def __init__(self, msg):
# super().__init__(msg)
# self.msg = msg


def evaluate(n: Node, bound_vars: Dict[str, Node] = None):
assert is_node(n), f'Parameter {n}: {type(n)} is not a node'
assert n.is_node, f'Parameter {n}: {type(n)} is not a node'
bound_vars = bound_vars or dict()

if type(n) == Num:
Expand Down Expand Up @@ -41,7 +38,8 @@ def evaluate(n: Node, bound_vars: Dict[str, Node] = None):
if n.op == UnaryOp.Op.NEG:
return -1 * x
else:
assert False, 'unreachable'
assert False, f'Unhandled unary operation {n.op}'

elif type(n) == BinaryOp:
a = evaluate(n.a, bound_vars)
b = evaluate(n.b, bound_vars)
Expand All @@ -57,7 +55,7 @@ def evaluate(n: Node, bound_vars: Dict[str, Node] = None):
elif n.op == BinaryOp.Op.EXP:
return a ** b
else:
assert False, f'unreachable'
assert False, f'Unhandled binary operation {n.op}'

else:
print(f'evaluate: unimplemented node type {type(n)}')

0 comments on commit 9c6d8dd

Please sign in to comment.