-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain2.py
75 lines (61 loc) · 2.28 KB
/
train2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from typing import Optional
import os
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import LearningRateMonitor
import torch
from data_voc import VOCSegmentation
from method2 import MultiqueryslotMethod
from model2 import PatchModel
from data2 import TrainTransforms,DinoTransforms
from params import SlotAttentionParams
def main(params: Optional[SlotAttentionParams] = None):
if params is None:
params = SlotAttentionParams()
assert params.num_slots > 1, "Must have at least 2 slots."
train = VOCSegmentation(
root=params.data_root,
year='2012',
image_set='train',
transform=TrainTransforms().transforms,
dino_transform=DinoTransforms().transforms
)
val = VOCSegmentation(
root=params.data_root,
year='2012',
image_set='val',
transform=TrainTransforms().transforms,
dino_transform=DinoTransforms().transforms
)
masking=params.masking_ratio # type: ignore
train_dataloader=torch.utils.data.DataLoader(train, batch_size=params.batch_size,shuffle=True, num_workers=params.num_workers)
val_dataloader=torch.utils.data.DataLoader(val, batch_size=params.val_batch_size, shuffle=False, num_workers=params.num_workers)
model = PatchModel(
num_slots=params.num_slots,
num_iterations=params.num_iterations,
empty_cache=params.empty_cache,
slot_size=384,
masking=masking
)
method = MultiqueryslotMethod(model=model, datamodule=train, params=params)
trainer = Trainer(
accelerator="gpu" if params.gpus > 0 else None,
devices=params.gpus,
strategy='ddp_find_unused_parameters_true',
num_sanity_val_steps=params.num_sanity_val_steps,
max_epochs=params.max_epochs,
log_every_n_steps=10,
callbacks=[LearningRateMonitor("step")] if params.is_logger_enabled else [],
)
trainer.fit(method,train_dataloader,val_dataloader)
model.eval()
model=model.to('cuda' if params.gpus>0 else 'cpu')
val = VOCSegmentation(
root=params.data_root,
year='2012',
image_set='trainval',
transform=TrainTransforms().transforms,
dino_transform=DinoTransforms().transforms,
evo=True
)
if __name__ == "__main__":
main()