mirror of
https://github.com/Cian-H/invenio-theme-iform.git
synced 2025-12-22 12:41:57 +00:00
Improved check_version_increment script
This commit is contained in:
@@ -1,29 +1,102 @@
|
|||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
|
"""A Script for checking that the version number has been incremented between pushes."""
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import re
|
import re
|
||||||
|
import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
import tomllib
|
import tomllib
|
||||||
|
|
||||||
|
ALLOWED_EXECUTABLES: list[str] = [
|
||||||
|
"git",
|
||||||
|
"uv",
|
||||||
|
]
|
||||||
|
|
||||||
def run_command(command):
|
|
||||||
return subprocess.run(
|
def validate_command(command: str) -> list[str]:
|
||||||
command, shell=True, capture_output=True, text=True, check=True
|
"""Validate and guard command calls.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
command (str): the command to be validated.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: the validated command
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundError: if the command was not found
|
||||||
|
|
||||||
|
"""
|
||||||
|
cmd: list[str] = command.split().copy()
|
||||||
|
if cmd[0] not in ALLOWED_EXECUTABLES:
|
||||||
|
msg = f"Command {cmd} is not allowed!"
|
||||||
|
raise PermissionError(msg)
|
||||||
|
call = shutil.which(cmd[0])
|
||||||
|
if not call:
|
||||||
|
msg = f"Command {call} not found!"
|
||||||
|
raise FileNotFoundError(msg)
|
||||||
|
cmd[0] = call
|
||||||
|
if cmd[0] == "uv" and cmd[1] == "run":
|
||||||
|
cmd = [cmd[0], cmd[1], *validate_command("".join(cmd[2:]))]
|
||||||
|
return cmd
|
||||||
|
|
||||||
|
|
||||||
|
def run_command(command: str) -> str:
|
||||||
|
"""Run a command and get its output as a string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
command (str): The command to run.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The output returned on stdout.
|
||||||
|
|
||||||
|
"""
|
||||||
|
cmd = validate_command(command)
|
||||||
|
return subprocess.run( # noqa: S603
|
||||||
|
cmd,
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
check=True,
|
||||||
).stdout.strip()
|
).stdout.strip()
|
||||||
|
|
||||||
|
|
||||||
def get_version_file_from_pyproject():
|
def get_version_file_from_pyproject() -> str:
|
||||||
|
"""Get the path to the version file from the pyproject file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: the path to the project version file
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundError: the pyproject file is not found
|
||||||
|
KeyError: the pyproject file does not reference a version file
|
||||||
|
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
with Path("pyproject.toml").open("rb") as f:
|
with Path("pyproject.toml").open("rb") as f:
|
||||||
return tomllib.load(f)["tool"]["hatch"]["version"]["path"]
|
return tomllib.load(f)["tool"]["hatch"]["version"]["path"]
|
||||||
except FileNotFoundError:
|
except FileNotFoundError as e:
|
||||||
raise FileNotFoundError("Project has no pyproject.toml file!")
|
msg = "Project has no pyproject.toml file!"
|
||||||
except KeyError:
|
raise FileNotFoundError(msg) from e
|
||||||
raise KeyError("Attribute `tool.hatch.version.path` not found in pyproject.toml")
|
except KeyError as e:
|
||||||
|
msg = "Attribute `tool.hatch.version.path` not found in pyproject.toml"
|
||||||
|
raise KeyError(msg) from e
|
||||||
|
|
||||||
|
|
||||||
def get_remote_version(remote, branch, version_file):
|
def get_remote_version(remote: str, branch: str, version_file: str) -> str:
|
||||||
|
"""Get the version from the repository remote.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
remote (str): The remote to fetch the version string from.
|
||||||
|
branch (str): The branch to fetch the version string from.
|
||||||
|
version_file (str): The file to fetch the version string from.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The version string.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AttributeError: NO `__version__` attribute was found in `version_file`.
|
||||||
|
|
||||||
|
"""
|
||||||
remote_file = f"{remote}/{branch}:{version_file}"
|
remote_file = f"{remote}/{branch}:{version_file}"
|
||||||
match = re.search(
|
match = re.search(
|
||||||
r"__version__\s*=\s*['\"]([^'\"]+)['\"]",
|
r"__version__\s*=\s*['\"]([^'\"]+)['\"]",
|
||||||
@@ -31,11 +104,12 @@ def get_remote_version(remote, branch, version_file):
|
|||||||
)
|
)
|
||||||
if match:
|
if match:
|
||||||
return match.group(1)
|
return match.group(1)
|
||||||
else:
|
msg = f"No `__version__` attribute found in {remote_file}"
|
||||||
raise AttributeError(f"No `__version__` attribute found in {remote_file}")
|
raise AttributeError(msg)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main() -> None:
|
||||||
|
"""Entrypoint for this script."""
|
||||||
version_file = get_version_file_from_pyproject()
|
version_file = get_version_file_from_pyproject()
|
||||||
branch = run_command("git rev-parse --abbrev-ref HEAD")
|
branch = run_command("git rev-parse --abbrev-ref HEAD")
|
||||||
remote = sys.argv[1] if len(sys.argv) > 1 else "origin"
|
remote = sys.argv[1] if len(sys.argv) > 1 else "origin"
|
||||||
|
|||||||
Reference in New Issue
Block a user