-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathprune.py
62 lines (52 loc) · 2.12 KB
/
prune.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from tqdm import tqdm
import torch
import numpy as np
def get_sparsity(tensor):
total_elements = tensor.size
zero_elements = np.count_nonzero(tensor == 0)
sparsity_ratio = zero_elements / total_elements
return sparsity_ratio
def prune_loop(model, loss, pruner, dataloader, device, sparsity, schedule, scope, epochs, clip, noise,
reinitialize=False, train_mode=False, shuffle=False, invert=False):
r"""Applies score mask loop iteratively to a final sparsity level.
"""
# Set model to train or eval mode
model.train()
if not train_mode:
model.eval()
# Prune model
for epoch in tqdm(range(epochs)):
pruner.score(model, loss, dataloader, device, clip, noise)
if schedule == 'exponential':
sparse = sparsity**((epoch + 1) / epochs)
elif schedule == 'linear':
sparse = 1.0 - (1.0 - sparsity)*((epoch + 1) / epochs)
# Invert scores
if invert:
pruner.invert()
mask = pruner.mask(sparse, scope)
param = pruner.param()
# for k, v in param:
# print(k, v.shape)
mask_np = {k: v.cpu().numpy() for k, v in mask}
param_np = {k: v.cpu().detach().numpy() for k, v in param}
# for k, v in param_np.items():
# print(k)
# for k, v in mask_np.items():
# print(k, get_sparsity(v))
np.savez("jax_privacy/pruned_torch_weights.npz", **mask_np)
np.savez("jax_privacy/jax_privacy/pruned_torch_weights.npz", **mask_np)
np.savez("jax_privacy/torch_params.npz", **param_np)
np.savez("jax_privacy/jax_privacy/torch_params.npz", **param_np)
print("---------pruned weights stored----------")
# Reainitialize weights
if reinitialize:
model._initialize_weights()
# Shuffle masks
if shuffle:
pruner.shuffle()
# Confirm sparsity level
remaining_params, total_params = pruner.stats()
if np.abs(remaining_params - total_params*sparsity) >= 5:
print("ERROR: {} prunable parameters remaining, expected {}".format(remaining_params, total_params*sparsity))
quit()