forked from PeizhuoLi/ganimator
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdemo.py
204 lines (174 loc) · 8.96 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
import torch
import os
from os.path import join as pjoin
from dataset.motion import MotionData, load_multiple_dataset
from models import create_model, create_conditional_model
from models.architecture import draw_example, get_pyramid_lengths, FullGenerator
from option import TestOptionParser, TrainOptionParser
from fix_contact import fix_contact_on_file
from models.utils import get_layered_mask
from interactive_utils import sliding_window
def load_all_from_path(save_path, device, use_class=False):
train_parser = TrainOptionParser()
args = train_parser.load(pjoin(save_path, 'args.txt'))
args.device = device
args.save_path = save_path
device = torch.device(args.device)
if not args.multiple_sequences:
motion_data = MotionData(pjoin(args.bvh_prefix, f'{args.bvh_name}.bvh'),
padding=args.skeleton_aware, use_velo=args.use_velo, repr=args.repr,
contact=args.contact, keep_y_pos=args.keep_y_pos)
multiple_data = [motion_data]
else:
multiple_data = load_multiple_dataset(prefix=args.bvh_prefix, name_list=pjoin(args.bvh_prefix, args.bvh_name),
padding=args.skeleton_aware, use_velo=args.use_velo, repr=args.repr,
contact=args.contact, keep_y_pos=args.keep_y_pos,
no_scale=True)
motion_data = multiple_data[0]
lengths = []
min_len = 10000
for i in range(len(multiple_data)):
new_length = get_pyramid_lengths(args, len(multiple_data[i]))
min_len = min(min_len, len(new_length))
if args.num_stages_limit != -1:
new_length = new_length[:args.num_stages_limit]
lengths.append(new_length)
for i in range(len(multiple_data)):
lengths[i] = lengths[i][-min_len:]
gens = []
for step, length in enumerate(lengths[0]):
create = create_conditional_model if args.conditional_generator and step < args.num_conditional_generator else create_model
gen = create(args, motion_data, evaluation=True)
try:
gen_sate = torch.load(pjoin(args.save_path, f'gen{step:03d}.pt'), map_location=device)
except FileNotFoundError:
gen_sate = torch.load(pjoin(args.save_path, f'gen{step}.pt'), map_location=device)
gen.load_state_dict(gen_sate)
gens.append(gen)
z_star = torch.load(pjoin(args.save_path, 'z_star.pt'), map_location=device)
amps = torch.load(pjoin(args.save_path, 'amps.pt'), map_location=device)
if use_class:
if isinstance(z_star, list):
z_star = z_star[0]
if len(amps.shape) != 1:
amps = amps[0]
return FullGenerator(args, motion_data, gens, z_star, amps)
else:
if len(amps.shape) == 1:
amps = amps.unsqueeze(0)
if isinstance(z_star, torch.Tensor) and len(z_star.shape) == 3:
z_star = z_star.unsqueeze(0)
return args, multiple_data, gens, z_star, amps, lengths
def write_multires(imgs, prefix, writer, interpolator, full_lengths=None, requires_con_loss=True):
os.makedirs(prefix, exist_ok=True)
length = imgs[-1].shape[-1] if full_lengths is None else full_lengths
res = []
for step, img in enumerate(imgs):
full_length = interpolator(img, length)
writer(pjoin(prefix, f'{step:02d}.bvh'), full_length)
velo = full_length[:, -6:-3].norm(dim=1)
res.append(velo)
if requires_con_loss:
res = torch.cat(res, dim=0)
consistency_loss = torch.nn.MSELoss()(res[1], res[0])
return consistency_loss
def gen_noise(n_channel, length, full_noise, device):
if full_noise:
res = torch.randn((1, n_channel, length)).to(device)
else:
res = torch.randn((1, 1, length)).repeat(1, n_channel, 1).to(device)
return res
def main():
test_parser = TestOptionParser()
test_args = test_parser.parse_args()
args, multiple_data, gens, z_stars, amps, lengths = load_all_from_path(test_args.save_path, test_args.device)
device = torch.device(args.device)
n_total_levels = len(gens)
motion_data = multiple_data[0]
noise_channel = z_stars[0].shape[1] if args.full_noise else 1
if len(args.path_to_existing):
ConGen = load_all_from_path(args.path_to_existing, args.device, use_class=True)
ConGen.output_mask = get_layered_mask(args.conditional_mode, motion_data.n_rot)
conds_rec = [motion_data.sample(lengths[0][i]) for i in range(args.num_conditional_generator)]
else:
ConGen = None
conds_rec = None
print('levels:', lengths)
save_path = pjoin(args.save_path, 'bvh')
os.makedirs(save_path, exist_ok=True)
base_id = 0
# Evaluate with reconstruct noise
for i in range(len(multiple_data)):
motion_data = multiple_data[i]
imgs = draw_example(gens, 'rec', z_stars[i], lengths[i] + [1], amps[i], 1, args, all_img=True, conds=conds_rec,
full_noise=args.full_noise)
real = motion_data.sample(size=len(motion_data), slerp=args.slerp).to(device)
motion_data.write(pjoin(save_path, f'gt_{i}.bvh'), real)
motion_data.write(pjoin(save_path, f'rec_{i}.bvh'), imgs[-1])
if imgs[-1].shape[-1] == real.shape[-1]:
rec_loss = torch.nn.MSELoss()(imgs[-1], real).detach().cpu().numpy()
print(f'rec_loss: {rec_loss.item():.07f}')
generation_mode = 'manip' if test_args.style_transfer or test_args.keyframe_editing else 'random'
if test_args.style_transfer:
manip_data = MotionData(f'{test_args.style_transfer}',
padding=args.skeleton_aware, use_velo=args.use_velo, repr=args.repr,
contact=args.contact, keep_y_pos=args.keep_y_pos,
no_scale=False)
target_len = len(manip_data)
target_length = get_pyramid_lengths(args, target_len)
conds = manip_data.sample(target_length[0])
elif test_args.keyframe_editing:
manip_data = MotionData(f'{test_args.keyframe_editing}',
padding=args.skeleton_aware, use_velo=args.use_velo, repr=args.repr,
contact=args.contact, keep_y_pos=args.keep_y_pos,
no_scale=True)
target_len = len(multiple_data[0]) # Use original length of training data
target_length = get_pyramid_lengths(args, target_len)
conds = manip_data.sample(target_length[0])
elif test_args.conditional_generation:
manip_data = MotionData(f'{test_args.conditional_generation}',
padding=args.skeleton_aware, use_velo=args.use_velo, repr=args.repr,
contact=args.contact, keep_y_pos=args.keep_y_pos,
no_scale=True)
target_len = len(manip_data) # Use original length of training data
target_length = get_pyramid_lengths(args, target_len)
conds = [manip_data.sample(l) for l in target_length[:args.num_conditional_generator]]
conds_full = conds[-1]
generation_mode = 'cond'
if not args.conditional_generator:
raise Exception('Conditional generation only applicable to conditional generators.')
elif args.conditional_generator:
"This is a conditional model, but the condition is not given. Then the condition will be sampled from the ConGen model"
target_len = test_args.target_length
target_length = get_pyramid_lengths(args, target_len)
manip_data = None
conds = ConGen.random_generate(target_length)
conds_full = conds[-1]
conds = conds[:args.num_conditional_generator]
else:
target_len = test_args.target_length
target_length = get_pyramid_lengths(args, target_len)
manip_data = None
conds = None
while len(target_length) > n_total_levels:
target_length = target_length[1:]
z_length = target_length[0]
z_target = gen_noise(noise_channel, z_length, args.full_noise, device)
z_target *= amps[base_id][0]
amps2 = amps[base_id].clone()
amps2[1:] = 0
if not test_args.interactive:
imgs = draw_example(gens, generation_mode, z_stars[base_id], target_length, amps2, 1, args, all_img=True,
conds=conds, full_noise=args.full_noise, given_noise=[z_target])
else:
if not args.conditional_generator:
raise Exception('Interactive mode only applicable to conditional generators.')
final_res, imgs = sliding_window(gens, z_stars[base_id], amps2, conds_full, args)
motion_data.write(pjoin(save_path, f'result.bvh'), imgs[-1])
fix_contact_on_file(save_path, name=f'result')
if manip_data is not None:
motion_data.write(pjoin(save_path, 'manipulate_input.bvh'), manip_data.sample())
elif args.conditional_generator:
motion_data.write(pjoin(save_path, 'generated_traj.bvh'), conds_full)
if __name__ == '__main__':
main()