-
Notifications
You must be signed in to change notification settings - Fork 57
/
Copy pathgenerate_samples.py
157 lines (130 loc) · 4.62 KB
/
generate_samples.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
""" Demo script for running Random-latent space interpolation on the trained MSG-StyleGAN OR
Show the effect of stochastic noise on a fixed image """
import argparse
import os
import pickle
from math import sqrt
from pathlib import Path
import dnnlib.tflib as tflib
import imageio
import numpy as np
from tqdm import tqdm
def parse_arguments():
parser = argparse.ArgumentParser("MSG-StyleGAN image_generator")
parser.add_argument(
"--pickle_file",
type=str,
required=True,
action="store",
help="pickle file containing the trained styleGAN model",
)
parser.add_argument(
"--output_path",
action="store",
type=str,
default=None,
required=True,
help="Path to directory for saving the files",
)
parser.add_argument(
"--random_state",
action="store",
type=int,
default=33,
help="random_state (seed) for the script to run",
)
parser.add_argument(
"--out_depth",
action="store",
type=int,
default=None,
help="output depth of the generated images",
)
parser.add_argument(
"--truncation_psi",
action="store",
type=float,
default=0.6,
help="value of truncation psi used for generating the images",
)
parser.add_argument(
"--num_samples",
action="store",
type=int,
default=100,
help="Number of samples to be generated",
)
parser.add_argument(
"--only_noise",
action="store",
type=bool,
default=False,
help="to visualize the same point with only different realizations of noise",
)
parser.add_argument(
"--run_stylegan",
action="store",
type=bool,
default=False,
help="Whether you are running an MSG-StyleGAN model or just styleGAN",
)
args = parser.parse_args()
return args
def get_image(gen, point, out_depth, truncation_psi=0.6, run_stylegan=False):
"""
obtain an All-resolution grid of images from the given point
:param gen: the generator object
:param point: random latent point for generation
:param out_depth: depth of network from where the images are to be generated
:param truncation_psi: value of truncation psi used for generating the images
:param run_stylegan: whether to use a stylegan pkl for generating the images?
:return: img => generated image
"""
fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
point = np.expand_dims(point, axis=0)
images = gen.run(
point, None, truncation_psi=truncation_psi, randomize_noise=True, output_transform=fmt
)
if not run_stylegan:
if out_depth is None or out_depth >= len(images):
out_depth = -1
return np.squeeze(images[out_depth])
else:
return np.squeeze(images)
def main(args):
# Initialize TensorFlow.
tflib.init_tf()
# Load pre-trained network.
with open(args.pickle_file, "rb") as f:
_, _, Gs = pickle.load(f)
# _ = Instantaneous snapshot of the generator. Mainly useful for resuming a previous training run.
# _ = Instantaneous snapshot of the discriminator. Mainly useful for resuming a previous training run.
# Gs = Long-term average of the generator. Yields higher-quality results than the instantaneous snapshot.
# Print network details.
print("\n\nLoaded the Generator as:")
Gs.print_layers()
# Pick latent vector.
latent_size = Gs.input_shape[1]
rnd = np.random.RandomState(args.random_state)
# create the random latent_points for the interpolation
total_samples = args.num_samples
all_latents = rnd.randn(total_samples, latent_size)
all_latents = (
all_latents / np.linalg.norm(all_latents, axis=-1, keepdims=True)
) * sqrt(latent_size)
# animation mechanism
start_point = np.expand_dims(all_latents[0], axis=0)
points = all_latents
# all points are start_point, if we have only noise realization
if args.only_noise:
points = np.array([np.squeeze(start_point) for _ in points])
# make sure that the output path exists
output_path = Path(args.output_path)
output_path.mkdir(exist_ok=True)
print("Generating the requested number of samples ... ")
for count, point in tqdm(enumerate(points, 1)):
image = get_image(Gs, point, args.out_depth, args.truncation_psi)
imageio.imwrite(os.path.join(output_path, str(count) + ".png"), image)
print(f"Requested images have been generated at: {output_path}")
if __name__ == "__main__":
main(parse_arguments())