- The datasets VSDv2 are available now.
This repository cotains code and data for our paper Visual Spatial Description: Controlled Spatial-Oriented Image-to-Text Generation
** Note ** Please go into VLT5 and follow the README there for Pretrained Models and Feature Extraction.
# Create python environment (optional)
conda create -n vsd python=3.10
source activate vsd
# Install python dependencies
pip install -r requirements.txt
# Store images, features, and annotations
./datasets
# Image feature extraction
./feature_extraction
# Train VL-T5
./VL-T5/
src/
modeling_t5.py modeling_bart.py <= VL-T5/VL-BART model classes
caption_sp.py, vrd_caption.py <= fine-tuning
param.py <= (argparse) configuration
tokenization.py <= custom tokenizer
utils.py, dist_utils.py <= utility functions
snap/ <= store weight checkpoints
- pretrained VL-BART and VL-T5 are provided by VLT5
- Download
snap/
from Google Drive or feijiang
- This dataset is create from VG and SpatialSense images by running
python feature_extraction/sp_proposal.py
or
python feature_extraction/vg_proposal.py
- All you need to do is put all of the VG and SpatialSense images in a same folder
- the final .h5 file can be downloaded from google and feijiang
- put the .h5 file in the dataset folder and named it vsd_boxes36.h5
- If your network can't connect to huggingface by url, you may need to download the facebook/bart-base and t5-base in your local.
bash train_b16.sh num_gpu
bash train_b80.sh num_gpu
# or if yout want to train one of them
bash test_batch_80/baseline_bart.sh num_gpu
# others are same
bash test_b16.sh 1 --use_golden
bash test_b80.sh 1 --use_golden
- The result files will save at test_16_res and test_80_res folder.
Model | BLEU-4 | METEOR | ROUGE | CIDEr | SPICE | Acc |
---|---|---|---|---|---|---|
VLBART | 54.52 | 43.10 | 78.79 | 482.64 | 68.95 | - |
VLBART-end2end | 52.90 | 42.15 | 77.60 | 469.65 | 67.64 | 52.22 |
VLBART-end2end-golden | 71.94 | 50.93 | 87.17 | 571.46 | 76.66 | 52.22 |
VLT5 | 54.72 | 43.26 | 79.04 | 484.09 | 68.95 | - |
VLT5-end2end | 53.88 | 42.88 | 78.98 | 481.18 | 68.88 | 54.38 |
VLT5-end2end-golden | 72.24 | 51.21 | 87.92 | 576.20 | 76.95 | 54.38 |
VLT5-end2end-onestep-test-false | 53.71 | 42.76 | 78.49 | 476.68 | 68.31 | 52.41 |
VLT5-end2end-onestep-test-true | 53.82 | 42.75 | 78.58 | 477.96 | 68.58 | 52.41 |
VLT5-end2end-onestep-golden | 52.85 | 42.45 | 78.18 | 472.35 | 68.00 | 52.41 |
Model | BLEU-4 | METEOR | ROUGE | CIDEr | SPICE | Acc |
---|---|---|---|---|---|---|
VLBART | 52.73 | 42.35 | 77.91 | 471.97 | 67.74 | - |
VLBART-end2end | 53.19 | 42.10 | 77.76 | 470.02 | 68.06 | 52.81 |
VLBART-end2end-golden | 71.77 | 50.75 | 87.28 | 568.66 | 76.80 | 52.81 |
VLT5 | 54.44 | 43.03 | 78.82 | 484.02 | 68.92 | - |
VLT5-end2end | 54.76 | 43.10 | 79.08 | 481.46 | 68.58 | 53.27 |
VLT5-end2end-golden | 73.49 | 51.77 | 88.48 | 582.18 | 77.48 | 53.27 |
Model | BLEU-4 | METEOR | ROUGE | CIDEr | SPICE | Acc |
---|---|---|---|---|---|---|
VLBART | 25.53 | 26.78 | 55.95 | 276.40 | 48.28 | - |
VLBART-end2end | ||||||
VLBART-end2end-golden | ||||||
VLT5 | 25.83 | 26.88 | 55.79 | 277.87 | 48.88 | - |
VLT5-end2end | 24.99 | 26.65 | 55.45 | 275.41 | 49.22 | |
VLT5-end2end-golden |
This repo is adapted from VLT5.
训练时分为3种
1、baseline,使用的input_id为'sub obj'
2、train(one_step_dec=True),使用的input_id为'sub <extral_id_0> obj'
3、train(one_step_dec=False),使用的input_id为'sub <extral_id_i> obj'其中i为方位词所对应的序号
论文所属参数使用为第三种训练方式
测试时分3种,
1、golden=False,one_step_dec=True,测试时输入input_id为'sub <extral_id_0> obj'
2、golden=False,one_step_dec=False,测试时输入input_id为'sub <extral_id_i> obj',此处的i为模型预测方位词对应id
3、golden=True,测试时输入input_id为'sub <extral_id_i> obj',此处i为数据集中标注方位词对应id