Skip to content

Commit

Permalink
Merge pull request #178 from Vardan2009/master
Browse files Browse the repository at this point in the history
Added basic security prompts
  • Loading branch information
Almas-Ali authored Oct 27, 2024
2 parents 34da0c0 + e10f5ea commit ab23b3d
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 5 deletions.
15 changes: 15 additions & 0 deletions core/builtin_classes/file_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from core.errors import RTError
from core.parser import Context, RTResult
from core.tokens import Position
from core import security


class FileObject(BuiltInObject):
Expand All @@ -14,6 +15,8 @@ class FileObject(BuiltInObject):
@operator("__constructor__")
@check([String, String], [None, String("r")])
def constructor(self, path: String, mode: String) -> RTResult[Value]:
security.security_prompt("disk_access")

allowed_modes = [None, "r", "w", "a", "r+", "w+", "a+"] # Allowed modes for opening files
res = RTResult[Value]()
if mode.value not in allowed_modes:
Expand All @@ -29,6 +32,8 @@ def constructor(self, path: String, mode: String) -> RTResult[Value]:
@args(["count"], [Number(-1)])
@method
def read(self, ctx: Context) -> RTResult[Value]:
security.security_prompt("disk_access")

res = RTResult[Value]()
count = ctx.symbol_table.get("count")
assert count is not None
Expand All @@ -49,6 +54,8 @@ def read(self, ctx: Context) -> RTResult[Value]:
@args([])
@method
def readline(self, ctx: Context) -> RTResult[Value]:
security.security_prompt("disk_access")

res = RTResult[Value]()
try:
value = self.file.readline()
Expand All @@ -60,6 +67,8 @@ def readline(self, ctx: Context) -> RTResult[Value]:
@args([])
@method
def readlines(self, ctx: Context) -> RTResult[Value]:
security.security_prompt("disk_access")

res = RTResult[Value]()
try:
value = self.file.readlines()
Expand All @@ -71,6 +80,8 @@ def readlines(self, ctx: Context) -> RTResult[Value]:
@args(["data"])
@method
def write(self, ctx: Context) -> RTResult[Value]:
security.security_prompt("disk_access")

res = RTResult[Value]()
data = ctx.symbol_table.get("data")
assert data is not None
Expand All @@ -88,12 +99,16 @@ def write(self, ctx: Context) -> RTResult[Value]:
@args([])
@method
def close(self, _ctx: Context) -> RTResult[Value]:
security.security_prompt("disk_access")

res = RTResult[Value]()
self.file.close()
return res.success(Null.null())

@args([])
@method
def is_closed(self, _ctx: Context) -> RTResult[Value]:
security.security_prompt("disk_access")

res = RTResult[Value]()
return res.success(Boolean(self.file.closed))
11 changes: 11 additions & 0 deletions core/builtin_classes/requests_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from core.datatypes import HashMap, Null, String, Value, deradonify, radonify
from core.errors import RTError
from core.parser import Context, RTResult
from core import security


class RequestsObject(BuiltInObject):
Expand All @@ -18,6 +19,8 @@ def constructor(self) -> RTResult[Value]:
@args(["url", "headers"], [None, HashMap({})])
@method
def get(self, ctx: Context) -> RTResult[Value]:
security.security_prompt("network_access")

res = RTResult[Value]()
url = ctx.symbol_table.get("url")
assert url is not None
Expand All @@ -38,6 +41,8 @@ def get(self, ctx: Context) -> RTResult[Value]:
@args(["url", "data", "headers"], [None, HashMap({}), HashMap({})])
@method
def post(self, ctx: Context) -> RTResult[Value]:
security.security_prompt("network_access")

res = RTResult[Value]()
url = ctx.symbol_table.get("url")
assert url is not None
Expand Down Expand Up @@ -66,6 +71,8 @@ def post(self, ctx: Context) -> RTResult[Value]:
@args(["url", "data", "headers"], [None, HashMap({}), HashMap({})])
@method
def put(self, ctx: Context) -> RTResult[Value]:
security.security_prompt("network_access")

res = RTResult[Value]()
url = ctx.symbol_table.get("url")
assert url is not None
Expand Down Expand Up @@ -93,6 +100,8 @@ def put(self, ctx: Context) -> RTResult[Value]:
@args(["url", "headers"], [None, HashMap({})])
@method
def delete(self, ctx: Context) -> RTResult[Value]:
security.security_prompt("network_access")

res = RTResult[Value]()
url = ctx.symbol_table.get("url")
assert url is not None
Expand All @@ -113,6 +122,8 @@ def delete(self, ctx: Context) -> RTResult[Value]:
@args(["url", "data", "headers"], [None, HashMap({}), HashMap({})])
@method
def patch(self, ctx: Context) -> RTResult[Value]:
security.security_prompt("network_access")

res = RTResult[Value]()
url = ctx.symbol_table.get("url")
assert url is not None
Expand Down
3 changes: 3 additions & 0 deletions core/builtin_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
from sys import stdout
from typing import Callable, Generic, NoReturn, Optional, ParamSpec, Protocol, Sequence, Union, cast
from core import security

from core.datatypes import (
Array,
Expand Down Expand Up @@ -427,6 +428,8 @@ def execute_type(self, exec_ctx: Context) -> RTResult[Value]:

@args(["code", "ns"])
def execute_pyapi(self, exec_ctx: Context) -> RTResult[Value]:
security.security_prompt("pyapi_access")

res = RTResult[Value]()

code = exec_ctx.symbol_table.get("code")
Expand Down
38 changes: 38 additions & 0 deletions core/security.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from typing import Literal
from core.colortools import Log

# Define all types of security prompts
SecurityPromptType = Literal["pyapi_access", "disk_access", "network_access"]
type_messages: dict[str, str] = {
"pyapi_access": "This program is attempting to use the Python API",
"disk_access": "This program is attempting to access the disk",
"network_access": "This program is attempting to access the network",
}

# List of allowed actions (used during code execution)
allowed: dict[str, bool] = {}


# !!! Only used for tests !!!
def allow_all_permissions() -> None:
allowed["pyapi_access"] = True
allowed["disk_access"] = True
allowed["network_access"] = True


def security_prompt(type: SecurityPromptType) -> None:
# If action already allowed, continue
if type in allowed:
return
# Log the message and get a y/n prompt by user
print(f"{Log.deep_warning(f"[{type.upper()}]")} {Log.deep_info(type_messages[type], True)}. Continue execution?")
print(f"{Log.deep_purple("[Y/n] -> ")}", end="")
# If user agreed
if input().lower() == "y":
# Add action to allowed list
allowed[type] = True
return
# Exit program
print("Permission denied by user.")
exit(1)
return
15 changes: 11 additions & 4 deletions radon.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,7 @@ def main(argv: list[str]) -> None:
usage(program_name, sys.stderr)
print(f"ERROR: {arg} requires an argument", file=sys.stderr)
exit(1)
source_file = argv[0]
break # allow program to use remaining args
source_file = argv.pop(0)
case "--version" | "-v":
print(base_core.__version__)
exit(0)
Expand All @@ -131,8 +130,16 @@ def main(argv: list[str]) -> None:
usage(program_name, sys.stderr)
print(f"ERROR: {arg} requires an argument", file=sys.stderr)
exit(1)
command = argv[0]
break # allow program to use remaining args
command = argv.pop(0)
# These flags starting with --allow should only be used for testing, and not be allowed to be set by a user
case "--allow-all" | "-A":
base_core.security.allow_all_permissions()
case "--allow-disk" | "-D":
base_core.security.allowed["disk_access"] = True
case "--allow-py" | "-P":
base_core.security.allowed["pyapi_access"] = True
case "--allow-network" | "-W":
base_core.security.allowed["network_access"] = True
case _:
usage(program_name, sys.stderr)
print(f"ERROR: Unknown argument '{arg}'", file=sys.stderr)
Expand Down
4 changes: 3 additions & 1 deletion test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ def dump(self, path: str) -> None:


def run_test(test: str) -> Output:
proc = subprocess.run([sys.executable, "radon.py", "-s", test], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
proc = subprocess.run(
[sys.executable, "radon.py", "-s", test, "-A"], stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
return Output(
proc.returncode,
proc.stdout.decode("utf-8").replace("\r\n", "\n"),
Expand Down

0 comments on commit ab23b3d

Please sign in to comment.