-
Notifications
You must be signed in to change notification settings - Fork 201
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Addition of Amazon ECS job, copied code from luigi.contrib.ecs.py and… #13
base: master
Are you sure you want to change the base?
Changes from 4 commits
18c5536
1ede3cb
150ff2e
bdd50e3
2be8ad3
64f9ff9
a48ff8e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
"""A sample job that prints string.""" | ||
|
||
import time | ||
import logging | ||
from ndscheduler import job | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
try: | ||
import boto3 | ||
|
||
client = boto3.client('ecs') | ||
except ImportError: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IMHO, don't catch error here. Let the error bubble up. If boto is not imported, let it fail hard. |
||
logger.warning('boto3 is not installed. ECSTasks require boto3') | ||
|
||
POLL_TIME = 2 | ||
|
||
|
||
class ECSJob(job.JobBase): | ||
@classmethod | ||
def meta_info(cls): | ||
return { | ||
'job_class_string': '%s.%s' % (cls.__module__, cls.__name__), | ||
'notes': 'This will execute a AWS ECS RunTask!', | ||
'arguments': [ | ||
{'type': 'string', 'description': 'ECS Cluster to run on'}, | ||
{'type': 'string', 'description': 'task_def_arn'}, | ||
{'type': 'array[dict]', 'description': 'task_def'}, | ||
{'type': 'string', 'description': 'Directly corresponds to the ' | ||
'`overrides` parameter of runTask API'} | ||
], | ||
'example_arguments': '["ClusterName", None, "arn:aws:ecs:<region>' | ||
':<user_id>:task-definition/<family>:<tag>", None]' | ||
} | ||
|
||
def _get_task_statuses(self, task_ids): | ||
""" | ||
Retrieve task statuses from ECS API | ||
|
||
Returns list of {RUNNING|PENDING|STOPPED} for each id in task_ids | ||
""" | ||
logger.debug('Get status of task_ids: {}'.format(task_ids)) | ||
response = client.describe_tasks(tasks=task_ids, cluster=self.cluster) | ||
|
||
# Error checking | ||
if response['failures']: | ||
raise Exception('There were some failures:\n{0}'.format( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't raise a generic exception. Make a custom exception that subclasses Exception, e.g., ECSJobException. |
||
response['failures'])) | ||
status_code = response['ResponseMetadata']['HTTPStatusCode'] | ||
if status_code != 200: | ||
msg = 'Task status request received status code {0}:\n{1}' | ||
raise Exception(msg.format(status_code, response)) | ||
|
||
return [t['lastStatus'] for t in response['tasks']] | ||
|
||
def _track_tasks(self, task_ids): | ||
"""Poll task status until STOPPED""" | ||
while True: | ||
statuses = self._get_task_statuses(task_ids) | ||
if all([status == 'STOPPED' for status in statuses]): | ||
logger.info('ECS tasks {0} STOPPED'.format(','.join(task_ids))) | ||
break | ||
time.sleep(POLL_TIME) | ||
logger.debug('ECS task status for tasks {0}: {1}'.format( | ||
','.join(task_ids), statuses)) | ||
|
||
@property | ||
def cluster(self): | ||
if not hasattr(self, '_cluster'): | ||
logger.warning('Cluster not set!') | ||
return None | ||
return self._cluster | ||
|
||
@cluster.setter | ||
def cluster(self, cluster): | ||
self._cluster = cluster | ||
logger.debug('Set Cluster: {}'.format(cluster)) | ||
|
||
def run(self, cluster, task_def_arn=None, task_def=None, command=None): | ||
self.cluster = cluster | ||
if (not task_def and not task_def_arn) or \ | ||
(task_def and task_def_arn): | ||
raise ValueError(('Either (but not both) a task_def (dict) or' | ||
'task_def_arn (string) must be assigned')) | ||
if not task_def_arn: | ||
# Register the task and get assigned taskDefinition ID (arn) | ||
response = client.register_task_definition(**task_def) | ||
task_def_arn = response['taskDefinition']['taskDefinitionArn'] | ||
logger.debug('Task Definition ARN: {}'.format(task_def_arn)) | ||
|
||
# Submit the task to AWS ECS and get assigned task ID | ||
# (list containing 1 string) | ||
if command: | ||
overrides = {'containerOverrides': command} | ||
else: | ||
overrides = {} | ||
response = client.run_task(taskDefinition=task_def_arn, | ||
overrides=overrides, cluster=self.cluster) | ||
_task_ids = [task['taskArn'] for task in response['tasks']] | ||
|
||
# Wait on task completion | ||
self._track_tasks(_task_ids) | ||
|
||
|
||
if __name__ == "__main__": | ||
# You can easily test this job here | ||
job = ECSJob.create_test_instance() | ||
job.run('ClusterName', "arn:aws:ecs:<region>:<user_id>:task-" | ||
"definition/<task_def_name>:<revision_number>") | ||
job.run('DataETLCluster', None, { | ||
'family': 'hello-world', | ||
'volumes': [], | ||
'containerDefinitions': [ | ||
{ | ||
'memory': 1, | ||
'essential': True, | ||
'name': 'hello-world', | ||
'image': 'ubuntu', | ||
'command': ['/bin/echo', 'hello world'] | ||
} | ||
] | ||
}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Change this comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How do people pass AWS credential? Better add some operational guidelines here, e.g., create dot file on worker nodes for aws credential .