Skip to content

Commit

Permalink
Merge pull request #112 from datacamp/dev-protowhat
Browse files Browse the repository at this point in the history
Dev protowhat
  • Loading branch information
machow authored Aug 4, 2017
2 parents ee186c7 + 26fd7a4 commit 40b02ba
Show file tree
Hide file tree
Showing 19 changed files with 62 additions and 399 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
name='sqlwhat',
version=version,
packages=['sqlwhat', 'sqlwhat.checks'],
install_requires=['markdown2', 'antlr-plsql>=0.1.0', 'antlr-tsql>=0.1.0'],
install_requires=['protowhat', 'antlr-plsql>=0.1.0', 'antlr-tsql>=0.1.0'],
description = 'Submission correctness tests for sql',
author = 'Michael Chow',
author_email = '[email protected]',
Expand Down
91 changes: 0 additions & 91 deletions sqlwhat/Reporter.py

This file was deleted.

72 changes: 6 additions & 66 deletions sqlwhat/State.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,71 +2,11 @@
import inspect

from sqlwhat.selectors import Dispatcher
from protowhat.State import State as BaseState

class State:
def __init__(self,
student_code,
solution_code,
pre_exercise_code,
student_conn,
solution_conn,
student_result,
solution_result,
reporter,
solution_ast = None,
student_ast = None,
ast_dispatcher = None,
history = tuple()):
class State(BaseState):

for k,v in locals().items():
if k != 'self': setattr(self, k, v)

if ast_dispatcher is None:
# MCE doesn't always have connection - fallback on postgresql
dn = student_conn.dialect.name if student_conn else 'postgresql'
self.ast_dispatcher = Dispatcher.from_dialect(dn)

# Parse solution and student code
# solution code raises an exception if can't be parsed
if solution_ast is None: self.solution_ast = self.ast_dispatcher.parse(solution_code)
if student_ast is None: self.student_ast = self.ast_dispatcher.parse(student_code)

def get_ast_path(self):
rev_checks = filter(lambda x: x['type'] in ['check_field', 'check_node'], reversed(self.history))

try:
last = next(rev_checks)
if last['type'] == 'check_node':
# final check was for a node
return self.ast_dispatcher.describe(last['node'],
index = last['kwargs']['index'],
msg = "{index}{node_name}")
else:
node = next(rev_checks)
if node['type'] == 'check_node':
# checked for node, then for target, so can give rich description
return self.ast_dispatcher.describe(node['node'],
field = last['kwargs']['name'],
index = last['kwargs']['index'],
msg = "{index}{field_name} of the {node_name}")
except StopIteration:
return self.ast_dispatcher.describe(self.student_ast, "{node_name}")


def do_test(self, *args, highlight=None, **kwargs):
highlight = self.student_ast if highlight is None else highlight

return self.reporter.do_test(*args, highlight=highlight, **kwargs)

def to_child(self, **kwargs):
"""Basic implementation of returning a child state"""

good_pars = inspect.signature(self.__init__).parameters
bad_pars = set(kwargs) - set(good_pars)
if bad_pars:
raise KeyError("Invalid init params for State: %s"% ", ".join(bad_pars))

child = copy(self)
for k, v in kwargs.items(): setattr(child, k, v)
child.parent = self
return child
def get_dispatcher(self):
# MCE doesn't always have connection - fallback on postgresql
dialect = self.student_conn.dialect.name if self.student_conn else 'postgresql'
return Dispatcher.from_dialect(dialect)
65 changes: 0 additions & 65 deletions sqlwhat/Test.py

This file was deleted.

3 changes: 2 additions & 1 deletion sqlwhat/checks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from sqlwhat.checks.check_result import check_result, test_has_columns, test_nrows, test_ncols, test_column, allow_error, test_error, test_name_miscased, test_column_name, sort_rows
from sqlwhat.checks.check_logic import fail, multi, extend, test_or, test_correct
from sqlwhat.checks.check_funcs import check_node, check_field, test_student_typed, has_equal_ast, test_mc, success_msg, verify_ast_parses

from protowhat.checks.check_logic import fail, multi, extend, test_or, test_correct
17 changes: 8 additions & 9 deletions sqlwhat/checks/check_funcs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from sqlwhat.Test import TestFail, Test
from sqlwhat.State import State

from functools import partial, wraps
Expand Down Expand Up @@ -61,7 +60,7 @@ def check_node(state, name, index=0, missing_msg="Could not find the {index}{nod
except IndexError:
# use speaker on ast dialect module to get message, or fall back to generic
_msg = state.ast_dispatcher.describe(sol_stmt, missing_msg, index = index)
state.do_test(Test(_msg or MSG_CHECK_FALLBACK))
state.do_test(_msg or MSG_CHECK_FALLBACK)

action = {'type': 'check_node', 'kwargs': {'name': name, 'index': index}, 'node': stu_stmt}

Expand Down Expand Up @@ -108,12 +107,12 @@ def check_field(state, name, index=None, missing_msg="Could not find the {index}
except:
# use speaker on ast dialect module to get message, or fall back to generic
_msg = state.ast_dispatcher.describe(state.student_ast, missing_msg, field = name, index = index)
state.do_test(Test(_msg or MSG_CHECK_FALLBACK))
state.do_test(_msg or MSG_CHECK_FALLBACK)

# fail if attribute exists, but is none only for student
if stu_attr is None and sol_attr is not None:
_msg = state.ast_dispatcher.describe(state.student_ast, missing_msg, field = name, index = index)
state.do_test(Test(_msg))
state.do_test(_msg)

action = {'type': 'check_field', 'kwargs': {'name': name, 'index': index}}

Expand Down Expand Up @@ -176,7 +175,7 @@ def test_student_typed(state, text, msg="Submission does not contain the code `{
res = text in stu_text if fixed else re.search(text, stu_text)

if not res:
state.do_test(Test(_msg))
state.do_test(_msg)

return state

Expand Down Expand Up @@ -221,8 +220,8 @@ def has_equal_ast(state,
sol_rep = repr(sol_ast)

_msg = msg.format(ast_path = state.get_ast_path())
if exact and (sol_rep != stu_rep): state.do_test(Test(_msg or MSG_CHECK_FALLBACK))
elif not exact and (sol_rep not in stu_rep): state.do_test(Test(_msg or MSG_CHECK_FALLBACK))
if exact and (sol_rep != stu_rep): state.do_test(_msg or MSG_CHECK_FALLBACK)
elif not exact and (sol_rep not in stu_rep): state.do_test(_msg or MSG_CHECK_FALLBACK)

return state

Expand All @@ -247,7 +246,7 @@ def test_mc(state, correct, msgs):
exec(state.student_code, globals(), ctxt)
sel_indx = ctxt['selected_option']
if sel_indx != correct:
state.do_test(Test(msgs[sel_indx-1]))
state.do_test(msgs[sel_indx-1])
else:
state.reporter.success_msg = msgs[correct-1]

Expand All @@ -274,6 +273,6 @@ def success_msg(state, msg):
def verify_ast_parses(state):
asts = [state.student_ast, state.solution_ast]
if any(isinstance(c, state.ast_dispatcher.ast.AntlrException) for c in asts):
state.do_test(Test("AST did not parse"))
state.do_test("AST did not parse")

return state
4 changes: 2 additions & 2 deletions sqlwhat/checks/check_logic.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from sqlwhat.Test import TestFail, Test
from protowhat.Test import TestFail
from types import GeneratorType
from functools import partial

def fail(state, msg=""):
"""Always fails the SCT, with an optional msg."""
state.do_test(Test(msg))
state.do_test(msg)

return state

Expand Down
16 changes: 7 additions & 9 deletions sqlwhat/checks/check_result.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from sqlwhat.Test import TestFail, Test

def allow_error(state):
"""Allow submission to pass, even if it originally caused a database error."""

Expand All @@ -13,7 +11,7 @@ def test_error(state, msg="Your command returned the following error: {}"):
error = state.reporter.get_error()

if error is not None:
state.do_test(Test(msg.format(error)))
state.do_test(msg.format(error))

return state

Expand All @@ -39,7 +37,7 @@ def test_has_columns(state, msg="Your result did not output any columns."):
"""Test if the student's query result contains any columns"""

if not state.student_result:
state.do_test(Test(msg))
state.do_test(msg)

return state

Expand All @@ -55,7 +53,7 @@ def test_nrows(state, msg="Result has {} row(s) but expected {}."):

if n_stu != n_sol:
_msg = msg.format(n_stu, n_sol)
state.do_test(Test(_msg))
state.do_test(_msg)

return state

Expand All @@ -70,7 +68,7 @@ def test_ncols(state, msg="Result has {} column(s) but expected {}."):

if n_stu != n_sol:
_msg = msg.format(n_stu, n_sol)
state.do_test(Test(_msg))
state.do_test(_msg)

return state

Expand All @@ -86,7 +84,7 @@ def test_name_miscased(state, name,

if name.lower() in stu_lower and name not in stu_res:
_msg = msg.format(stu_lower[name.lower()], name)
state.do_test(Test(_msg))
state.do_test(_msg)

return state

Expand All @@ -107,7 +105,7 @@ def test_column_name(state, name,

if name.lower() not in stu_lower:
_msg = msg.format(name)
state.do_test(Test(_msg))
state.do_test(_msg)

return state

Expand Down Expand Up @@ -171,7 +169,7 @@ def test_column(state, name, msg="Column `{}` in the solution does not have a co

# fail test if no match
_msg = msg.format(name)
state.do_test(Test(_msg))
state.do_test(_msg)

# return state just in case, but should never happen
return state
Expand Down
Loading

0 comments on commit 40b02ba

Please sign in to comment.