-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
4d2fbfe
commit 0120146
Showing
35 changed files
with
2,454 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
/dlt/evaluation/fid_pretrained/ | ||
/dlt/configs/local/ | ||
/dlt/visualization.py | ||
/dlt/playground.py | ||
/dlt/playground2.py | ||
/dlt/scripts/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,92 @@ | ||
# <p style="text-align: center;">[ICCV 23] DLT: Conditioned layout generation with Joint Discrete-Continuous Diffusion Layout Transformer</p> | ||
|
||
This repository is an official implementation of DLT paper. | ||
This repository is an official implementation of DLT paper. Please, refer to the [paper](https://arxiv.org/abs/2110.00000) | ||
for more details and [project page](https://wix-incubator.github.io/DLT/) for general overview.<div style="display: flex; flex-direction: column; align-items: center;"> | ||
|
||
| Unconditional | Category | Category + Size | | ||
|--------------------------------------------|-----------------------------------|---------------------------------------| | ||
| data:image/s3,"s3://crabby-images/276b9/276b935978334c176aef10605d9d219e8d1137bf" alt="unconditional" | data:image/s3,"s3://crabby-images/7b22d/7b22d7c1839e7a08c264f5f31402e613f2d84e87" alt="category" | data:image/s3,"s3://crabby-images/d180a/d180afe1247a5425fb95f522c7941669cf035217" alt="category_size" | | ||
|
||
|
||
### Dev environment | ||
- Operating System: Ubuntu 18.04 | ||
- CUDA Version: 11.6 | ||
- Python Version: 3.9 | ||
### Requirements | ||
All relevant requirements are listed in [environment.yml](environment.yml). We recommend using | ||
[conda](https://docs.conda.io/en/latest/) to create the appropriate environment and install the dependencies: | ||
```bash | ||
conda env create -f environment.yml | ||
conda activate dlt | ||
``` | ||
### Datasets | ||
Please download the public datasets at the following webpages. Put it in your folder and update | ||
`./dlt/configs/remote/dataset_config.yaml` accordingly. | ||
|
||
1. [RICO](https://interactionmining.org/rico) | ||
2. [PubLayNet](https://github.com/ibm-aur-nlp/PubLayNet) | ||
3. [Magazine](https://xtqiao.com/projects/content_aware_layout/) | ||
|
||
### Training | ||
You can train the model using any config script in [configs](./dlt/configs) folder. For example, if you want to train the | ||
provided DLT model on publaynet dataset, the command is as follows: | ||
|
||
```bash | ||
cd dlt | ||
python main.py --config configs/remote/dlt_publaynet_config.yaml --workdir <WORKDIR> | ||
``` | ||
Please, see that code is accelerator agnostic. if you don't want to log results to wandb, just set `--workdir test` | ||
in args. | ||
|
||
### Evaluation | ||
|
||
To generate samples for evaluation on the test set, follow these steps: | ||
|
||
- train the model using the above command | ||
- Run the following command: | ||
|
||
```bash | ||
# put weights in config.logs folder | ||
DATASET = "publaynet" # or "rico" or "magazine" | ||
python generate_samples.py --config configs/remote/dlt_{$DATASET}_config.yaml \\ | ||
--workdir <WORKDIR> --epoch <EPOCH> --cond_type <COND_TYPE> \\ | ||
--save True | ||
# get all the metrics | ||
# update path to pickle file in dlt/evaluation/metric_comp.py | ||
./download_fid_model.sh | ||
python metric_comp.py | ||
``` | ||
where `<COND_TYPE>` can be: (all, whole_box, loc) - (unconditional, category, category+size) respectively, | ||
`<EPOCH>` is the epoch number of the model you want to evaluate, and `<WORKDIR>` is the path to the folder where | ||
the model weights are saved (e.g. rico_final). The generated samples will be saved in `logs/<WORKDIR>/samples` folder if `save` True. | ||
|
||
An output from it is pickle file with generated samples. You can use it to calculate metrics. | ||
|
||
The folder with weights after training has this structure: | ||
``` | ||
logs | ||
├── magazine_final | ||
│ ├── checkpoints | ||
│ └── samples | ||
├── publaynet_final | ||
│ ├── checkpoints | ||
│ └── samples | ||
└── rico_final | ||
├── checkpoints | ||
└── samples | ||
``` | ||
|
||
## Citation | ||
|
||
If you find this code useful for your research, please cite our paper: | ||
|
||
``` | ||
@misc{levi2023dlt, | ||
title={DLT: Conditioned layout generation with Joint Discrete-Continuous Diffusion Layout Transformer}, | ||
author={Elad Levi and Eli Brosh and Mykola Mykhailych and Meir Perez}, | ||
year={2023}, | ||
eprint={2303.03755}, | ||
archivePrefix={arXiv}, | ||
primaryClass={cs.CV} | ||
} | ||
``` |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import ml_collections | ||
import torch | ||
from path import Path | ||
|
||
|
||
def get_config(): | ||
"""Gets the default hyperparameter configuration.""" | ||
|
||
config = ml_collections.ConfigDict() | ||
config.log_dir = Path('/home/ubuntu/logs') | ||
# Exp info | ||
config.dataset_path = Path("/home/ubuntu/dataset/magazine") | ||
config.train_json = config.dataset_path / 'train.json' | ||
config.val_json = config.dataset_path / 'val.json' | ||
config.resume_from_checkpoint = None | ||
|
||
config.dataset = "magazine" | ||
config.max_num_comp = 16 | ||
|
||
# Training info | ||
config.seed = 42 | ||
# data specific | ||
config.categories_num = 7 | ||
# model specific | ||
config.latent_dim = 512 | ||
config.num_layers = 4 | ||
config.num_heads = 8 | ||
config.dropout_r = 0.0 | ||
config.activation = "gelu" | ||
config.cond_emb_size = 224 | ||
config.cls_emb_size = 64 | ||
# diffusion specific | ||
config.num_cont_timesteps = 100 | ||
config.num_discrete_steps = 10 | ||
config.beta_schedule = "squaredcos_cap_v2" | ||
|
||
# Training info | ||
config.log_interval = 44 | ||
config.save_interval = 10_000 | ||
|
||
config.optimizer = ml_collections.ConfigDict() | ||
config.optimizer.num_gpus = torch.cuda.device_count() | ||
|
||
config.optimizer.mixed_precision = 'no' | ||
config.optimizer.gradient_accumulation_steps = 1 | ||
config.optimizer.betas = (0.95, 0.999) | ||
config.optimizer.epsilon = 1e-8 | ||
config.optimizer.weight_decay = 1e-6 | ||
|
||
config.optimizer.lr_scheduler = 'cosine' | ||
config.optimizer.num_warmup_steps = 2_000 | ||
config.optimizer.lr = 0.0001 | ||
|
||
config.optimizer.num_epochs = 2000 | ||
config.optimizer.batch_size = 64 | ||
config.optimizer.split_batches = False | ||
config.optimizer.num_workers = 1 | ||
|
||
config.optimizer.lmb = 5 | ||
|
||
if config.optimizer.num_gpus == 0: | ||
config.device = 'cpu' | ||
else: | ||
config.device = 'cuda' | ||
return config |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import ml_collections | ||
import torch | ||
from path import Path | ||
|
||
|
||
def get_config(): | ||
"""Gets the default hyperparameter configuration.""" | ||
|
||
config = ml_collections.ConfigDict() | ||
config.log_dir = Path('/home/ubuntu/logs') | ||
# Exp info | ||
config.dataset_path = Path("/home/ubuntu/dataset/publaynet") | ||
config.train_json = config.dataset_path / 'train.json' | ||
config.val_json = config.dataset_path / 'val.json' | ||
|
||
config.resume_from_checkpoint = None | ||
|
||
config.dataset = "publaynet" | ||
config.max_num_comp = 9 | ||
|
||
# Training info | ||
config.seed = 42 | ||
# data specific | ||
config.categories_num = 7 | ||
# model specific | ||
config.latent_dim = 512 | ||
config.num_layers = 4 | ||
config.num_heads = 8 | ||
config.dropout_r = 0.0 | ||
config.activation = "gelu" | ||
config.cond_emb_size = 224 | ||
config.cls_emb_size = 64 | ||
# diffusion specific | ||
config.num_cont_timesteps = 100 | ||
config.num_discrete_steps = 10 | ||
config.beta_schedule = "squaredcos_cap_v2" | ||
|
||
# Training info | ||
config.log_interval = 2647 | ||
config.save_interval = 50_000 | ||
|
||
config.optimizer = ml_collections.ConfigDict() | ||
config.optimizer.num_gpus = torch.cuda.device_count() | ||
|
||
config.optimizer.mixed_precision = 'no' | ||
config.optimizer.gradient_accumulation_steps = 1 | ||
config.optimizer.betas = (0.95, 0.999) | ||
config.optimizer.epsilon = 1e-8 | ||
config.optimizer.weight_decay = 1e-6 | ||
|
||
config.optimizer.lr_scheduler = 'cosine' | ||
config.optimizer.num_warmup_steps = 100_000 | ||
config.optimizer.lr = 0.0001 | ||
|
||
config.optimizer.num_epochs = 800 | ||
config.optimizer.batch_size = 64 | ||
config.optimizer.split_batches = False | ||
config.optimizer.num_workers = 4 | ||
|
||
config.optimizer.lmb = 5 | ||
|
||
if config.optimizer.num_gpus == 0: | ||
config.device = 'cpu' | ||
else: | ||
config.device = 'cuda' | ||
return config |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import ml_collections | ||
import torch | ||
from path import Path | ||
|
||
|
||
def get_config(): | ||
"""Gets the default hyperparameter configuration.""" | ||
|
||
config = ml_collections.ConfigDict() | ||
config.log_dir = Path('/home/ubuntu/logs') | ||
# Exp info | ||
config.dataset_path = Path("/home/ubuntu/dataset/rico") | ||
config.resume_from_checkpoint = None | ||
|
||
config.dataset = "rico" | ||
config.max_num_comp = 10 | ||
|
||
# Training info | ||
config.seed = 42 | ||
# data specific | ||
config.categories_num = 15 | ||
# model specific | ||
config.latent_dim = 512 | ||
config.num_layers = 4 | ||
config.num_heads = 8 | ||
config.dropout_r = 0.0 | ||
config.activation = "gelu" | ||
config.cond_emb_size = 224 | ||
config.cls_emb_size = 64 | ||
# diffusion specific | ||
config.num_cont_timesteps = 100 | ||
config.num_discrete_steps = 10 | ||
config.beta_schedule = "squaredcos_cap_v2" | ||
|
||
# Training info | ||
config.log_interval = 500 | ||
config.save_interval = 10_000 | ||
|
||
config.optimizer = ml_collections.ConfigDict() | ||
config.optimizer.num_gpus = torch.cuda.device_count() | ||
|
||
config.optimizer.mixed_precision = 'no' | ||
config.optimizer.gradient_accumulation_steps = 1 | ||
config.optimizer.betas = (0.95, 0.999) | ||
config.optimizer.epsilon = 1e-8 | ||
config.optimizer.weight_decay = 1e-6 | ||
|
||
config.optimizer.lr_scheduler = 'cosine' | ||
config.optimizer.num_warmup_steps = 10_000 | ||
config.optimizer.lr = 0.0001 | ||
|
||
config.optimizer.num_epochs = 800 | ||
config.optimizer.batch_size = 64 | ||
config.optimizer.split_batches = False | ||
config.optimizer.num_workers = 4 | ||
|
||
config.optimizer.lmb = 5 | ||
|
||
if config.optimizer.num_gpus == 0: | ||
config.device = 'cpu' | ||
else: | ||
config.device = 'cuda' | ||
return config |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
import numpy as np | ||
|
||
|
||
def norm_bbox(H, W, element): | ||
x1, y1, width, height = element['bbox'] | ||
xc = x1 + width / 2. | ||
yc = y1 + height / 2. | ||
b = [xc / W, yc / H, | ||
width / W, height / H] | ||
return b | ||
|
||
|
||
def is_valid_comp(comp, W, H): | ||
x1, y1, width, height = comp['bbox'] | ||
x2, y2 = x1 + width, y1 + height | ||
if x1 < 0 or y1 < 0 or W < x2 or H < y2: | ||
return False | ||
|
||
if x2 <= x1 or y2 <= y1: | ||
return False | ||
|
||
return True | ||
|
||
|
||
def mask_loc(bbox_shape, r_mask=1.0): | ||
n, _ = bbox_shape | ||
ind_mask = np.random.choice(range(n), int(n * r_mask), replace=False) | ||
mask = np.zeros(bbox_shape) | ||
mask[ind_mask, :2] = 1 | ||
full_mask_cat = np.zeros(n).astype('long') | ||
return mask, full_mask_cat | ||
|
||
|
||
def mask_size(bbox_shape, r_mask=1.0): | ||
n, _ = bbox_shape | ||
ind_mask = np.random.choice(range(n), int(n * r_mask), replace=False) | ||
mask = np.zeros(bbox_shape) | ||
mask[ind_mask, 2:] = 1 | ||
full_mask_cat = np.zeros(n).astype('long') | ||
return mask, full_mask_cat | ||
|
||
|
||
def mask_whole_box(bbox_shape, r_mask=1.0): | ||
n, _ = bbox_shape | ||
ind_mask = np.random.choice(range(n), int(n * r_mask), replace=False) | ||
mask = np.zeros(bbox_shape) | ||
mask[ind_mask, :4] = 1 | ||
full_mask_cat = np.zeros(n).astype('long') | ||
return mask, full_mask_cat | ||
|
||
|
||
def mask_all(bbox_shape): | ||
n, _ = bbox_shape | ||
mask = np.ones(bbox_shape) | ||
full_mask_cat = np.ones(n).astype('long') | ||
return mask, full_mask_cat | ||
|
||
|
||
def mask_cat(bbox_shape, r_mask=1.0): | ||
n, dim = bbox_shape | ||
ind_mask = np.random.choice(range(n), int(n * r_mask), replace=False) | ||
mask = np.zeros(bbox_shape) | ||
full_mask_cat = np.zeros(n).astype('long') | ||
full_mask_cat[ind_mask] = 1 | ||
return mask, full_mask_cat | ||
|
||
|
||
def mask_random_box_and_cat(bbox_shape, r_mask_box=1.0, r_mask_cat=1.0): | ||
n, _ = bbox_shape | ||
func_options = [mask_loc, mask_size, [mask_loc, mask_size], mask_whole_box] | ||
ix = np.random.choice(range(len(func_options)), 1)[0] | ||
func_mask_box = func_options[ix] | ||
if isinstance(func_mask_box, list): | ||
mask_box = np.zeros(bbox_shape) | ||
for func in func_mask_box: | ||
m, _ = func(bbox_shape, r_mask_box) | ||
mask_box += m | ||
all_cat_mask = np.zeros(n).astype('long') | ||
else: | ||
mask_box, all_cat_mask = func_mask_box(bbox_shape, r_mask_box) | ||
_, full_mask_cat = mask_cat(bbox_shape, r_mask_cat) | ||
cat_mask_options = [all_cat_mask, full_mask_cat] | ||
return mask_box, cat_mask_options[np.random.choice(len(cat_mask_options), 1)[0]] |
Oops, something went wrong.