"""CLI interface for pytokens.""" from __future__ import annotations import argparse import io import os.path import tokenize from typing import Iterable, NamedTuple import warnings import pytokens class CLIArgs: filepath: str validate: bool issue_128233_handling: bool def cli(argv: list[str] | None = None) -> int: """CLI interface.""" parser = argparse.ArgumentParser() parser.add_argument("filepath") parser.add_argument( "--no-128233-handling", dest="issue_128233_handling", action="store_false", ) parser.add_argument("--validate", action="store_true") args = parser.parse_args(argv, namespace=CLIArgs()) if os.path.isdir(args.filepath): files = find_all_python_files(args.filepath) verbose = False else: files = [args.filepath] verbose = True for filepath in sorted(files): with open(filepath, "rb") as file: try: encoding, read_bytes = tokenize.detect_encoding(file.readline) except SyntaxError: if args.validate: # Broken `# coding` comment, tokenizer bails, skip file print("\033[1;33mS\033[0m", end="", flush=True) continue raise source = b"".join(read_bytes) + file.read() if args.validate: validate( filepath, source, encoding, verbose=verbose, issue_128233_handling=args.issue_128233_handling, ) else: source_str = source.decode(encoding) for token in pytokens.tokenize( source_str, issue_128233_handling=args.issue_128233_handling, ): token_source = source_str[token.start_index : token.end_index] print(repr(token_source), token) return 0 class TokenTuple(NamedTuple): type: str start: tuple[int, int] end: tuple[int, int] def validate( filepath: str, source: bytes, encoding: str, *, issue_128233_handling: bool, verbose: bool = True, ) -> None: """Validate the source code.""" warnings.simplefilter("ignore") # Ensure all line endings have newline as a valid index if len(source) == 0 or source[-1:] != b"\n": source = source + b"\n" # Same as .splitlines(keepends=True), but doesn't split on linefeeds i.e. \x0c sourcelines = [line + b"\n" for line in source.split(b"\n")] # For that last newline token that exists on an imaginary line sometimes sourcelines.append(b"\n") source_file = io.BytesIO(source) builtin_tokens = tokenize.tokenize(source_file.readline) # drop the encoding token next(builtin_tokens) try: expected_tokens_unprocessed = [ TokenTuple(tokenize.tok_name[token.type], token.start, token.end) for token in builtin_tokens ] except tokenize.TokenError: print("\033[1;33mS\033[0m", end="", flush=True) return expected_tokens = [expected_tokens_unprocessed[0]] for index, token in enumerate(expected_tokens_unprocessed[1:], start=1): last_token = expected_tokens[-1] current_token = token # Merge consecutive FSTRING_MIDDLE tokens. it's weird cpython has it like that. if current_token.type == last_token.type == "FSTRING_MIDDLE": expected_tokens.pop() current_token = TokenTuple( current_token.type, last_token.start, current_token.end, ) if index + 1 < len(expected_tokens_unprocessed): # When an FSTRING_MIDDLE ends with a `{{{` like f'x{{{1}', Python eats # the last { char as well as its end index, so we get a `x{` token # instead of the expected `x{{` token. This fixes that case. Pretty # much always there should be no gap between an fstring-middle ending # and the { op after it. # Same deal for `}}}"` next_token = expected_tokens_unprocessed[index + 1] if ( (current_token.type == "FSTRING_MIDDLE" and next_token.type == "OP") or ( current_token.type == "FSTRING_MIDDLE" and next_token.type == "FSTRING_END" ) and next_token.start[0] == current_token.end[0] and next_token.start[1] > current_token.end[1] ): expected_tokens.append( TokenTuple( current_token.type, current_token.start, next_token.start, ) ) continue expected_tokens.append(current_token) source_string = source.decode(encoding) our_tokens = ( TokenTuple( token.type.to_python_token(), (token.start_line, token.start_col), (token.end_line, token.end_col), ) for token in pytokens.tokenize( source_string, issue_128233_handling=issue_128233_handling ) if token.type != pytokens.TokenType.whitespace ) for builtin_token, our_token in zip(expected_tokens, our_tokens, strict=True): mismatch = builtin_token != our_token if mismatch or verbose: print("EXPECTED", builtin_token) print("---- GOT", our_token) if mismatch: print("Filepath:", filepath) print("\033[1;31mF\033[0m", end="", flush=True) # raise AssertionError("Tokens do not match") return print("\033[1;32m.\033[0m", end="", flush=True) def find_all_python_files(directory: str) -> Iterable[str]: """Recursively find all Python files in the given directory.""" python_files = set() for root, _, files in os.walk(directory, followlinks=False): for file in files: if file.endswith(".py"): python_files.add(os.path.join(root, file)) return python_files