Skip to content

Commit

Permalink
Merge pull request #19 from mraspaud/feature-multiple-threads
Browse files Browse the repository at this point in the history
Add multiple threads capability
  • Loading branch information
mraspaud authored Nov 13, 2024
2 parents 815d5c5 + a6de434 commit d5e4e78
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 9 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ The configuration file is made of three sections.

### `script`

Full path script to run
A dictionary with:
- `command` Full path script to run
- optionally `workers` The number of workers to use for parallel processing of messages. Defaults to 1.

### `subscriber_config`

Expand Down
41 changes: 33 additions & 8 deletions pytroll_runner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
static_metadata:
sensor: thermometer
topic: /hi/there
script: /tmp/pytest-of-a001673/pytest-169/test_fake_publisher0/myscript_bla.sh
script:
command: /tmp/pytest-of-a001673/pytest-169/test_fake_publisher0/myscript_bla.sh
workers: 4
subscriber_config:
addresses:
- ipc://bla
Expand All @@ -26,7 +28,9 @@
import os
import re
from contextlib import closing, suppress
from functools import partial
from glob import glob
from multiprocessing.pool import ThreadPool
from subprocess import PIPE, Popen

import yaml
Expand Down Expand Up @@ -137,24 +141,45 @@ def run_from_new_subscriber(command, subscriber_settings):

def run_on_messages(command, messages):
"""Run the command on files from messages."""
try:
num_workers = command.get("workers", 1)
except AttributeError:
num_workers = 1
pool = ThreadPool(num_workers)
run_command_on_message = partial(run_on_single_message, command)

yield from pool.imap_unordered(run_command_on_message, select_messages(messages))


def select_messages(messages):
"""Select only valid messages."""
accepted_message_types = ["file", "dataset"]
for message in messages:
if message.type not in accepted_message_types:
continue
try: # file
files = [message.data["uri"]]
except KeyError: # dataset
files = []
files.extend(info["uri"] for info in message.data["dataset"])
yield run_on_files(command, files), message.data
yield message


def run_on_single_message(command, message):
"""Run the command on files from message."""
try: # file
files = [message.data["uri"]]
except KeyError: # dataset
files = []
files.extend(info["uri"] for info in message.data["dataset"])
return run_on_files(command, files), message.data


def run_on_files(command, files):
"""Run the command of files."""
if not files:
return
logger.info(f"Start running command {command} on files {files}")
process = Popen([*os.fspath(command).split(), *files], stdout=PIPE) # noqa: S603
try:
command_to_call = command["command"]
except TypeError:
command_to_call = command
process = Popen([*os.fspath(command_to_call).split(), *files], stdout=PIPE) # noqa: S603
out, _ = process.communicate()
logger.debug(f"After having run the script: {out}")
return out
Expand Down
35 changes: 35 additions & 0 deletions pytroll_runner/tests/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
script_bla = """#!/bin/bash
for file in $*; do
cp "$file" "$file.bla"
echo "Written output file : $file.bla"
done
"""

Expand Down Expand Up @@ -478,6 +479,40 @@ def test_main_parse_log_configfile(config_file_aws,log_config_file):

assert isinstance(logging.getLogger("").handlers[0], logging.StreamHandler)

@pytest.fixture
def ten_files_to_glob(tmp_path):
"""Create multiple files to glob."""
n_files = 10
some_files = [f"file{n}" for n in range(n_files)]
for filename in some_files:
with open(tmp_path / filename, "w") as fd:
fd.write("hi")
return some_files

def test_run_and_publish_with_command_subitem_and_thread_number(tmp_path, command_bla, ten_files_to_glob):
"""Test run and publish."""
sub_config = dict(nameserver=False, addresses=["ipc://bla"])
pub_config = dict(publisher_settings=dict(nameservers=False, port=1979),
output_files_log_regex="Written output file : (.*.bla)",
topic="/hi/there")
command_path = os.fspath(command_bla)
test_config = dict(subscriber_config=sub_config,
script=dict(command=command_path, workers=4),
publisher_config=pub_config)
yaml_file = tmp_path / "config.yaml"
with open(yaml_file, "w") as fd:
fd.write(yaml.dump(test_config))

some_files = ten_files_to_glob
datas = [{"uri": os.fspath(tmp_path / f), "uid": f} for f in some_files]
messages = [Message("some_topic", "file", data=data) for data in datas]

with patched_subscriber_recv(messages):
with patched_publisher() as published_messages:
run_and_publish(yaml_file)
assert len(published_messages) == 10
res_files = [Message(rawstr=msg).data["uid"] for msg in published_messages]
assert res_files != sorted(res_files)

def test_run_and_publish_from_message_file(tmp_path, config_file_aws):
"""Test run and publish."""
Expand Down

0 comments on commit d5e4e78

Please sign in to comment.