-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathstyleGAN.py
164 lines (128 loc) · 5.19 KB
/
styleGAN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
# -*- coding: utf-8 -*-
"""Untitled2.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1LmmjkZG437l3T2y_3NdcNOICqya7MOfU
"""
#!git clone https://github.com/NVlabs/stylegan.git
# cd stylegan
import config
import dnnlib.tflib as tflib
import dnnlib
import PIL.Image
import numpy as np
import pickle
import random
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"] = ""
from PIL import Image
import matplotlib.pyplot as plt
from pylab import rcParams
# -*- coding: utf-8 -*-
"""Untitled2.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1LmmjkZG437l3T2y_3NdcNOICqya7MOfU
"""
#!git clone https://github.com/NVlabs/stylegan.git
# cd stylegan
tflib.init_tf()
url = "https://drive.google.com/uc?id=1MQywl0FNt6lHu8E_EUqnRbviagS7fbiJ"
with dnnlib.util.open_url(url, cache_dir=config.cache_dir) as f:
_G, _D, Gs = pickle.load(f)
rnd = np.random.RandomState(200)
latents = rnd.randn(1, Gs.input_shape[1])
fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
images = Gs.run(
latents, None, truncation_psi=1, randomize_noise=True, output_transform=fmt
)
os.makedirs(config.result_dir, exist_ok=True)
png_filename = os.path.join(config.result_dir, "example200.png")
PIL.Image.fromarray(images[0], "RGB").save(png_filename)
rcParams["figure.figsize"] = 20, 20
# karras2019stylegan-ffhq-1024x1024.pkl
url_ffhq = "https://drive.google.com/uc?id=1MQywl0FNt6lHu8E_EUqnRbviagS7fbiJ"
_Gs_cache = dict()
synthesis_kwargs = dict(
output_transform=dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True),
minibatch_size=8,
)
def draw_style_mixing_figure(png, Gs, w, h, src_seeds, dst_seeds, style_ranges):
print(png)
src_latents = np.stack(
np.random.RandomState(seed).randn(Gs.input_shape[1]) for seed in src_seeds
)
dst_latents = np.stack(
np.random.RandomState(seed).randn(Gs.input_shape[1]) for seed in dst_seeds
)
src_dlatents = Gs.components.mapping.run(
src_latents, None
) # [seed, layer, component]
dst_dlatents = Gs.components.mapping.run(
dst_latents, None
) # [seed, layer, component]
src_images = Gs.components.synthesis.run(
src_dlatents, randomize_noise=False, **synthesis_kwargs
)
dst_images = Gs.components.synthesis.run(
dst_dlatents, randomize_noise=False, **synthesis_kwargs
)
canvas = PIL.Image.new(
"RGB", (w * (len(src_seeds) + 1), h * (len(dst_seeds) + 1)), "white"
)
for col, src_image in enumerate(list(src_images)):
canvas.paste(PIL.Image.fromarray(src_image, "RGB"), ((col + 1) * w, 0))
for row, dst_image in enumerate(list(dst_images)):
canvas.paste(PIL.Image.fromarray(dst_image, "RGB"), (0, (row + 1) * h))
row_dlatents = np.stack([dst_dlatents[row]] * len(src_seeds))
row_dlatents[:, style_ranges[row]] = src_dlatents[:, style_ranges[row]]
row_images = Gs.components.synthesis.run(
row_dlatents, randomize_noise=False, **synthesis_kwargs
)
for col, image in enumerate(list(row_images)):
canvas.paste(
PIL.Image.fromarray(image, "RGB"), ((col + 1) * w, (row + 1) * h)
)
canvas.save(png)
def load_Gs(url):
if url not in _Gs_cache:
with dnnlib.util.open_url(url, cache_dir=config.cache_dir) as f:
_G, _D, Gs = pickle.load(f)
_Gs_cache[url] = Gs
return _Gs_cache[url]
# シード値を2つ選び,最適な毛の長さの方をsrc_seedsに,
# 最適な毛の模様,色の方をdst_seedsに格納する
# 毛の長さはリストから取り出したシード値のものをそのまま利用する
def generate_cat_fugure(color, variety, hair_length):
tflib.init_tf()
with dnnlib.util.open_url(url, cache_dir=config.cache_dir) as f:
_G, _D, Gs = pickle.load(f)
import pandas as pd
data = pd.read_csv("./catcsv.csv")
# whiteに色を入れる
color = data[data["color of hair"] == color]
pattern = color[color["pattern of hair"] == "mix"]
# とりあえずのmixを代入して変数を初期化
# Toraに猫の品種
if data["pattern of hair"].value_counts()[variety] != 0:
pattern = color[color["pattern of hair"] == variety]
# ユーザーの入力した毛の模様に対応するシード値があるのならpatternに代入
src_seed = pattern["seed"]
src_seed.values[0]
# これが最初に生成される画像の毛の模様,色の初期値に使われるシード値になる
# 模様と色はここまで
# 毛の長さ
length = data[data["length of hair"] == hair_length]
dst_seed = length["seed"]
dst_seed.values[0]
# これが最初に生成される画像の毛の長さ,ポーズに使われるシード値になる
draw_style_mixing_figure(
os.path.join(config.result_dir, "figure03-style-mixing15.png"),
load_Gs(url_ffhq),
w=256,
h=256,
src_seeds=[int(dst_seed.values[0])],
dst_seeds=[int(src_seed.values[0])],
style_ranges=[range(0, 4)] * 3 + [range(4, 8)] * 2 + [range(8, 14)] * 1,
)