Skip to content

Commit

Permalink
Added support for selecting cross-validation groups from a metadata key
Browse files Browse the repository at this point in the history
  • Loading branch information
stephensolis committed Jul 30, 2018
1 parent b077807 commit 7327ac5
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 6 deletions.
2 changes: 1 addition & 1 deletion kameris/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from __future__ import unicode_literals

__version__ = '1.1.2'
__version__ = '1.2.1'
12 changes: 10 additions & 2 deletions kameris/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@ def main():
subparser.set_defaults(module_name=cmd_settings['module_name'])

import importlib
import logging
import sys

try:
args = parser.parse_args()
Expand All @@ -42,15 +40,25 @@ def main():
)
run_module.run(args)
except Exception as e:
import logging
import sys
import traceback

log = logging.getLogger('kameris')
message = 'an unexpected error occurred: {}: {}'.format(
type(e).__name__,
(e.message if hasattr(e, 'message') else '') or str(e)
)
report_message = (
'if you believe this is a bug, please report it at '
'https://github.com/stephensolis/kameris/issues and include ALL '
'the following text:\n') + ''.join(traceback.format_exc())
if log.handlers:
log.error(message)
log.error(report_message)
else:
print('ERROR ' + message)
print('ERROR ' + report_message)
sys.exit(1)


Expand Down
21 changes: 18 additions & 3 deletions kameris/job_steps/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,19 @@ def crossvalidation_run(classifier_factory, features, features_mode,
# perform validation group splitting
validation_count = options['validation_count']
num_points = len(point_classes)
validation_indexes = np.array_split(np.random.permutation(num_points),
validation_count)
if 'validation_split_classes' in options:
val_all_classes = options['validation_split_classes']
val_split_classes = np.array_split(
np.random.permutation(np.unique(val_all_classes)), validation_count
)
validation_indexes = [
np.concatenate([np.where(val_all_classes == split_class)[0]
for split_class in split_classes])
for split_classes in val_split_classes
]
else:
validation_indexes = np.array_split(np.random.permutation(num_points),
validation_count)

# setup storage for accuracy/stats
totals = defaultdict(int)
Expand Down Expand Up @@ -222,7 +233,11 @@ def run_classify_step(options, exp_options):
with open(options['metadata_file'], 'r') as infile:
metadata = json.load(infile)
point_classes = np.array([x['group'] for x in metadata])
unique_classes = sorted(set(x['group'] for x in metadata))
unique_classes = np.unique(point_classes)
if 'validation_split_by' in options:
options['validation_split_classes'] = np.array([
x[options['validation_split_by']] for x in metadata
])

# run classifiers and obtain results
results = {}
Expand Down
1 change: 1 addition & 0 deletions kameris/schemas/job_options.json
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@
"type": "integer",
"minimum": 1
},
"validation_split_by": {"type": "string"},
"dim_reduce_fraction": {
"type": "number",
"minimum": 0,
Expand Down

0 comments on commit 7327ac5

Please sign in to comment.