-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathconfig.py
84 lines (61 loc) · 2.2 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import torch
import datetime
import torch.nn as nn
import numpy as np
import random
class config:
train_data_dir = ['Path to train dataset']
test_data_dir = ['Path to test dataset'] # Kodak CLIC21 CLIC22
lr = 1e-4
train_lambda = 1
batch_size = 3
num_workers = 8
print_step = 50
plot_step = 1000
logger = None
# training details
image_dims = (3, 256, 256)
aux_lr = 1e-3
distortion_metric = 'MSE' # 'MS-SSIM'
use_side_info = False
eta = [0.4, 0.2]
channel_adaptive = False
channel = {"type": 'awgn', 'chan_param': 10}
modulation = False
# modulation order
modulation_order = 64
multiple_rate = [16, 32, 48, 64, 80, 96, 102, 118, 134, 160, 186, 192, 208, 224, 240, 256]
ga_kwargs = dict(
img_size=(image_dims[1], image_dims[2]),
embed_dims=[256, 256, 256, 256], depths=[1, 1, 2, 4], num_heads=[8, 8, 8, 8],
window_size=8, mlp_ratio=4., qkv_bias=True, qk_scale=None,
norm_layer=nn.LayerNorm, patch_norm=True,
)
gs_kwargs = dict(
img_size=(image_dims[1], image_dims[2]),
embed_dims=[256, 256, 256, 256], depths=[4, 2, 1, 1], num_heads=[8, 8, 8, 8],
window_size=8, mlp_ratio=4., norm_layer=nn.LayerNorm, patch_norm=True
)
fe_kwargs = dict(
input_resolution=(image_dims[1] // 16, image_dims[2] // 16),
embed_dim=256, depths=[4], num_heads=[8],
window_size=16, mlp_ratio=4., qkv_bias=True, qk_scale=None,
norm_layer=nn.LayerNorm, rate_choice=multiple_rate, channel_adaptive = channel_adaptive, modulation = modulation
)
fd_kwargs = dict(
input_resolution=(image_dims[1] // 16, image_dims[2] // 16),
embed_dim=256, depths=[4], num_heads=[8],
window_size=16, mlp_ratio=4., qkv_bias=True, qk_scale=None,
norm_layer=nn.LayerNorm, rate_choice=multiple_rate, channel_adaptive = channel_adaptive, modulation = modulation
)
# turbo decoder
enc_num_layer = 2
code_rate_k = 1
enc_num_unit = 100
enc_kernel_size = 5
dec_num_layer = 3
num_iter_ft = 5
dec_num_unit = 100
dec_kernel_size = 5
num_iteration = 5
block_len=200