Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bulk replace #44

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 193 additions & 0 deletions https_everywhere/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
import asyncio
import sys
import typing
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from os import cpu_count
from pathlib import Path

from binaryornot.check import is_binary
from plumbum import cli

from .core import CombinedReplacerFactory, ReplaceContext
from .core.InBufferReplacer import InBufferReplacer
from .core.InFileReplacer import InFileReplacer
from .replacers.HEReplacer import HEReplacer
from .replacers.HSTSPreloadReplacer import HSTSPreloadReplacer


class OurInBufferReplacer(InBufferReplacer):
__slots__ = ()
FACS = CombinedReplacerFactory(
{
"preloads": HSTSPreloadReplacer,
"heRulesets": HEReplacer,
}
)

def __init__(self, preloads=None, heRulesets=None):
super().__init__(preloads=preloads, heRulesets=heRulesets)


class OurInFileReplacer(InFileReplacer):
def __init__(self, preloads=None, heRulesets=None):
super().__init__(OurInBufferReplacer(preloads=preloads, heRulesets=heRulesets))


class CLI(cli.Application):
"""HTTPSEverywhere-like URI rewriter"""


class FileClassifier:
__slots__ = ("noSkipDot", "noSkipBinary")

def __init__(self, noSkipDot: bool, noSkipBinary: bool):
self.noSkipDot = noSkipDot
self.noSkipBinary = noSkipBinary

def __call__(self, p: Path) -> str:
for pa in p.parts:
if not self.noSkipDot and pa[0] == ".":
return "dotfile"

if not p.is_dir():
if p.is_file():
if self.noSkipBinary or not is_binary(p):
return ""
else:
return "binary"
else:
return "not regular file"


class FilesEnumerator:
__slots__ = ("classifier", "disallowedReportingCallback")

def __init__(self, classifier, disallowedReportingCallback):
self.classifier = classifier
self.disallowedReportingCallback = disallowedReportingCallback

def __call__(self, fileOrDir: Path):
reasonOfDisallowal = self.classifier(fileOrDir)
if not reasonOfDisallowal:
if fileOrDir.is_dir():
for f in fileOrDir.iterdir():
yield from self(f)
else:
yield fileOrDir
else:
self.disallowedReportingCallback(fileOrDir, reasonOfDisallowal)


@CLI.subcommand("bulk")
class FileRewriteCLI(cli.Application):
"""Rewrites URIs in files. Use - to consume list of files from stdin. Don't use `find`, it is a piece of shit which is impossible to configure to skip .git dirs."""

__slots__ = ("_repl",)

@property
def repl(self) -> InFileReplacer:
if self._repl is None:
self._repl = OurInFileReplacer()
print(
len(self._repl.inBufferReplacer.singleURIReplacer.children[0].preloads),
"HSTS preloads",
)
print(len(self._repl.inBufferReplacer.singleURIReplacer.children[1].rulesets), "HE rules")
return self._repl

def processEachFileName(self, ctx: ReplaceContext, l: str) -> Path:
l = l.strip()
if l:
l = l.decode("utf-8")
p = Path(l).resolve().absolute()
self.processEachFilePath(ctx, p)

def processEachFilePath(self, ctx: ReplaceContext, p: Path) -> None:
for pp in self.fe(p):
if self.trace:
print("Processing", pp, file=sys.stderr)
self.repl(ctx, pp)
if self.trace:
print("Processed", pp, file=sys.stderr)

@asyncio.coroutine
def asyncMainPathsFromStdIn(self):
conc = []
asyncStdin = asyncio.StreamReader(loop=self.loop)
yield from self.loop.connect_read_pipe(
lambda: asyncio.StreamReaderProtocol(asyncStdin, loop=self.loop), sys.stdin
)
with ThreadPoolExecutor(max_workers=cpu_count()) as pool:
while not asyncStdin.at_eof():
l = yield from asyncStdin.readline()
yield from self.loop.run_in_executor(pool, partial(self.processEachFileName, l))

@asyncio.coroutine
def asyncMainPathsFromCLI(self, filesOrDirs: typing.Iterable[typing.Union[Path, str]]):
try:
from tqdm import tqdm
except ImportError:

def tqdm(x):
return x

ctx = ReplaceContext(None)
replaceInEachFileWithContext = partial(self.repl, ctx)

with tqdm(filesOrDirs) as pb:
for fileOrDir in pb:
fileOrDir = Path(fileOrDir).resolve().absolute()

files = tuple(self.fe(fileOrDir))

if files:
with ThreadPoolExecutor(max_workers=cpu_count()) as pool:
for f in files:
if self.trace:
print("Processing", f, file=pb)
yield from self.loop.run_in_executor(pool, partial(replaceInEachFileWithContext, f))
if self.trace:
print("Processed", f, file=pb)

noSkipBinary = cli.Flag(
["--no-skip-binary", "-n"],
help="Don't skip binary files. Allows usage without `binaryornot`",
default=False,
)
noSkipDot = cli.Flag(
["--no-skip-dotfiles", "-d"],
help="Don't skip files and dirs which name stem begins from dot.",
default=False,
)
trace = cli.Flag(
["--trace", "-t"],
help="Print info about processing of regular files",
default=False,
)
noReportSkipped = cli.Flag(
["--no-report-skipped", "-s"],
help="Don't report about skipped files",
default=False,
)

def disallowedReportingCallback(self, fileOrDir: Path, reasonOfDisallowal: str) -> None:
if not self.noReportSkipped:
print("Skipping ", fileOrDir, ":", reasonOfDisallowal)

def main(self, *filesOrDirs):
self._repl = None # type: OurInFileReplacer
self.loop = asyncio.get_event_loop()

self.fc = FileClassifier(self.noSkipDot, self.noSkipBinary)
self.fe = FilesEnumerator(self.fc, self.disallowedReportingCallback)

if len(filesOrDirs) == 1 and filesOrDirs[0] == "0":
t = self.loop.create_task(self.asyncMainPathsFromStdIn())
else:
t = self.loop.create_task(self.asyncMainPathsFromCLI(filesOrDirs))
self.loop.run_until_complete(t)


if __name__ == "__main__":
CLI.run()
25 changes: 19 additions & 6 deletions https_everywhere/_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -821,39 +821,52 @@ def _get_ruleset(hostname, rulesets=None):

logger.debug("no ruleset matches {}".format(hostname))

from icecream import ic

def _remove_trailing_slash(url):
if url[-1] == "/":
url = url[:-1]
return url

def https_url_rewrite(url, rulesets=None):
orig_url = url
if isinstance(url, str):
# In HTTPSEverywhere, URLs must contain a '/'.
if url.replace("http://", "").find("/") == -1:
url += "/"
remove_trailing_slash_if_needed = _remove_trailing_slash
parsed_url = urlparse(url)
else:
remove_trailing_slash_if_needed = lambda x: x

parsed_url = url
if hasattr(parsed_url, "geturl"):
url = parsed_url.geturl()
else:
url = str(parsed_url)

if parsed_url.scheme is None or parsed_url.host is None:
return orig_url

try:
ruleset = _get_ruleset(parsed_url.host, rulesets)
except AttributeError:
ruleset = _get_ruleset(parsed_url.netloc, rulesets)

if not ruleset:
return url
return orig_url

if not isinstance(ruleset, _Ruleset):
ruleset = _Ruleset(ruleset[0], ruleset[1])

if ruleset.exclude_url(url):
return url
return orig_url

# process rules
for rule in ruleset.rules:
logger.debug("checking rule {} -> {}".format(rule[0], rule[1]))
try:
new_url = rule[0].sub(rule[1], url)
count, new_url = rule[0].subn(rule[1], url)
except Exception as e: # pragma: no cover
logger.warning(
"failed during rule {} -> {} , input {}: {}".format(
Expand All @@ -863,7 +876,7 @@ def https_url_rewrite(url, rulesets=None):
raise

# stop if this rule was a hit
if new_url != url:
return new_url
if count:
return remove_trailing_slash_if_needed(new_url)

return url
return orig_url
7 changes: 2 additions & 5 deletions https_everywhere/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from logging_helper import setup_logging

import urllib3
from urllib3.util.url import parse_url

import requests
from requests.adapters import HTTPAdapter
Expand All @@ -13,6 +12,7 @@
from ._chrome_preload_hsts import _preload_including_subdomains
from ._mozilla_preload_hsts import _preload_remove_negative
from ._util import _check_in
from .replacers.HSTSPreloadReplacer import apply_HSTS_preload

PY2 = str != "".__class__
if PY2:
Expand Down Expand Up @@ -155,10 +155,7 @@ def __init__(self, *args, **kwargs):

def get_redirect(self, url):
if url.startswith("http://"):
p = parse_url(url)
if _check_in(self._domains, p.host):
new_url = "https:" + url[5:]
return new_url
return apply_HSTS_preload(url, self._domains)

return super(PreloadHSTSAdapter, self).get_redirect(url)

Expand Down
43 changes: 43 additions & 0 deletions https_everywhere/core/InBufferReplacer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import re
import typing

from urllib3.util.url import parse_url

from . import ReplaceContext, SingleURIReplacer

uri_re_source = "(?:http|ftp):\\/\\/?((?:[\\w-]+)(?::[\\w-]+)?@)?[\\w\\.:(-]+(?:\\/[\\w\\.:(/-]*)?"
uri_re_text = re.compile(uri_re_source)
uri_re_binary = re.compile(uri_re_source.encode("ascii"))


class InBufferReplacer(SingleURIReplacer):
__slots__ = ("singleURIReplacer",)
FACS = None

def __init__(self, **kwargs):
self.singleURIReplacer = self.__class__.FACS(**kwargs)

def _rePlaceFuncCore(self, uri):
ctx = ReplaceContext(uri)
self.singleURIReplacer(ctx)
return ctx

def _rePlaceFuncText(self, m):
uri = m.group(0)
ctx = self._rePlaceFuncCore(uri)
if ctx.count > 0:
return ctx.res
return uri

def _rePlaceFuncBinary(self, m):
uri = m.group(0)
ctx = self._rePlaceFuncCore(uri.decode("utf-8"))
if ctx.count > 0:
return ctx.res.encode("utf-8")
return uri

def __call__(self, inputStr: typing.Union[str, bytes]) -> ReplaceContext:
if isinstance(inputStr, str):
return ReplaceContext(*uri_re_text.subn(self._rePlaceFuncText, inputStr))
else:
return ReplaceContext(*uri_re_binary.subn(self._rePlaceFuncBinary, inputStr))
Loading