Skip to content

Commit

Permalink
new exp with upsampling before SDS
Browse files Browse the repository at this point in the history
  • Loading branch information
jadevaibhav committed Sep 23, 2024
1 parent bfe69c8 commit 96cb5b8
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 26 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ coverage.xml
.pytest_cache/
cover/

# Slurm logs
slurm*
# Translations
*.mo
*.pot
Expand Down
33 changes: 30 additions & 3 deletions threestudio/systems/dreamfusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,16 +174,43 @@ def configure(self):
# create geometry, material, background, renderer
super().configure()

def unmask(self,ind,subsampled_tensor,H,W):
"""
ind: B,s_H,s_W
subsampled_tensor: B,C,s_H,s_W
"""

# Create a grid of coordinates for the original image size
offset = [ind[0,0]%H,ind[0,0]//H]
indices_all = torch.meshgrid(
torch.arange(W, dtype=torch.float32,device=self.device) ,
torch.arange(H, dtype=torch.float32,device=self.device) ,
indexing="xy"
)

grid = torch.stack(
[(indices_all[0] - offset[0])*4/(3*W),
(indices_all[1] - offset[1])*4/(H*3)],
dim=-1)
grid = grid*2 - 1
grid = grid.repeat(subsampled_tensor.shape[0], 1, 1, 1)
# Use grid_sample to upsample the subsampled tensor (B,C,H,W)
upsampled_tensor = torch.nn.functional.grid_sample(subsampled_tensor, grid, mode='bilinear', align_corners=True)

return upsampled_tensor.permute(0,2,3,1)

def training_step(self, batch, batch_idx):
out = self(batch)
### using mask to create image at original resolution during training
(B,s_H,s_W,C) = out["comp_rgb"].shape
comp_rgb = out["comp_rgb"].permute(0,3,1,2)
mask = batch["efficiency_mask"]
comp_rgb = torch.zeros(B,batch["height"],batch["width"],C,device=self.device).view(B,-1,C)
comp_rgb[:,mask.view(-1)] = out["comp_rgb"].view(B,-1,C)
comp_rgb = self.unmask(mask,comp_rgb,batch["height"],batch["width"])
# comp_rgb = torch.zeros(B,batch["height"],batch["width"],C,device=self.device).view(B,-1,C)
# comp_rgb[:,mask.view(-1)] = out["comp_rgb"].view(B,-1,C)
out.update(
{
"comp_rgb": comp_rgb.view(B,batch["height"],batch["width"],C),
"comp_rgb": comp_rgb,
}
)

Expand Down
54 changes: 31 additions & 23 deletions threestudio/utils/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,33 +227,41 @@ def mask_ray_directions(
pixels from (s_H,s_W) are sampled more (1-aspect_ratio) than outside pixels(aspect_ratio).
the masking is deferred to before calling get_rays().
"""
indices_all = torch.meshgrid(
torch.arange(W, dtype=torch.float32) ,
torch.arange(H, dtype=torch.float32) ,
indexing="xy",
)
# indices_inner = torch.meshgrid(
# torch.arange((W-s_W)//2 , W - math.ceil((W-s_W)/2), dtype=torch.float32) ,
# torch.arange((H-s_H)//2,H - math.ceil((H-s_H)/2), dtype=torch.float32) ,
# indices_all = torch.meshgrid(
# torch.arange(W, dtype=torch.float32) ,
# torch.arange(H, dtype=torch.float32) ,
# indexing="xy",
# )
mask = torch.zeros(H,W, dtype=torch.bool)
mask[(H-s_H)//2 : H - math.ceil((H-s_H)/2),(W-s_W)//2 : W - math.ceil((W-s_W)/2)] = True

indices_inner = torch.meshgrid(
torch.linspace(0,0.75*W,s_W, dtype=torch.int8) ,
torch.linspace(0,0.75*H,s_H, dtype=torch.int8) ,
indexing="xy",
)
offset = [torch.randint(0,W//8 +1,(1,)),
torch.randint(0,H//8 +1,(1,))]

select_ind = indices_inner[0]+offset[0] + H*(indices_inner[1] + offset[1])

in_ind_1d = (indices_all[0]+H*indices_all[1])[mask]
out_ind_1d = (indices_all[0]+H*indices_all[1])[torch.logical_not(mask)]
### tried using 0.5 p ratio of sampling inside vs outside, as smaller area already
### leads to more samples inside anyways

### removing the random sampling approach, we sample in uniform grid
# mask = torch.zeros(H,W, dtype=torch.bool)
# mask[(H-s_H)//2 : H - math.ceil((H-s_H)/2),(W-s_W)//2 : W - math.ceil((W-s_W)/2)] = True

# in_ind_1d = (indices_all[0]+H*indices_all[1])[mask]
# out_ind_1d = (indices_all[0]+H*indices_all[1])[torch.logical_not(mask)]
# ### tried using 0.5 p ratio of sampling inside vs outside, as smaller area already
# ### leads to more samples inside anyways

p = 0.5#(s_H*s_W)/(H*W)
select_ind = in_ind_1d[
torch.multinomial(
torch.ones_like(in_ind_1d)*(1-p),int((1-p)*(s_H*s_W)),replacement=False)]
select_ind = torch.concatenate(
[select_ind, out_ind_1d[torch.multinomial(
torch.ones_like(out_ind_1d)*(p),int((p)*(s_H*s_W)),replacement=False)]
],
dim=0).to(dtype=torch.int).view(s_H,s_W)
# p = 0.5#(s_H*s_W)/(H*W)
# select_ind = in_ind_1d[
# torch.multinomial(
# torch.ones_like(in_ind_1d)*(1-p),int((1-p)*(s_H*s_W)),replacement=False)]
# select_ind = torch.concatenate(
# [select_ind, out_ind_1d[torch.multinomial(
# torch.ones_like(out_ind_1d)*(p),int((p)*(s_H*s_W)),replacement=False)]
# ],
# dim=0).to(dtype=torch.int).view(s_H,s_W)

### first attempt at sampling, this produces variable number of rays,
### so 4D tensor directions cant be sampled
Expand Down

0 comments on commit 96cb5b8

Please sign in to comment.