From 3108c2a13ae8a7372c4a74c0cfb90cc9ae798afd Mon Sep 17 00:00:00 2001 From: KOLANICH Date: Tue, 6 Apr 2021 17:34:40 +0300 Subject: [PATCH] Implemented bulk replacing in files. --- https_everywhere/__main__.py | 193 ++++++++++++++++++ https_everywhere/_rules.py | 25 ++- https_everywhere/adapter.py | 7 +- https_everywhere/core/InBufferReplacer.py | 43 ++++ https_everywhere/core/InFileReplacer.py | 144 +++++++++++++ https_everywhere/core/__init__.py | 48 +++++ https_everywhere/replacers/HEReplacer.py | 19 ++ .../replacers/HSTSPreloadReplacer.py | 31 +++ https_everywhere/replacers/__init__.py | 0 setup.py | 3 + 10 files changed, 502 insertions(+), 11 deletions(-) create mode 100644 https_everywhere/__main__.py create mode 100644 https_everywhere/core/InBufferReplacer.py create mode 100644 https_everywhere/core/InFileReplacer.py create mode 100644 https_everywhere/core/__init__.py create mode 100644 https_everywhere/replacers/HEReplacer.py create mode 100644 https_everywhere/replacers/HSTSPreloadReplacer.py create mode 100644 https_everywhere/replacers/__init__.py diff --git a/https_everywhere/__main__.py b/https_everywhere/__main__.py new file mode 100644 index 0000000..e2dd06e --- /dev/null +++ b/https_everywhere/__main__.py @@ -0,0 +1,193 @@ +import asyncio +import sys +import typing +from concurrent.futures import ThreadPoolExecutor +from functools import partial +from os import cpu_count +from pathlib import Path + +from binaryornot.check import is_binary +from plumbum import cli + +from .core import CombinedReplacerFactory, ReplaceContext +from .core.InBufferReplacer import InBufferReplacer +from .core.InFileReplacer import InFileReplacer +from .replacers.HEReplacer import HEReplacer +from .replacers.HSTSPreloadReplacer import HSTSPreloadReplacer + + +class OurInBufferReplacer(InBufferReplacer): + __slots__ = () + FACS = CombinedReplacerFactory( + { + "preloads": HSTSPreloadReplacer, + "heRulesets": HEReplacer, + } + ) + + def __init__(self, preloads=None, heRulesets=None): + super().__init__(preloads=preloads, heRulesets=heRulesets) + + +class OurInFileReplacer(InFileReplacer): + def __init__(self, preloads=None, heRulesets=None): + super().__init__(OurInBufferReplacer(preloads=preloads, heRulesets=heRulesets)) + + +class CLI(cli.Application): + """HTTPSEverywhere-like URI rewriter""" + + +class FileClassifier: + __slots__ = ("noSkipDot", "noSkipBinary") + + def __init__(self, noSkipDot: bool, noSkipBinary: bool): + self.noSkipDot = noSkipDot + self.noSkipBinary = noSkipBinary + + def __call__(self, p: Path) -> str: + for pa in p.parts: + if not self.noSkipDot and pa[0] == ".": + return "dotfile" + + if not p.is_dir(): + if p.is_file(): + if self.noSkipBinary or not is_binary(p): + return "" + else: + return "binary" + else: + return "not regular file" + + +class FilesEnumerator: + __slots__ = ("classifier", "disallowedReportingCallback") + + def __init__(self, classifier, disallowedReportingCallback): + self.classifier = classifier + self.disallowedReportingCallback = disallowedReportingCallback + + def __call__(self, fileOrDir: Path): + reasonOfDisallowal = self.classifier(fileOrDir) + if not reasonOfDisallowal: + if fileOrDir.is_dir(): + for f in fileOrDir.iterdir(): + yield from self(f) + else: + yield fileOrDir + else: + self.disallowedReportingCallback(fileOrDir, reasonOfDisallowal) + + +@CLI.subcommand("bulk") +class FileRewriteCLI(cli.Application): + """Rewrites URIs in files. Use - to consume list of files from stdin. Don't use `find`, it is a piece of shit which is impossible to configure to skip .git dirs.""" + + __slots__ = ("_repl",) + + @property + def repl(self) -> InFileReplacer: + if self._repl is None: + self._repl = OurInFileReplacer() + print( + len(self._repl.inBufferReplacer.singleURIReplacer.children[0].preloads), + "HSTS preloads", + ) + print(len(self._repl.inBufferReplacer.singleURIReplacer.children[1].rulesets), "HE rules") + return self._repl + + def processEachFileName(self, ctx: ReplaceContext, l: str) -> Path: + l = l.strip() + if l: + l = l.decode("utf-8") + p = Path(l).resolve().absolute() + self.processEachFilePath(ctx, p) + + def processEachFilePath(self, ctx: ReplaceContext, p: Path) -> None: + for pp in self.fe(p): + if self.trace: + print("Processing", pp, file=sys.stderr) + self.repl(ctx, pp) + if self.trace: + print("Processed", pp, file=sys.stderr) + + @asyncio.coroutine + def asyncMainPathsFromStdIn(self): + conc = [] + asyncStdin = asyncio.StreamReader(loop=self.loop) + yield from self.loop.connect_read_pipe( + lambda: asyncio.StreamReaderProtocol(asyncStdin, loop=self.loop), sys.stdin + ) + with ThreadPoolExecutor(max_workers=cpu_count()) as pool: + while not asyncStdin.at_eof(): + l = yield from asyncStdin.readline() + yield from self.loop.run_in_executor(pool, partial(self.processEachFileName, l)) + + @asyncio.coroutine + def asyncMainPathsFromCLI(self, filesOrDirs: typing.Iterable[typing.Union[Path, str]]): + try: + from tqdm import tqdm + except ImportError: + + def tqdm(x): + return x + + ctx = ReplaceContext(None) + replaceInEachFileWithContext = partial(self.repl, ctx) + + with tqdm(filesOrDirs) as pb: + for fileOrDir in pb: + fileOrDir = Path(fileOrDir).resolve().absolute() + + files = tuple(self.fe(fileOrDir)) + + if files: + with ThreadPoolExecutor(max_workers=cpu_count()) as pool: + for f in files: + if self.trace: + print("Processing", f, file=pb) + yield from self.loop.run_in_executor(pool, partial(replaceInEachFileWithContext, f)) + if self.trace: + print("Processed", f, file=pb) + + noSkipBinary = cli.Flag( + ["--no-skip-binary", "-n"], + help="Don't skip binary files. Allows usage without `binaryornot`", + default=False, + ) + noSkipDot = cli.Flag( + ["--no-skip-dotfiles", "-d"], + help="Don't skip files and dirs which name stem begins from dot.", + default=False, + ) + trace = cli.Flag( + ["--trace", "-t"], + help="Print info about processing of regular files", + default=False, + ) + noReportSkipped = cli.Flag( + ["--no-report-skipped", "-s"], + help="Don't report about skipped files", + default=False, + ) + + def disallowedReportingCallback(self, fileOrDir: Path, reasonOfDisallowal: str) -> None: + if not self.noReportSkipped: + print("Skipping ", fileOrDir, ":", reasonOfDisallowal) + + def main(self, *filesOrDirs): + self._repl = None # type: OurInFileReplacer + self.loop = asyncio.get_event_loop() + + self.fc = FileClassifier(self.noSkipDot, self.noSkipBinary) + self.fe = FilesEnumerator(self.fc, self.disallowedReportingCallback) + + if len(filesOrDirs) == 1 and filesOrDirs[0] == "0": + t = self.loop.create_task(self.asyncMainPathsFromStdIn()) + else: + t = self.loop.create_task(self.asyncMainPathsFromCLI(filesOrDirs)) + self.loop.run_until_complete(t) + + +if __name__ == "__main__": + CLI.run() diff --git a/https_everywhere/_rules.py b/https_everywhere/_rules.py index f5f33f3..312bc13 100644 --- a/https_everywhere/_rules.py +++ b/https_everywhere/_rules.py @@ -821,39 +821,52 @@ def _get_ruleset(hostname, rulesets=None): logger.debug("no ruleset matches {}".format(hostname)) +from icecream import ic + +def _remove_trailing_slash(url): + if url[-1] == "/": + url = url[:-1] + return url def https_url_rewrite(url, rulesets=None): + orig_url = url if isinstance(url, str): # In HTTPSEverywhere, URLs must contain a '/'. if url.replace("http://", "").find("/") == -1: url += "/" + remove_trailing_slash_if_needed = _remove_trailing_slash parsed_url = urlparse(url) else: + remove_trailing_slash_if_needed = lambda x: x + parsed_url = url if hasattr(parsed_url, "geturl"): url = parsed_url.geturl() else: url = str(parsed_url) + if parsed_url.scheme is None or parsed_url.host is None: + return orig_url + try: ruleset = _get_ruleset(parsed_url.host, rulesets) except AttributeError: ruleset = _get_ruleset(parsed_url.netloc, rulesets) if not ruleset: - return url + return orig_url if not isinstance(ruleset, _Ruleset): ruleset = _Ruleset(ruleset[0], ruleset[1]) if ruleset.exclude_url(url): - return url + return orig_url # process rules for rule in ruleset.rules: logger.debug("checking rule {} -> {}".format(rule[0], rule[1])) try: - new_url = rule[0].sub(rule[1], url) + count, new_url = rule[0].subn(rule[1], url) except Exception as e: # pragma: no cover logger.warning( "failed during rule {} -> {} , input {}: {}".format( @@ -863,7 +876,7 @@ def https_url_rewrite(url, rulesets=None): raise # stop if this rule was a hit - if new_url != url: - return new_url + if count: + return remove_trailing_slash_if_needed(new_url) - return url + return orig_url diff --git a/https_everywhere/adapter.py b/https_everywhere/adapter.py index 4d16a5c..3783e7c 100644 --- a/https_everywhere/adapter.py +++ b/https_everywhere/adapter.py @@ -3,7 +3,6 @@ from logging_helper import setup_logging import urllib3 -from urllib3.util.url import parse_url import requests from requests.adapters import HTTPAdapter @@ -13,6 +12,7 @@ from ._chrome_preload_hsts import _preload_including_subdomains from ._mozilla_preload_hsts import _preload_remove_negative from ._util import _check_in +from .replacers.HSTSPreloadReplacer import apply_HSTS_preload PY2 = str != "".__class__ if PY2: @@ -155,10 +155,7 @@ def __init__(self, *args, **kwargs): def get_redirect(self, url): if url.startswith("http://"): - p = parse_url(url) - if _check_in(self._domains, p.host): - new_url = "https:" + url[5:] - return new_url + return apply_HSTS_preload(url, self._domains) return super(PreloadHSTSAdapter, self).get_redirect(url) diff --git a/https_everywhere/core/InBufferReplacer.py b/https_everywhere/core/InBufferReplacer.py new file mode 100644 index 0000000..7feba67 --- /dev/null +++ b/https_everywhere/core/InBufferReplacer.py @@ -0,0 +1,43 @@ +import re +import typing + +from urllib3.util.url import parse_url + +from . import ReplaceContext, SingleURIReplacer + +uri_re_source = "(?:http|ftp):\\/\\/?((?:[\\w-]+)(?::[\\w-]+)?@)?[\\w\\.:(-]+(?:\\/[\\w\\.:(/-]*)?" +uri_re_text = re.compile(uri_re_source) +uri_re_binary = re.compile(uri_re_source.encode("ascii")) + + +class InBufferReplacer(SingleURIReplacer): + __slots__ = ("singleURIReplacer",) + FACS = None + + def __init__(self, **kwargs): + self.singleURIReplacer = self.__class__.FACS(**kwargs) + + def _rePlaceFuncCore(self, uri): + ctx = ReplaceContext(uri) + self.singleURIReplacer(ctx) + return ctx + + def _rePlaceFuncText(self, m): + uri = m.group(0) + ctx = self._rePlaceFuncCore(uri) + if ctx.count > 0: + return ctx.res + return uri + + def _rePlaceFuncBinary(self, m): + uri = m.group(0) + ctx = self._rePlaceFuncCore(uri.decode("utf-8")) + if ctx.count > 0: + return ctx.res.encode("utf-8") + return uri + + def __call__(self, inputStr: typing.Union[str, bytes]) -> ReplaceContext: + if isinstance(inputStr, str): + return ReplaceContext(*uri_re_text.subn(self._rePlaceFuncText, inputStr)) + else: + return ReplaceContext(*uri_re_binary.subn(self._rePlaceFuncBinary, inputStr)) diff --git a/https_everywhere/core/InFileReplacer.py b/https_everywhere/core/InFileReplacer.py new file mode 100644 index 0000000..78c6674 --- /dev/null +++ b/https_everywhere/core/InFileReplacer.py @@ -0,0 +1,144 @@ +import typing +from os import close +from pathlib import Path +from shutil import copystat +from tempfile import NamedTemporaryFile +from warnings import warn + +from . import ReplaceContext +from .InBufferReplacer import InBufferReplacer + +chardet = None # lazily initialized +fallbackDefaultEncoding = "utf-8" + + +class InFileReplacer: + __slots__ = ("inBufferReplacer", "encoding") + + def __init__(self, inBufferReplacer: InBufferReplacer, encoding: typing.Optional[str] = None) -> None: + global chardet + self.inBufferReplacer = inBufferReplacer + self.encoding = encoding + if encoding is None: + try: + import chardet + except ImportError: + warn("`chardet` is not installed. Assumming utf-8. There will be errors if another encoding is used.") + encoding = fallbackDefaultEncoding + + def __call__(self, ctx: ReplaceContext, inputFilePath: Path, safe: bool = True) -> None: + if safe: + return self.safe(ctx, inputFilePath) + return self.unsafe(ctx, inputFilePath) + + def safe(self, ctx: ReplaceContext, inputFilePath: Path) -> None: + fo = None + tmpFilePath = None + + encodingsAccum = [] + + if not self.encoding: + encDetector = chardet.UniversalDetector() + encodingPrevConfidence = -1.0 + encoding = fallbackDefaultEncoding + + try: + with open(inputFilePath, "rb") as fi: + while True: + origLineStart = fi.tell() + l = fi.readline() + origLineEnd = fi.tell() + origLineLength = origLineEnd - origLineStart + + if not l: + break + + if not self.encoding: + # black magic here. UniversalDetector doesn't return correct encoding unless closed in some cases. So we close it. Then modify its internal state to make it look like if it is open and accept strings. + encDetector.feed(l) + encDetector.close() + encDetector.done = False + res = encDetector.result + #ic(res) + detectedConfidence = res["confidence"] + detectedEncoding = res["encoding"] + encodings2try = [(detectedConfidence, detectedEncoding)] + if detectedEncoding != encoding: + encodings2try.append((encodingPrevConfidence, encoding)) + + if detectedConfidence < encodingPrevConfidence: + encodings2try.reverse() + + encodings2try.append((0, fallbackDefaultEncoding)) + + encoding = None + decodedLine = None + for curConfidence, curEnc in encodings2try: + if not curEnc: + continue + try: + #print("Trying", curEnc, ", confidence=", curConfidence) + decodedLine = l.decode(curEnc) + except ValueError: + #print("Fail") + pass + except LookupError: + warn("Unsupported encoding: " + curEnc) + else: + #print("Success") + encoding = curEnc + encodingPrevConfidence = curConfidence + break + if decodedLine is None: + warn("No supported encoding has been detected for the line " + repr(l) + "; Processing as binary.") + encoding = None + decodedLine = l + else: + encoding = self.encoding + decodedLine = l.decode(encoding) + + cctx = self.inBufferReplacer(decodedLine) + if cctx.count: + if not fo: + fo = NamedTemporaryFile( + mode="ab", + encoding=None, + suffix="new", + prefix=inputFilePath.stem, + dir=inputFilePath.parent, + delete=False, + ).__enter__() + tmpFilePath = Path(fo.name) + fi.seek(0) + beginning = fi.read(origLineStart) + fo.flush() + fo.write(beginning) + fo.flush() + fi.seek(origLineEnd) + fi.flush() + if encoding: + fo.write(cctx.res.encode(encoding)) + else: + fo.write(cctx.res) + ctx.count += ctx.count + else: + if fo: + fo.write(l) + + except BaseException as ex: + if fo: + fo.__exit__(type(ex), ex, None) + if tmpFilePath.exists(): + tmpFilePath.unlink() + raise + else: + if fo: + fo.__exit__(None, None, None) + copystat(inputFilePath, tmpFilePath) + tmpFilePath.rename(inputFilePath) + + def unsafe(self, ctx: ReplaceContext, inputFilePath: Path) -> None: + from warnings import warn + + warn("Unsafe in-place editing is not yet implamented") + return self.safe(ctx, inputFilePath) diff --git a/https_everywhere/core/__init__.py b/https_everywhere/core/__init__.py new file mode 100644 index 0000000..bb35dd5 --- /dev/null +++ b/https_everywhere/core/__init__.py @@ -0,0 +1,48 @@ +from functools import partial + + +class ReplaceContext: + __slots__ = ("res", "shouldStop", "count") + + def __init__(self, res, count=0, shouldStop=False): + self.res = res + self.shouldStop = shouldStop + self.count = count + + +class SingleURIReplacer: + def __init__(self, arg): + raise NotImplementedError + + def __call__(self, ctx): + raise NotImplementedError + + +class CombinedReplacer(SingleURIReplacer): + __slots__ = ("children",) + + def __init__(self, children): + self.children = children + + def __call__(self, ctx): + for r in self.children: + r(ctx) + if ctx.shouldStop: + break + return ctx + + +class CombinedReplacerFactory: + __slots__ = ("args2Ctors", "ctor") + + def __init__(self, args2Ctors): + self.args2Ctors = args2Ctors + + def _gen_replacers(self, kwargs): + for k, v in kwargs.items(): + c = self.args2Ctors.get(k, None) + if c: + yield c(v) + + def __call__(self, **kwargs): + return CombinedReplacer(tuple(self._gen_replacers(kwargs))) diff --git a/https_everywhere/replacers/HEReplacer.py b/https_everywhere/replacers/HEReplacer.py new file mode 100644 index 0000000..f31dbf7 --- /dev/null +++ b/https_everywhere/replacers/HEReplacer.py @@ -0,0 +1,19 @@ +from .. import _rules +from .._rules import _get_rulesets, https_url_rewrite +from ..core import SingleURIReplacer + + +class HEReplacer(SingleURIReplacer): + __slots__ = ("rulesets",) + + def __init__(self, rulesets): + if rulesets is None: + _get_rulesets() + rulesets = _rules._DATA + self.rulesets = rulesets + + def __call__(self, ctx): + prevRes = ctx.res + ctx.res = https_url_rewrite(ctx.res, self.rulesets) + if prevRes != ctx.res: + ctx.count += 1 diff --git a/https_everywhere/replacers/HSTSPreloadReplacer.py b/https_everywhere/replacers/HSTSPreloadReplacer.py new file mode 100644 index 0000000..74667a2 --- /dev/null +++ b/https_everywhere/replacers/HSTSPreloadReplacer.py @@ -0,0 +1,31 @@ +from urllib3.util.url import parse_url + +from .._chrome_preload_hsts import \ + _preload_including_subdomains as _get_preload_chrome +from .._mozilla_preload_hsts import \ + _preload_remove_negative as _get_preload_mozilla +from .._util import _check_in +from ..core import SingleURIReplacer + + +def apply_HSTS_preload(url, domains): + p = parse_url(url) + if p.scheme is not None and p.host is not None and _check_in(domains, p.host): + new_url = "https:" + url[len(p.scheme) + 1:] + return new_url + return url + + +class HSTSPreloadReplacer(SingleURIReplacer): + __slots__ = ("preloads",) + + def __init__(self, preloads): + if preloads is None: + preloads = _get_preload_mozilla() | _get_preload_chrome() + self.preloads = preloads + + def __call__(self, ctx): + prevRes = ctx.res + ctx.res = apply_HSTS_preload(ctx.res, self.preloads) + if prevRes != ctx.res: + ctx.count += 1 diff --git a/https_everywhere/replacers/__init__.py b/https_everywhere/replacers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/setup.py b/setup.py index 8b4e516..49c71d5 100755 --- a/setup.py +++ b/setup.py @@ -61,4 +61,7 @@ classifiers=classifiers.splitlines(), tests_require=["unittest-expander", "lxml", "tldextract", "regex"], # lxml is optional, needed for testing upstream rules + entry_points = { + "console_scripts": ["pyhttpeverywhere = https_everywhere.__main__:CLI"] + } )