From 7327ac5a0239337401415266c08cb2bc8b841e86 Mon Sep 17 00:00:00 2001 From: Stephen Date: Sat, 7 Jul 2018 00:00:00 +0000 Subject: [PATCH] Added support for selecting cross-validation groups from a metadata key --- kameris/__init__.py | 2 +- kameris/__main__.py | 12 ++++++++++-- kameris/job_steps/classify.py | 21 ++++++++++++++++++--- kameris/schemas/job_options.json | 1 + 4 files changed, 30 insertions(+), 6 deletions(-) diff --git a/kameris/__init__.py b/kameris/__init__.py index 410294c..89dce52 100644 --- a/kameris/__init__.py +++ b/kameris/__init__.py @@ -1,3 +1,3 @@ from __future__ import unicode_literals -__version__ = '1.1.2' +__version__ = '1.2.1' diff --git a/kameris/__main__.py b/kameris/__main__.py index ff62356..44889c3 100644 --- a/kameris/__main__.py +++ b/kameris/__main__.py @@ -32,8 +32,6 @@ def main(): subparser.set_defaults(module_name=cmd_settings['module_name']) import importlib - import logging - import sys try: args = parser.parse_args() @@ -42,15 +40,25 @@ def main(): ) run_module.run(args) except Exception as e: + import logging + import sys + import traceback + log = logging.getLogger('kameris') message = 'an unexpected error occurred: {}: {}'.format( type(e).__name__, (e.message if hasattr(e, 'message') else '') or str(e) ) + report_message = ( + 'if you believe this is a bug, please report it at ' + 'https://github.com/stephensolis/kameris/issues and include ALL ' + 'the following text:\n') + ''.join(traceback.format_exc()) if log.handlers: log.error(message) + log.error(report_message) else: print('ERROR ' + message) + print('ERROR ' + report_message) sys.exit(1) diff --git a/kameris/job_steps/classify.py b/kameris/job_steps/classify.py index 94fd480..b9c75f2 100644 --- a/kameris/job_steps/classify.py +++ b/kameris/job_steps/classify.py @@ -124,8 +124,19 @@ def crossvalidation_run(classifier_factory, features, features_mode, # perform validation group splitting validation_count = options['validation_count'] num_points = len(point_classes) - validation_indexes = np.array_split(np.random.permutation(num_points), - validation_count) + if 'validation_split_classes' in options: + val_all_classes = options['validation_split_classes'] + val_split_classes = np.array_split( + np.random.permutation(np.unique(val_all_classes)), validation_count + ) + validation_indexes = [ + np.concatenate([np.where(val_all_classes == split_class)[0] + for split_class in split_classes]) + for split_classes in val_split_classes + ] + else: + validation_indexes = np.array_split(np.random.permutation(num_points), + validation_count) # setup storage for accuracy/stats totals = defaultdict(int) @@ -222,7 +233,11 @@ def run_classify_step(options, exp_options): with open(options['metadata_file'], 'r') as infile: metadata = json.load(infile) point_classes = np.array([x['group'] for x in metadata]) - unique_classes = sorted(set(x['group'] for x in metadata)) + unique_classes = np.unique(point_classes) + if 'validation_split_by' in options: + options['validation_split_classes'] = np.array([ + x[options['validation_split_by']] for x in metadata + ]) # run classifiers and obtain results results = {} diff --git a/kameris/schemas/job_options.json b/kameris/schemas/job_options.json index 752b0d0..bf97b8e 100644 --- a/kameris/schemas/job_options.json +++ b/kameris/schemas/job_options.json @@ -125,6 +125,7 @@ "type": "integer", "minimum": 1 }, + "validation_split_by": {"type": "string"}, "dim_reduce_fraction": { "type": "number", "minimum": 0,