-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathREADME
37 lines (33 loc) · 1.41 KB
/
README
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
The code contains the following files:
train.py - contains the main training loop. Includes an argparser to enable choice of hyperparameters (model architecture - for example ResNet depth and decoder type, epochs, learning rate, etc. )
eval.py - to generate the output json files for uploading to the evaluation server
utils.py - utility functions for PyTorch
vocab.py - to build the vocabulary
models.py - contains code for different models, and an implementation of beam search
preprocess_glove.py - to preprocess data and convert to glove embedding form
test_coco_eval.py - to calculate evaluation metrics on the validation set
The shell commands to be used are shown below:
Training:
python train.py \
--models_dir sat_resnet50_lstm1_e512_h512_b256_glove \
--vocab_path vocab2014.pkl \
--image_root train2014/ \
--captions_json annotations/captions_train2014.json \
--embed_size 300 \
--rnn_type lstm \
--resnet_size 50 \
--glove_embed_path glove_embeddings.pkl
Evaluation:
python eval.py \
--vocab_path vocab2014.pkl
--image_root val2014
--eval_ckpt_path sat_resnet50_lstm1_e512_h512_b256_glove/model-after-epoch-19.ckpt
--glove_embed_path glove_embeddings.pkl
--batch_size 512
--num_workers 8
--embed_size 300
--hidden_size 512
--rnn_type lstm
--resnet_size 50
--caption_maxlen 25
--results_json_path sat_resnet50_lstm1_e512_h512_b256_glove_model-after-epoch-19.json