From 3475f6e7e550a6935616dfb9364fb1d3e7f77a0e Mon Sep 17 00:00:00 2001 From: Cian Hughes Date: Tue, 20 May 2025 11:19:56 +0100 Subject: [PATCH] Improved `check_version_increment` script --- .../hooks/pre-push/check_version_increment.py | 98 ++++++++++++++++--- 1 file changed, 86 insertions(+), 12 deletions(-) diff --git a/scripts/hooks/pre-push/check_version_increment.py b/scripts/hooks/pre-push/check_version_increment.py index a7295a2..f34c322 100755 --- a/scripts/hooks/pre-push/check_version_increment.py +++ b/scripts/hooks/pre-push/check_version_increment.py @@ -1,29 +1,102 @@ #!/usr/bin/env python +"""A Script for checking that the version number has been incremented between pushes.""" from pathlib import Path import re +import shutil import subprocess import sys import tomllib +ALLOWED_EXECUTABLES: list[str] = [ + "git", + "uv", +] -def run_command(command): - return subprocess.run( - command, shell=True, capture_output=True, text=True, check=True + +def validate_command(command: str) -> list[str]: + """Validate and guard command calls. + + Args: + command (str): the command to be validated. + + Returns: + str: the validated command + + Raises: + FileNotFoundError: if the command was not found + + """ + cmd: list[str] = command.split().copy() + if cmd[0] not in ALLOWED_EXECUTABLES: + msg = f"Command {cmd} is not allowed!" + raise PermissionError(msg) + call = shutil.which(cmd[0]) + if not call: + msg = f"Command {call} not found!" + raise FileNotFoundError(msg) + cmd[0] = call + if cmd[0] == "uv" and cmd[1] == "run": + cmd = [cmd[0], cmd[1], *validate_command("".join(cmd[2:]))] + return cmd + + +def run_command(command: str) -> str: + """Run a command and get its output as a string. + + Args: + command (str): The command to run. + + Returns: + The output returned on stdout. + + """ + cmd = validate_command(command) + return subprocess.run( # noqa: S603 + cmd, + capture_output=True, + text=True, + check=True, ).stdout.strip() -def get_version_file_from_pyproject(): +def get_version_file_from_pyproject() -> str: + """Get the path to the version file from the pyproject file. + + Returns: + str: the path to the project version file + + Raises: + FileNotFoundError: the pyproject file is not found + KeyError: the pyproject file does not reference a version file + + """ try: with Path("pyproject.toml").open("rb") as f: return tomllib.load(f)["tool"]["hatch"]["version"]["path"] - except FileNotFoundError: - raise FileNotFoundError("Project has no pyproject.toml file!") - except KeyError: - raise KeyError("Attribute `tool.hatch.version.path` not found in pyproject.toml") + except FileNotFoundError as e: + msg = "Project has no pyproject.toml file!" + raise FileNotFoundError(msg) from e + except KeyError as e: + msg = "Attribute `tool.hatch.version.path` not found in pyproject.toml" + raise KeyError(msg) from e -def get_remote_version(remote, branch, version_file): +def get_remote_version(remote: str, branch: str, version_file: str) -> str: + """Get the version from the repository remote. + + Args: + remote (str): The remote to fetch the version string from. + branch (str): The branch to fetch the version string from. + version_file (str): The file to fetch the version string from. + + Returns: + str: The version string. + + Raises: + AttributeError: NO `__version__` attribute was found in `version_file`. + + """ remote_file = f"{remote}/{branch}:{version_file}" match = re.search( r"__version__\s*=\s*['\"]([^'\"]+)['\"]", @@ -31,11 +104,12 @@ def get_remote_version(remote, branch, version_file): ) if match: return match.group(1) - else: - raise AttributeError(f"No `__version__` attribute found in {remote_file}") + msg = f"No `__version__` attribute found in {remote_file}" + raise AttributeError(msg) -def main(): +def main() -> None: + """Entrypoint for this script.""" version_file = get_version_file_from_pyproject() branch = run_command("git rev-parse --abbrev-ref HEAD") remote = sys.argv[1] if len(sys.argv) > 1 else "origin"