Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor Python 3 fixups. #1

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import numpy as np
import h5py
import threading
import Queue
from six.moves import queue

#----------------------------------------------------------------------------

Expand Down Expand Up @@ -60,7 +60,7 @@ def __init__(self,
self.reshuffle_window = min(self.order.size / 2, self.order.size - self.prefetch_images * 2 - 1)
else:
self.reshuffle_window = 1
self.queue = Queue.Queue(self.prefetch_images)
self.queue = queue.Queue(self.prefetch_images)
self.thread = None
self.cur_pos = 0
self.cur_lod = -1
Expand Down
15 changes: 8 additions & 7 deletions misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
# 4.0 International License. To view a copy of this license, visit
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
from __future__ import print_function

import os
import sys
import glob
import math
import types
import datetime
import cPickle
from six.moves import cPickle
import numpy as np
from collections import OrderedDict
import PIL.Image
Expand Down Expand Up @@ -175,7 +176,7 @@ def create_result_subdir(result_dir, run_desc):
continue
raise

print "Saving results to", result_subdir
print("Saving results to", result_subdir)
set_output_log_file(os.path.join(result_subdir, 'log.txt'))

# Export run details.
Expand Down Expand Up @@ -209,8 +210,8 @@ def print_network_topology_info(layers):
import lasagne

print
print "%-16s%-28s%-10s%-20s%-20s%s" % ('LayerName', 'LayerType', 'Params', 'OutputShape', 'WeightShape', 'Activation')
print "%-16s%-28s%-10s%-20s%-20s%s" % (('---',) * 6)
print("%-16s%-28s%-10s%-20s%-20s%s" % ('LayerName', 'LayerType', 'Params', 'OutputShape', 'WeightShape', 'Activation'))
print("%-16s%-28s%-10s%-20s%-20s%s" % (('---',) * 6))
total_params = 0

for l in lasagne.layers.get_all_layers(layers):
Expand All @@ -227,10 +228,10 @@ def print_network_topology_info(layers):
weights = np.zeros(())
weight_str = shape_to_str(weights.shape)
act_str = '' if not hasattr(l, 'nonlinearity') else l.nonlinearity.__name__ if isinstance(l.nonlinearity, types.FunctionType) else type(l.nonlinearity).__name__
print "%-16s%-28s%-10d%-20s%-20s%s" % (l.name, type_str, nparams, shape_to_str(outshape), weight_str, act_str)
print("%-16s%-28s%-10d%-20s%-20s%s" % (l.name, type_str, nparams, shape_to_str(outshape), weight_str, act_str))

print "%-16s%-28s%-10s%-20s%-20s%s" % (('---',) * 6)
print "%-16s%-28s%-10d%-20s%-20s%s" % ('Total', '', total_params, '', '', '')
print("%-16s%-28s%-10s%-20s%-20s%s" % (('---',) * 6))
print("%-16s%-28s%-10d%-20s%-20s%s" % ('Total', '', total_params, '', '', ''))
print

def shape_to_str(shape):
Expand Down
4 changes: 2 additions & 2 deletions network.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import theano
from theano import tensor as T
import lasagne
import cPickle
from six.moves import cPickle

# NOTE: Do not reference config.py here!
# Instead, specify all network parameters as build function arguments.
Expand Down Expand Up @@ -184,7 +184,7 @@ def _call_build_func(self, module_globals):

def _call_build_func_from_src(self):
tmp_module = imp.new_module('network_tmp_module')
exec self.build_module_src in tmp_module.__dict__
exec(self.build_module_src in tmp_module.__dict__)
globals()['tmp_modules'] = globals().get('tmp_modules', []) + [tmp_module] # Work around issues with GC.
return self._call_build_func(tmp_module.__dict__)

Expand Down
102 changes: 52 additions & 50 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,26 @@
# 4.0 International License. To view a copy of this license, visit
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
from __future__ import print_function

import os
import sys
import time
import glob
import shutil
import operator
import six
import numpy as np
import scipy.ndimage

import misc
misc.init_output_logging()

if __name__ == "__main__":
print 'Importing Theano...'
print('Importing Theano...')

import config
os.environ['THEANO_FLAGS'] = ','.join([key + '=' + value for key, value in config.theano_flags.iteritems()])
os.environ['THEANO_FLAGS'] = ','.join([key + '=' + value for key, value in six.iteritems(config.theano_flags)])
sys.setrecursionlimit(10000)
import theano
from theano import tensor as T
Expand Down Expand Up @@ -55,24 +57,24 @@ def random_labels(num_labels, training_set):
return training_set.labels[np.random.randint(training_set.labels.shape[0], size=num_labels)]

def load_dataset(dataset_spec=None, verbose=False, **spec_overrides):
if verbose: print 'Loading dataset...'
if verbose: print('Loading dataset...')
if dataset_spec is None: dataset_spec = config.dataset
dataset_spec = dict(dataset_spec) # take a copy of the dict before modifying it
dataset_spec.update(spec_overrides)
dataset_spec['h5_path'] = os.path.join(config.data_dir, dataset_spec['h5_path'])
if 'label_path' in dataset_spec: dataset_spec['label_path'] = os.path.join(config.data_dir, dataset_spec['label_path'])
training_set = dataset.Dataset(**dataset_spec)
if verbose: print 'Dataset shape =', np.int32(training_set.shape).tolist()
if verbose: print('Dataset shape =', np.int32(training_set.shape).tolist())
drange_orig = training_set.get_dynamic_range()
if verbose: print 'Dynamic range =', drange_orig
if verbose: print('Dynamic range =', drange_orig)
return training_set, drange_orig

def load_dataset_for_previous_run(result_subdir, **kwargs):
dataset = None
with open(os.path.join(result_subdir, 'config.txt'), 'rt') as f:
for line in f:
if line.startswith('dataset = '):
exec line
exec(line)
return load_dataset(dataset, **kwargs)

#----------------------------------------------------------------------------
Expand Down Expand Up @@ -114,7 +116,7 @@ def train_gan(
# Load dataset and build networks.
training_set, drange_orig = load_dataset()
if resume_network_pkl:
print 'Resuming', resume_network_pkl
print('Resuming', resume_network_pkl)
G, D, _ = misc.load_pkl(os.path.join(config.result_dir, resume_network_pkl))
else:
G = network.Network(num_channels=training_set.shape[1], resolution=training_set.shape[2], label_size=training_set.labels.shape[1], **config.G)
Expand Down Expand Up @@ -146,7 +148,7 @@ def train_gan(
raise ValueError('Invalid image_grid_type', image_grid_type)

# Theano input variables and compile generation func.
print 'Setting up Theano...'
print('Setting up Theano...')
real_images_var = T.TensorType('float32', [False] * len(D.input_shape)) ('real_images_var')
real_labels_var = T.TensorType('float32', [False] * len(training_set.labels.shape))('real_labels_var')
fake_latents_var = T.TensorType('float32', [False] * len(G.input_shape)) ('fake_latents_var')
Expand All @@ -163,7 +165,7 @@ def train_gan(
fake_score_avg = 0.0

if config.D.get('mbdisc_kernels', None):
print 'Initializing minibatch discrimination...'
print('Initializing minibatch discrimination...')
if hasattr(D, 'cur_lod'): D.cur_lod.set_value(np.float32(initial_lod))
D.eval(real_images_var, deterministic=False, init=True)
init_layers = lasagne.layers.get_all_layers(D.output_layers)
Expand Down Expand Up @@ -214,7 +216,7 @@ def train_gan(
# Setup training func for current LOD.
new_min_lod, new_max_lod = int(np.floor(cur_lod)), int(np.ceil(cur_lod))
if min_lod != new_min_lod or max_lod != new_max_lod:
print 'Compiling training funcs...'
print('Compiling training funcs...')
min_lod, max_lod = new_min_lod, new_max_lod

# Pre-process reals.
Expand Down Expand Up @@ -283,8 +285,8 @@ def train_gan(
tick_train_out = []

# Print progress.
print 'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-9.1f sec/kimg %-6.1f Dgdrop %-8.4f Gloss %-8.4f Dloss %-8.4f Dreal %-8.4f Dfake %-8.4f' % (
(cur_tick, cur_nimg / 1000.0, cur_lod, minibatch_size, misc.format_time(cur_time - train_start_time), tick_time, tick_time / tick_kimg, gdrop_strength) + tick_train_avg)
print('tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-9.1f sec/kimg %-6.1f Dgdrop %-8.4f Gloss %-8.4f Dloss %-8.4f Dreal %-8.4f Dfake %-8.4f' % (
(cur_tick, cur_nimg / 1000.0, cur_lod, minibatch_size, misc.format_time(cur_time - train_start_time), tick_time, tick_time / tick_kimg, gdrop_strength) + tick_train_avg))

# Visualize generated images.
if cur_tick % image_snapshot_ticks == 0 or cur_nimg >= total_kimg * 1000:
Expand All @@ -298,7 +300,7 @@ def train_gan(
# Write final results.
misc.save_pkl((G, D, Gs), os.path.join(result_subdir, 'network-final.pkl'))
training_set.close()
print 'Done.'
print('Done.')
with open(os.path.join(result_subdir, '_training-done.txt'), 'wt'):
pass

Expand Down Expand Up @@ -430,13 +432,13 @@ def interpolate_latents(
if image_grid_size is None: image_grid_size = np.clip(int(np.floor(1920 / (w * zoom))), 1, 16), np.clip(int(np.floor(1080 / (h * zoom))), 1, 16)

# Generate latent vectors (frame, image, channel, component).
print 'Generating latent vectors...'
print('Generating latent vectors...')
latents = np.random.randn(num_frames, np.prod(image_grid_size), *net.G.input_shape[1:]).astype(np.float32)
latents = scipy.ndimage.gaussian_filter(latents, [filter_frames] + [0] * len(net.G.input_shape), mode='wrap')
latents /= np.sqrt(np.mean(latents ** 2))

# Create video.
print 'Generating video...'
print('Generating video...')
result_subdir = misc.create_result_subdir(config.result_dir, config.run_desc)
def make_frame(t):
frame_idx = np.clip(int(np.round(t * video_fps)), 0, num_frames - 1)
Expand All @@ -450,7 +452,7 @@ def make_frame(t):
video.write_videofile(os.path.join(result_subdir, os.path.basename(result_subdir) + '.mp4'), fps=video_fps, codec='libx264', bitrate=video_bitrate)

# Done.
print 'Done.'
print('Done.')
with open(os.path.join(result_subdir, '_video-done.txt'), 'wt'):
pass

Expand All @@ -461,7 +463,7 @@ def calc_inception_scores(run_id, log='inception.txt', num_images=50000, minibat
network_pkls = misc.list_network_pkls(result_subdir)
misc.set_output_log_file(os.path.join(result_subdir, log))

print 'Importing inception score module...'
print('Importing inception score module...')
import inception_score
def calc_inception_score(images):
if images.shape[1] == 1:
Expand All @@ -475,24 +477,24 @@ def calc_inception_score(images):

# Evaluate reals.
if eval_reals:
print 'Evaluating inception score for reals...'
print('Evaluating inception score for reals...')
time_begin = time.time()
mean, std = calc_inception_score(reals)
print 'Done in %s' % misc.format_time(time.time() - time_begin)
print '%-32s mean %-8.4f std %-8.4f' % ('reals', mean, std)
print('Done in %s' % misc.format_time(time.time() - time_begin))
print('%-32s mean %-8.4f std %-8.4f' % ('reals', mean, std))

# Evaluate each network snapshot.
network_pkls = list(enumerate(network_pkls))
if reverse_order:
network_pkls = network_pkls[::-1]
for network_idx, network_pkl in network_pkls:
print '%-32s' % os.path.basename(network_pkl),
print('%-32s' % os.path.basename(network_pkl), end=' ')
net = imgapi_load_net(run_id=result_subdir, snapshot=network_pkl, num_example_latents=num_images, random_seed=network_idx)
fakes = imgapi_generate_batch(net, net.example_latents, np.random.permutation(labels), minibatch_size=minibatch_size, convert_to_uint8=True)
mean, std = calc_inception_score(fakes)
print 'mean %-8.4f std %-8.4f' % (mean, std)
print
print 'Done.'
print('mean %-8.4f std %-8.4f' % (mean, std))
print()
print('Done.')

#----------------------------------------------------------------------------

Expand All @@ -514,7 +516,7 @@ def calc_sliced_wasserstein_scores(
misc.set_output_log_file(os.path.join(result_subdir, log))

# Load dataset.
print 'Loading dataset...'
print('Loading dataset...')
training_set, drange_orig = load_dataset_for_previous_run(result_subdir)
assert training_set.shape[1] == 3 # RGB
assert num_images % minibatch_size == 0
Expand All @@ -527,7 +529,7 @@ def calc_sliced_wasserstein_scores(
resolutions = [2**i for i in xrange(int(np.log2(resolution_max)), int(np.log2(resolution_min)) - 1, -1)]

# Collect descriptors for reals.
print 'Extracting descriptors for reals...',
print('Extracting descriptors for reals...', end=' ')
time_begin = time.time()
desc_real = [[] for res in resolutions]
desc_test = [[] for res in resolutions]
Expand All @@ -536,37 +538,37 @@ def calc_sliced_wasserstein_scores(
for lod, level in enumerate(sliced_wasserstein.generate_laplacian_pyramid(minibatch, len(resolutions))):
desc_real[lod].append(sliced_wasserstein.get_descriptors_for_minibatch(level, nhood_size, nhoods_per_image))
desc_test[lod].append(sliced_wasserstein.get_descriptors_for_minibatch(level, nhood_size, nhoods_per_image))
print 'done in %s' % misc.format_time(time.time() - time_begin)
print('done in %s' % misc.format_time(time.time() - time_begin))

# Evaluate scores for reals.
print 'Evaluating scores for reals...',
print('Evaluating scores for reals...', end=' ')
time_begin = time.time()
scores = []
for lod, res in enumerate(resolutions):
desc_real[lod] = sliced_wasserstein.finalize_descriptors(desc_real[lod])
desc_test[lod] = sliced_wasserstein.finalize_descriptors(desc_test[lod])
scores.append(sliced_wasserstein.sliced_wasserstein(desc_real[lod], desc_test[lod], dir_repeats, dirs_per_repeat))
del desc_test
print 'done in %s' % misc.format_time(time.time() - time_begin)
print('done in %s' % misc.format_time(time.time() - time_begin))

# Print table header.
print
print '%-32s' % 'Case',
print()
print('%-32s' % 'Case', end=' ')
for lod, res in enumerate(resolutions):
print '%-12s' % ('%dx%d' % (res, res)),
print 'Average'
print '%-32s' % '---',
print('%-12s' % ('%dx%d' % (res, res)), end=' ')
print('Average')
print('%-32s' % '---', end=' ')
for lod, res in enumerate(resolutions):
print '%-12s' % '---',
print '---'
print '%-32s' % 'reals',
print('%-12s' % '---', end=' ')
print('---')
print('%-32s' % 'reals', end=' ')
for lod, res in enumerate(resolutions):
print '%-12.6f' % scores[lod],
print '%.6f' % np.mean(scores)
print('%-12.6f' % scores[lod], end=' ')
print('%.6f' % np.mean(scores))

# Process each network snapshot.
for network_idx, network_pkl in enumerate(network_pkls):
print '%-32s' % os.path.basename(network_pkl),
print('%-32s' % os.path.basename(network_pkl), end=' ')
net = imgapi_load_net(run_id=result_subdir, snapshot=network_pkl, num_example_latents=num_images, random_seed=network_idx)

# Extract descriptors for generated images.
Expand All @@ -588,10 +590,10 @@ def calc_sliced_wasserstein_scores(

# Report results.
for lod, res in enumerate(resolutions):
print '%-12.6f' % scores[lod],
print '%.6f' % np.mean(scores)
print
print 'Done.'
print('%-12.6f' % scores[lod], end=' ')
print('%.6f' % np.mean(scores))
print()
print('Done.')

#----------------------------------------------------------------------------

Expand All @@ -605,7 +607,7 @@ def calc_mnistrgb_histogram(run_id, num_images=25600, log='histogram.txt', minib
classify_fn = theano.function([input_var], [output_expr])

# Process folders
print 'Processing directory %s' % (run_id)
print('Processing directory %s' % (run_id))
result_subdir = misc.locate_result_subdir(run_id)

network_pkls = misc.list_network_pkls(result_subdir)
Expand Down Expand Up @@ -642,20 +644,20 @@ def calc_histogram(images_all):

# Evaluate reals.
if eval_reals:
print 'Evaluating histogram for reals...'
print('Evaluating histogram for reals...')
time_begin = time.time()
mean, kld = calc_histogram(reals)
print 'Done in %s' % misc.format_time(time.time() - time_begin)
print 'mean %-8.4f kld %-8.4f' % (mean, kld)
print('Done in %s' % misc.format_time(time.time() - time_begin))
print('mean %-8.4f kld %-8.4f' % (mean, kld))

# Evaluate each network snapshot.
latents = None
for network_idx, network_pkl in enumerate(network_pkls):
print '%-32s' % os.path.basename(network_pkl),
print('%-32s' % os.path.basename(network_pkl), end=' ')
net = imgapi_load_net(run_id=result_subdir, snapshot=network_pkl, num_example_latents=num_images*num_evals)
fakes = imgapi_generate_batch(net, net.example_latents, labels, minibatch_size=minibatch_size, convert_to_uint8=True)
mean, kld = calc_histogram(fakes)
print 'mean %-8.4f kld %-8.4f' % (mean, kld)
print('mean %-8.4f kld %-8.4f' % (mean, kld))

#----------------------------------------------------------------------------

Expand Down