Skip to content

Commit

Permalink
Merge pull request #19 from d-krupke/development
Browse files Browse the repository at this point in the history
v0.9.0
  • Loading branch information
d-krupke authored Feb 24, 2024
2 parents 4bb35ea + bbbd559 commit 966fa7d
Show file tree
Hide file tree
Showing 12 changed files with 164 additions and 79 deletions.
5 changes: 5 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ A simple script could look like this:
# compatible with any environment.
# You can enforce slurm with `slurminade.set_dispatcher(slurminade.SlurmDispatcher())`
@slurminade.node_setup
def setup():
print("I will run automatically on every slurm node at the beginning!")
# use this decorator to make a function distributable with slurm
@slurminade.slurmify(
Expand Down Expand Up @@ -353,6 +357,7 @@ The project is reasonably easy:
Changes
-------

- 0.9.0: Lots of improvements.
- 0.8.1: Bugfix and automatic detection of wrong usage when using ``Batch`` with ``wait_for``.
- 0.8.0: Added extensive logging and improved typing.
- 0.7.0: Warning if a Batch is flushed multiple times, as we noticed this to be a common indentation error.
Expand Down
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ where = ["src"]

[project]
name = "slurminade"
dynamic = ["version"]
version = "1.0.0"
authors = [
{ name = "TU Braunschweig, IBR, Algorithms Group (Dominik Krupke)", email = "krupke@ibr.cs.tu-bs.de" },
]
Expand All @@ -21,14 +21,14 @@ classifiers = [
]
keywords=["slurm"]
dependencies = [
"simple_slurm>=0.2.6"
"simple_slurm>=0.2.6",
"click",
]

[project.urls]
Homepage = "https://github.com/d-krupke/slurminade"
Issues = "https://github.com/d-krupke/slurminade/issues"

[tool.setuptools_scm]

[tool.pytest.ini_options]
minversion = "6.0"
Expand Down
7 changes: 5 additions & 2 deletions src/slurminade/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def clean_up():
SubprocessDispatcher,
)
from .function_map import set_entry_point
from .node_setup import node_setup

__all__ = [
"slurmify",
Expand All @@ -90,10 +91,12 @@ def clean_up():
"TestDispatcher",
"SubprocessDispatcher",
"set_entry_point",
"node_setup",
]

# set default logging
import logging
import sys

logging.getLogger("slurminade").setLevel(logging.INFO)
logging.getLogger("slurminade").addHandler(logging.StreamHandler())
# Set up the root logger to print to stdout by default
logging.basicConfig(level=logging.INFO, stream=sys.stdout)
2 changes: 1 addition & 1 deletion src/slurminade/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def __enter__(self):

def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type:
print("Aborted due to exception.")
logging.getLogger("slurminade").error("Aborted due to exception.")
return
self.flush()
set_dispatcher(self.subdispatcher)
Expand Down
21 changes: 15 additions & 6 deletions src/slurminade/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,19 +367,23 @@ def create_slurminade_command(
:param max_arg_length: The maximum allowed length of a command line argument.
:returns: A string representing the command to be executed in the terminal.
"""
command = f"{sys.executable} -m slurminade.execute {shlex.quote(get_entry_point())}"
command = f"{sys.executable} -m slurminade.execute --root {shlex.quote(get_entry_point())}"

# Serialize function calls as JSON
serialized_calls = json.dumps([f.to_json() for f in funcs])
json_calls = json.dumps([f.to_json() for f in funcs])
serialized_calls = shlex.quote(json_calls)

if len(shlex.quote(serialized_calls)) > max_arg_length:
if len(serialized_calls) > max_arg_length:
# The argument is too long, create temporary file for the JSON
fd, filename = mkstemp(prefix="slurminade_", suffix=".json", text=True, dir=".")
logging.getLogger("slurminade").info(
f"Long function calls. Serializing function calls to temporary file {filename}"
)
with os.fdopen(fd, "w") as f:
f.write(serialized_calls)
command += f" temp {shlex.quote(filename)}"
f.write(json_calls)
command += f" --fromfile {filename}"
else:
command += f" arg {shlex.quote(serialized_calls)}"
command += f" --calls {serialized_calls}"
return command


Expand Down Expand Up @@ -427,6 +431,11 @@ def dispatch(
:param options: The slurm options to be used.
:return: The job id.
"""
funcs = list(funcs) if not isinstance(funcs, FunctionCall) else [funcs]
for func in funcs:
if not FunctionMap.check_id(func.func_id):
msg = f"Function '{func.func_id}' cannot be called from the given entry point."
raise KeyError(msg)
return get_dispatcher()(funcs, options)


Expand Down
94 changes: 50 additions & 44 deletions src/slurminade/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,63 +2,69 @@
This module provides the starting point for the slurm node. You do not have to call
anything of this file yourself.
"""

import json
import pathlib
import sys
import logging
from pathlib import Path

import click

from .function import SlurmFunction
from .function import FunctionMap, SlurmFunction
from .function_map import set_entry_point
from .guard import prevent_distribution


def parse_args():
batch_file_path = pathlib.Path(
sys.argv[1]
) # the file with the code (function definition)
if not batch_file_path.exists():
msg = "Batch file does not exist.\n"
msg += f" File: {batch_file_path}\n"
msg += "This should not happen. Please report this bug."
raise RuntimeError(msg)
# determine whether function calls are provided as an argument or in a temp file.
mode = sys.argv[2]
if mode == "arg":
function_calls = json.loads(sys.argv[3])
elif mode == "temp":
tmp_file_path = pathlib.Path(sys.argv[3])
if not tmp_file_path.exists():
msg = "Using temporary file for passing function arguments, but file does not exist.\n"
msg += f" File: {tmp_file_path}\n"
msg += "This should not happen. Please report this bug."
raise RuntimeError(msg)
with open(tmp_file_path) as f:
function_calls = json.load(f)
tmp_file_path.unlink() # delete the temp file
else:
msg = "Unknown function call mode. Expected 'arg' or 'temp'.\n"
msg += f" Got: {mode}\n"
msg += "This should not happen. Please report this bug."
raise RuntimeError(msg)
assert isinstance(function_calls, list), "Expected a list of dicts"
return batch_file_path, function_calls


def main():
@click.command()
@click.option(
"--root",
type=click.Path(exists=True),
help="The root file of the task.",
required=True,
)
@click.option("--calls", type=str, help="The function calls.", required=False)
@click.option(
"--fromfile",
type=click.Path(exists=True),
help="The file to read the function calls from.",
required=False,
)
@click.option(
"--listfuncs",
help="List all available functions.",
default=False,
is_flag=True,
required=False,
)
def main(root, calls, fromfile, listfuncs):
prevent_distribution() # make sure, the code on the node does not distribute itself.
batch_file, function_calls = parse_args()

set_entry_point(batch_file)
with open(batch_file) as f:
set_entry_point(root)
with open(root) as f:
code = "".join(f.readlines())

# Workaround as otherwise __name__ is not defined
global __name__
__name__ = None

glob = dict(globals())
glob["__file__"] = batch_file
glob["__file__"] = root
glob["__name__"] = None
exec(code, glob)

if listfuncs:
print(json.dumps(FunctionMap.get_all_ids())) # noqa T201
return
if calls:
function_calls = json.loads(calls)
elif fromfile:
with open(fromfile) as f:
logging.getLogger("slurminade").info(
f"Reading function calls from {fromfile}."
)
function_calls = json.load(f)
Path(fromfile).unlink()
else:
msg = "No function calls provided."
raise ValueError(msg)
if not isinstance(function_calls, list):
msg = "Expected a list of function calls."
raise ValueError(msg)
# Execute the functions
for fc in function_calls:
SlurmFunction.call(fc["func_id"], *fc.get("args", []), **fc.get("kwargs", {}))
Expand Down
25 changes: 25 additions & 0 deletions src/slurminade/function_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@
"""

import inspect
import json
import os
import pathlib
import subprocess
import sys
import typing


Expand All @@ -21,6 +24,7 @@ class FunctionMap:
# slurminade will set this value in the beginning to reconstruct it.
entry_point: typing.Optional[str] = None
_data = {}
_ids = set()

@staticmethod
def get_id(func: typing.Callable) -> str:
Expand Down Expand Up @@ -92,6 +96,27 @@ def call(
raise KeyError(msg)
return FunctionMap._data[func_id](*args, **kwargs)

@staticmethod
def check_id(func_id: str) -> bool:
if func_id in FunctionMap._ids:
return True
cmd = [
sys.executable,
"-m",
"slurminade.execute",
"--root",
get_entry_point(),
"--listfuncs",
]
out = subprocess.check_output(cmd).decode()
ids = json.loads(out)
FunctionMap._ids = set(ids)
return func_id in FunctionMap._ids

@staticmethod
def get_all_ids() -> typing.List[str]:
return list(FunctionMap._data.keys())


def set_entry_point(entry_point: typing.Union[str, pathlib.Path]) -> None:
"""
Expand Down
6 changes: 4 additions & 2 deletions src/slurminade/guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@

_exec_flag = False

def on_slurm_node():
global _exec_flag
return _exec_flag

def guard_recursive_distribution():
global _exec_flag
if _exec_flag:
if on_slurm_node():
msg = """
You tried to distribute a task recursively. This is not allowed by default,
because it probably indicates a bug in your code. To save you from accidentally
Expand Down
17 changes: 17 additions & 0 deletions src/slurminade/node_setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import inspect
import typing
from .guard import on_slurm_node

def node_setup(func: typing.Callable):
"""
Decorator: Call this function on the node before running any function calls.
"""
if on_slurm_node():
func()
else:
# check if the function has no arguments
sig = inspect.signature(func)
if sig.parameters:
msg = "The node setup function must not have any arguments."
raise ValueError(msg)
return func
21 changes: 0 additions & 21 deletions tests/test_create_command.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import shlex
import unittest
from pathlib import Path

import slurminade
from slurminade.dispatcher import FunctionCall, create_slurminade_command

test_file_path = Path("./f_test_file.txt")

Expand All @@ -15,25 +13,6 @@ def f(s):


class TestCreateCommand(unittest.TestCase):
def test_create_long_command(self):
slurminade.set_entry_point(__file__)
test_call = FunctionCall(f.func_id, ["." * 100], {})
command = create_slurminade_command([test_call], 100)
args = shlex.split(command)
path = Path(args[-1])
assert args[-2] == "temp"
# check creation of temporary file
assert Path(path).is_file()
if path.exists(): # delete the file
path.unlink()

def test_create_short_command(self):
slurminade.set_entry_point(__file__)
test_call = FunctionCall(f.func_id, [""], {})
command = create_slurminade_command([test_call], 100000)
args = shlex.split(command)
assert args[-2] == "arg"

def test_dispatch_with_temp_file(self):
slurminade.set_entry_point(__file__)
if test_file_path.exists():
Expand Down
13 changes: 13 additions & 0 deletions tests/test_local_function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@

import slurminade
from pytest import raises

def test_dispatch_limit_batch():
@slurminade.slurmify()
def f():
pass

slurminade.set_entry_point(__file__)

with raises(KeyError):
f.distribute()
26 changes: 26 additions & 0 deletions tests/test_node_setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import slurminade
from pathlib import Path

test_file_path = Path("./f_test_file.txt")

@slurminade.node_setup
def f():
with open(test_file_path, "w") as file:
file.write('node_setup')

@slurminade.slurmify
def nil():
pass

def test_node_setup():
slurminade.set_entry_point(__file__)
if test_file_path.exists():
test_file_path.unlink()
dispatcher = slurminade.SubprocessDispatcher()
slurminade.set_dispatcher(dispatcher)
slurminade.set_dispatch_limit(100)
nil.distribute()
with open(test_file_path) as file:
assert file.readline() == 'node_setup'
if test_file_path.exists(): # delete the file
test_file_path.unlink()

0 comments on commit 966fa7d

Please sign in to comment.