Files
invenio-config-iform/scripts/hooks/pre-push/check_version_increment.py

127 lines
3.4 KiB
Python
Executable File

#!/usr/bin/env python
"""A Script for checking that the version number has been incremented between pushes."""
from pathlib import Path
import re
import shutil
import subprocess
import sys
import tomllib
ALLOWED_EXECUTABLES: list[str] = [
"git",
"uv",
]
def validate_command(command: str) -> list[str]:
"""Validate and guard command calls.
Args:
command (str): the command to be validated.
Returns:
str: the validated command
Raises:
FileNotFoundError: if the command was not found
"""
cmd: list[str] = command.split().copy()
if cmd[0] not in ALLOWED_EXECUTABLES:
msg = f"Command {cmd} is not allowed!"
raise PermissionError(msg)
call = shutil.which(cmd[0])
if not call:
msg = f"Command {call} not found!"
raise FileNotFoundError(msg)
cmd[0] = call
if cmd[0] == "uv" and cmd[1] == "run":
cmd = [cmd[0], cmd[1], *validate_command("".join(cmd[2:]))]
return cmd
def run_command(command: str) -> str:
"""Run a command and get its output as a string.
Args:
command (str): The command to run.
Returns:
The output returned on stdout.
"""
cmd = validate_command(command)
return subprocess.run( # noqa: S603
cmd,
capture_output=True,
text=True,
check=True,
).stdout.strip()
def get_version_file_from_pyproject() -> str:
"""Get the path to the version file from the pyproject file.
Returns:
str: the path to the project version file
Raises:
FileNotFoundError: the pyproject file is not found
KeyError: the pyproject file does not reference a version file
"""
try:
with Path("pyproject.toml").open("rb") as f:
return tomllib.load(f)["tool"]["hatch"]["version"]["path"]
except FileNotFoundError as e:
msg = "Project has no pyproject.toml file!"
raise FileNotFoundError(msg) from e
except KeyError as e:
msg = "Attribute `tool.hatch.version.path` not found in pyproject.toml"
raise KeyError(msg) from e
def get_remote_version(remote: str, branch: str, version_file: str) -> str:
"""Get the version from the repository remote.
Args:
remote (str): The remote to fetch the version string from.
branch (str): The branch to fetch the version string from.
version_file (str): The file to fetch the version string from.
Returns:
str: The version string.
Raises:
AttributeError: NO `__version__` attribute was found in `version_file`.
"""
remote_file = f"{remote}/{branch}:{version_file}"
match = re.search(
r"__version__\s*=\s*['\"]([^'\"]+)['\"]",
run_command(f"git show {remote_file}"),
)
if match:
return match.group(1)
msg = f"No `__version__` attribute found in {remote_file}"
raise AttributeError(msg)
def main() -> None:
"""Entrypoint for this script."""
version_file = get_version_file_from_pyproject()
branch = run_command("git rev-parse --abbrev-ref HEAD")
remote = sys.argv[1] if len(sys.argv) > 1 else "origin"
current_version = run_command("uv run hatch version")
remote_version = get_remote_version(remote, branch, version_file)
assert current_version != remote_version, (
"Version has not been incremented! Please update the version before pushing."
)
if __name__ == "__main__":
main()