Skip to content

Commit

Permalink
Implemented bulk replacing in files.
Browse files Browse the repository at this point in the history
  • Loading branch information
KOLANICH committed Apr 7, 2021
1 parent b25084a commit b617978
Show file tree
Hide file tree
Showing 9 changed files with 333 additions and 5 deletions.
134 changes: 134 additions & 0 deletions https_everywhere/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import asyncio
import sys
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from os import cpu_count
from pathlib import Path

from binaryornot.check import is_binary
from plumbum import cli

from .core import CombinedReplacerFactory
from .core.InBufferReplacer import InBufferReplacer
from .core.InFileReplacer import InFileReplacer
from .replacers.HEReplacer import HEReplacer
from .replacers.HSTSPreloadReplacer import HSTSPreloadReplacer


class OurInBufferReplacer(InBufferReplacer):
__slots__ = ()
FACS = CombinedReplacerFactory(
{
"preloads": HSTSPreloadReplacer,
"heRulesets": HEReplacer,
}
)

def __init__(self, preloads=None, heRulesets=None):
super().__init__(preloads=preloads, heRulesets=heRulesets)


class OurInFileReplacer(InFileReplacer):
def __init__(self, preloads=None, heRulesets=None):
super().__init__(OurInBufferReplacer(preloads=preloads, heRulesets=heRulesets))


class CLI(cli.Application):
"""HTTPSEverywhere-like URI rewriter"""


@CLI.subcommand("bulk")
class FileRewriteCLI(cli.Application):
"""Rewrites URIs in files. Use - to consume list of files from stdin. Don't use `find`, it is a piece of shit which is impossible to configure to skip .git dirs."""

__slots__ = ("_repl",)

@property
def repl(self):
if self._repl is None:
self._repl = OurInFileReplacer()
print(
len(self._repl.inBufferReplacer.singleURIReplacer.children[0].preloads),
"HSTS preloads",
)
print(
len(self._repl.inBufferReplacer.singleURIReplacer.children[1].rulesets), "HE rules"
)
return self._repl

def processEachFileName(self, l):
l = l.strip()
if l:
l = l.decode("utf-8")
return self.processEachFilePath(Path(l).resolve().absolute())

def processEachFilePath(self, p):
for pa in p.parts:
if not self.noSkipDot and pa[0] == ".":
print("Skipping ", p, ": dotfile")
return

if not p.is_dir():
if self.noSkipBinary or not is_binary(p):
self.repl(p)
else:
print("Skipping ", p, ": binary")

@asyncio.coroutine
def asyncMainPathsFromStdIn(self):
conc = []
asyncStdin = asyncio.StreamReader(loop=self.loop)
yield from self.loop.connect_read_pipe(
lambda: asyncio.StreamReaderProtocol(asyncStdin, loop=self.loop), sys.stdin
)
with ThreadPoolExecutor(max_workers=cpu_count()) as pool:
while not asyncStdin.at_eof():
l = yield from asyncStdin.readline()
yield from self.loop.run_in_executor(pool, partial(self.processEachFileName, l))

@asyncio.coroutine
def asyncMainPathsFromCLI(self, filesOrDirs):
try:
from tqdm import tqdm
except ImportError:

def tqdm(x):
return x

for fileOrDir in tqdm(filesOrDirs):
fileOrDir = Path(fileOrDir).resolve().absolute()
if fileOrDir.is_dir():
files = [el for el in fileOrDir.glob("**/*") if not el.is_dir()]
print(files)
else:
files = [fileOrDir]

if files:
with ThreadPoolExecutor(max_workers=cpu_count()) as pool:
for f in files:
yield from self.loop.run_in_executor(pool, partial(self.processEachFilePath, f))

noSkipBinary = cli.Flag(
["--no-skip-binary", "-n"],
help="Don't skip binary files. Allows usage without `binaryornot`",
default=False,
)
noSkipDot = cli.Flag(
["--no-skip-dotfiles", "-d"],
help="Don't skip files and dirs which name stem begins from dot.",
default=False,
)

def main(self, *filesOrDirs):
self._repl = None
self.loop = asyncio.get_event_loop()

if len(filesOrDirs) == 1 and filesOrDirs[0] == "0":
t = self.loop.create_task(self.asyncMainPathsFromStdIn())
else:
t = self.loop.create_task(self.asyncMainPathsFromCLI(filesOrDirs))
self.loop.run_until_complete(t)


if __name__ == "__main__":
CLI.run()
7 changes: 2 additions & 5 deletions https_everywhere/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from logging_helper import setup_logging

import urllib3
from urllib3.util.url import parse_url

import requests
from requests.adapters import HTTPAdapter
Expand All @@ -13,6 +12,7 @@
from ._chrome_preload_hsts import _preload_including_subdomains
from ._mozilla_preload_hsts import _preload_remove_negative
from ._util import _check_in
from .replacers.HSTSPreloadReplacer import apply_HSTS_preload

PY2 = str != "".__class__
if PY2:
Expand Down Expand Up @@ -155,10 +155,7 @@ def __init__(self, *args, **kwargs):

def get_redirect(self, url):
if url.startswith("http://"):
p = parse_url(url)
if _check_in(self._domains, p.host):
new_url = "https:" + url[5:]
return new_url
return apply_HSTS_preload(url, self._domains)

return super(PreloadHSTSAdapter, self).get_redirect(url)

Expand Down
28 changes: 28 additions & 0 deletions https_everywhere/core/InBufferReplacer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import re

from urllib3.util.url import parse_url

from . import ReplaceContext, SingleURIReplacer

uri_re = re.compile(
"(?:http|ftp):\\/\\/?((?:[\\w-]+)(?::[\\w-]+)?@)?[\\w\\.:()-]+(?:\\/[\\w\\.:()/-]*)?"
)


class InBufferReplacer(SingleURIReplacer):
__slots__ = ("singleURIReplacer",)
FACS = None

def __init__(self, **kwargs):
self.singleURIReplacer = self.__class__.FACS(**kwargs)

def _rePlaceFunc(self, m):
uri = m.group(0)
ctx = ReplaceContext(uri)
self.singleURIReplacer(ctx)
if ctx.count > 0:
return ctx.res
return uri

def __call__(self, inputStr):
return ReplaceContext(*uri_re.subn(self._rePlaceFunc, inputStr))
68 changes: 68 additions & 0 deletions https_everywhere/core/InFileReplacer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from os import close
from pathlib import Path
from tempfile import NamedTemporaryFile


class InFileReplacer:
__slots__ = ("inBufferReplacer", "encoding")

def __init__(self, inBufferReplacer, encoding="utf-8"):
self.inBufferReplacer = inBufferReplacer
self.encoding = encoding

def __call__(self, inputFilePath, safe=True):
if safe:
return self.safe(inputFilePath)
return self.unsafe(inputFilePath)

def safe(self, inputFilePath):
replaced = 0
fo = None

try:
with open(inputFilePath, "rt", encoding=self.encoding) as fi:
while True:
l = fi.readline()
if not l:
break

ctx = self.inBufferReplacer(l)
if ctx.count:
if not fo:
fo = NamedTemporaryFile(
mode="at",
encoding=self.encoding,
suffix="new",
prefix=inputFilePath.stem,
dir=inputFilePath.parent,
delete=False,
).__enter__()
pBk = fi.tell()
fi.seek(0)
beginning = fi.read(pBk - len(l))
fo.write(beginning)
fi.seek(pBk)
fo.write(ctx.res)
replaced += ctx.count
else:
if fo:
fo.write(l)

except BaseException as ex:
if fo:
fo.__exit__(type(ex), ex, None)
tmpFilePath = Path(fo.name)
if tmpFilePath.exists():
tmpFilePath.unlink()
raise ex
else:
if fo:
fo.__exit__(None, None, None)
Path(fo.name).rename(inputFilePath)
return replaced

def unsafe(self, inputFilePath):
from warnings import warn

warn("Unsafe in-place editing is not yet implamented")
return self.safe(inputFilePath)
48 changes: 48 additions & 0 deletions https_everywhere/core/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from functools import partial


class ReplaceContext:
__slots__ = ("res", "shouldStop", "count")

def __init__(self, res, count=0, shouldStop=False):
self.res = res
self.shouldStop = shouldStop
self.count = count


class SingleURIReplacer:
def __init__(self, arg):
raise NotImplementedError

def __call__(self, ctx):
raise NotImplementedError


class CombinedReplacer(SingleURIReplacer):
__slots__ = ("children",)

def __init__(self, children):
self.children = children

def __call__(self, ctx):
for r in self.children:
r(ctx)
if ctx.shouldStop:
break
return ctx


class CombinedReplacerFactory:
__slots__ = ("args2Ctors", "ctor")

def __init__(self, args2Ctors):
self.args2Ctors = args2Ctors

def _gen_replacers(self, kwargs):
for k, v in kwargs.items():
c = self.args2Ctors.get(k, None)
if c:
yield c(v)

def __call__(self, **kwargs):
return CombinedReplacer(tuple(self._gen_replacers(kwargs)))
19 changes: 19 additions & 0 deletions https_everywhere/replacers/HEReplacer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from .. import _rules
from .._rules import _get_rulesets, https_url_rewrite
from ..core import SingleURIReplacer


class HEReplacer(SingleURIReplacer):
__slots__ = ("rulesets",)

def __init__(self, rulesets):
if rulesets is None:
_get_rulesets()
rulesets = _rules._DATA
self.rulesets = rulesets

def __call__(self, ctx):
prevRes = ctx.res
ctx.res = https_url_rewrite(ctx.res, self.rulesets)
if prevRes != ctx.res:
ctx.count += 1
31 changes: 31 additions & 0 deletions https_everywhere/replacers/HSTSPreloadReplacer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from urllib3.util.url import parse_url

from .._chrome_preload_hsts import \
_preload_including_subdomains as _get_preload_chrome
from .._mozilla_preload_hsts import \
_preload_remove_negative as _get_preload_mozilla
from .._util import _check_in
from ..core import SingleURIReplacer


def apply_HSTS_preload(url, domains):
p = parse_url(url)
if _check_in(domains, p.host):
new_url = "https:" + url[len(p.scheme) + 1:]
return new_url
return url


class HSTSPreloadReplacer(SingleURIReplacer):
__slots__ = ("preloads",)

def __init__(self, preloads):
if preloads is None:
preloads = _get_preload_mozilla() | _get_preload_chrome()
self.preloads = preloads

def __call__(self, ctx):
prevRes = ctx.res
ctx.res = apply_HSTS_preload(ctx.res, self.preloads)
if prevRes != ctx.res:
ctx.count += 1
Empty file.
3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,7 @@
classifiers=classifiers.splitlines(),
tests_require=["unittest-expander", "lxml", "tldextract", "regex"],
# lxml is optional, needed for testing upstream rules
entry_points = {
"console_scripts": ["pyhttpeverywhere = https_everywhere.__main__:CLI"]
}
)

0 comments on commit b617978

Please sign in to comment.