Skip to content
This repository has been archived by the owner on May 27, 2021. It is now read-only.

Commit

Permalink
add-fpr-rule: generate migrations
Browse files Browse the repository at this point in the history
  • Loading branch information
mistydemeo committed Feb 25, 2016
1 parent d83c32a commit dd5f68a
Showing 1 changed file with 109 additions and 16 deletions.
125 changes: 109 additions & 16 deletions tools/add-fpr-rule
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@

from __future__ import print_function
from argparse import ArgumentParser
import datetime
import os
import sys

import jinja2

sys.path.append('/usr/share/archivematica/dashboard')
os.environ["DJANGO_SETTINGS_MODULE"] = "settings.local"

Expand Down Expand Up @@ -75,10 +78,64 @@ NAME_MAP = {
"fptool": (FPTool, FPTOOL_FIELDS),
}


def save_object(obj):
obj.save()
return connection.queries[-1]['sql']
MIGRATION_TEMPLATE = """
# -*- coding: utf-8 -*-
from __future__ import unicode_literals
from django.db import models, migrations
import main.models
def forwards_fn(apps, schema_editor):
db_alias = schema_editor.connection.alias
{% for change in changes %}
{{ change.model }} = apps.get_model("fpr", "{{ change.model }}")
{% if change.type == 'create' %}
{{ change.model }}.objects.using(db_alias).create(**{{ change.changes }})
{% else %}
{{ change.model }}.objects.filter({{ change.identity_key }}={{ repr(change.identity_value) }}).update(**{{ change.changes }})
{% endif %}
{% endfor %}
def reverse_fn(apps, schema_editor):
db_alias = schema_editor.connection.alias
{% for change in changes %}
{{ change.model }} = apps.get_model("fpr", "{{ change.model }}")
{% if change.type == 'create' %}
{{ change.model }}.objects.filter(**{{ change.changes }}).delete()
{% else %}
{{ change.model }}.objects.filter({{ change.identity_key }}={{ repr(change.identity_value) }}).update(**{{ change.reverse }})
{% endif %}
{% endfor %}
class Migration(migrations.Migration):
dependencies = []
operations = [
migrations.RunPython(forwards_fn, reverse_fn),
]
"""


def _filter_dict(d):
def test(pair):
k, _ = pair
return not k.startswith('_')
return dict(filter(test, d.items()))


def _convert_datetime_fields(d):
"""
Given a dict, returns a copy where any datetime objects are replaced by
isoformat versions of themselves.
This is needed to make sure that the dictionary can be properly represented
by repr().
"""
new_dict = d.copy()
for k, v in new_dict.items():
if isinstance(v, datetime.datetime):
new_dict[k] = v.isoformat()

return new_dict


def ask_for_field(field):
Expand All @@ -90,36 +147,53 @@ def ask_for_field(field):


def create_format():
changes = []
obj = FormatVersion()

format_description = ask_for_field("format")
try:
format = Format.objects.get(description=format_description)
except Format.DoesNotExist:
format = Format(description=format_description)
sql = save_object(format)
print(sql + ";")
format.save()
changes.append({
'model': 'Format',
'type': 'create',
'changes': _convert_datetime_fields(_filter_dict(format.__dict__))
})

if not format.group:
group_description = ask_for_field("format_group")
try:
group = FormatGroup.objects.get(description=group_description)
except FormatGroup.DoesNotExist:
group = FormatGroup(description=group_description)
sql = save_object(group)
print(sql + ";")
group.save()
changes.append({
'model': 'FormatGroup',
'type': 'create',
'changes': _convert_datetime_fields(_filter_dict(group.__dict__))
})
format.group = group

for field in FORMAT_FIELDS:
val = ask_for_field(field)
if val:
setattr(obj, field, val)

sql = save_object(obj)
print(sql + ";")
obj.save()
changes.append({
'model': 'FormatVersion',
'type': 'create',
'changes': _convert_datetime_fields(_filter_dict(obj.__dict__))
})

return changes


def create_rule(kind, command_file=None):
changes = []

klass, fields = NAME_MAP[kind]
obj = klass()
for field in fields:
Expand All @@ -135,11 +209,28 @@ def create_rule(kind, command_file=None):

if hasattr(obj, "replaces") and obj.replaces:
obj.replaces.enabled = False
sql = save_object(obj.replaces)
print(sql + ";")
obj.replaces.save()
changes.append({
'model': type(obj.replaces).__name__,
'type': 'update',
'identity_key': 'uuid',
'identity_value': obj.replaces.uuid,
'changes': {
'enabled': False,
},
'reverse': {
'enabled': True,
}
})

obj.save()
changes.append({
'model': type(obj).__name__,
'type': 'create',
'changes': _convert_datetime_fields(_filter_dict(obj.__dict__))
})

sql = save_object(obj)
print(sql + ";")
return changes


if __name__ == '__main__':
Expand All @@ -151,8 +242,10 @@ if __name__ == '__main__':

try:
if os.path.basename(sys.argv[0]) == "add-format":
create_format()
changes = create_format()
else:
create_rule(opts.kind, command_file=opts.command)
changes = create_rule(opts.kind, command_file=opts.command)
template = jinja2.Template(MIGRATION_TEMPLATE, trim_blocks=True, lstrip_blocks=True)
print(template.render(changes=changes, repr=repr))
except KeyboardInterrupt:
sys.exit("\nAborting")

0 comments on commit dd5f68a

Please sign in to comment.