Improved check_version_increment script

This commit is contained in:
2025-05-20 11:19:10 +01:00
parent a0da4cf05f
commit 9b7aa9a217

View File

@@ -1,29 +1,102 @@
#!/usr/bin/env python #!/usr/bin/env python
"""A Script for checking that the version number has been incremented between pushes."""
from pathlib import Path from pathlib import Path
import re import re
import shutil
import subprocess import subprocess
import sys import sys
import tomllib import tomllib
ALLOWED_EXECUTABLES: list[str] = [
"git",
"uv",
]
def run_command(command):
return subprocess.run( def validate_command(command: str) -> list[str]:
command, shell=True, capture_output=True, text=True, check=True """Validate and guard command calls.
Args:
command (str): the command to be validated.
Returns:
str: the validated command
Raises:
FileNotFoundError: if the command was not found
"""
cmd: list[str] = command.split().copy()
if cmd[0] not in ALLOWED_EXECUTABLES:
msg = f"Command {cmd} is not allowed!"
raise PermissionError(msg)
call = shutil.which(cmd[0])
if not call:
msg = f"Command {call} not found!"
raise FileNotFoundError(msg)
cmd[0] = call
if cmd[0] == "uv" and cmd[1] == "run":
cmd = [cmd[0], cmd[1], *validate_command("".join(cmd[2:]))]
return cmd
def run_command(command: str) -> str:
"""Run a command and get its output as a string.
Args:
command (str): The command to run.
Returns:
The output returned on stdout.
"""
cmd = validate_command(command)
return subprocess.run( # noqa: S603
cmd,
capture_output=True,
text=True,
check=True,
).stdout.strip() ).stdout.strip()
def get_version_file_from_pyproject(): def get_version_file_from_pyproject() -> str:
"""Get the path to the version file from the pyproject file.
Returns:
str: the path to the project version file
Raises:
FileNotFoundError: the pyproject file is not found
KeyError: the pyproject file does not reference a version file
"""
try: try:
with Path("pyproject.toml").open("rb") as f: with Path("pyproject.toml").open("rb") as f:
return tomllib.load(f)["tool"]["hatch"]["version"]["path"] return tomllib.load(f)["tool"]["hatch"]["version"]["path"]
except FileNotFoundError: except FileNotFoundError as e:
raise FileNotFoundError("Project has no pyproject.toml file!") msg = "Project has no pyproject.toml file!"
except KeyError: raise FileNotFoundError(msg) from e
raise KeyError("Attribute `tool.hatch.version.path` not found in pyproject.toml") except KeyError as e:
msg = "Attribute `tool.hatch.version.path` not found in pyproject.toml"
raise KeyError(msg) from e
def get_remote_version(remote, branch, version_file): def get_remote_version(remote: str, branch: str, version_file: str) -> str:
"""Get the version from the repository remote.
Args:
remote (str): The remote to fetch the version string from.
branch (str): The branch to fetch the version string from.
version_file (str): The file to fetch the version string from.
Returns:
str: The version string.
Raises:
AttributeError: NO `__version__` attribute was found in `version_file`.
"""
remote_file = f"{remote}/{branch}:{version_file}" remote_file = f"{remote}/{branch}:{version_file}"
match = re.search( match = re.search(
r"__version__\s*=\s*['\"]([^'\"]+)['\"]", r"__version__\s*=\s*['\"]([^'\"]+)['\"]",
@@ -31,11 +104,12 @@ def get_remote_version(remote, branch, version_file):
) )
if match: if match:
return match.group(1) return match.group(1)
else: msg = f"No `__version__` attribute found in {remote_file}"
raise AttributeError(f"No `__version__` attribute found in {remote_file}") raise AttributeError(msg)
def main(): def main() -> None:
"""Entrypoint for this script."""
version_file = get_version_file_from_pyproject() version_file = get_version_file_from_pyproject()
branch = run_command("git rev-parse --abbrev-ref HEAD") branch = run_command("git rev-parse --abbrev-ref HEAD")
remote = sys.argv[1] if len(sys.argv) > 1 else "origin" remote = sys.argv[1] if len(sys.argv) > 1 else "origin"