Skip to content

Commit

Permalink
SINR update
Browse files Browse the repository at this point in the history
  • Loading branch information
Vishu26 committed Oct 21, 2024
1 parent 6bc76c1 commit 1b6d50c
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 6 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ pip install rshf
```python
from rshf.satmae import SatMAE
model = SatMAE.from_pretrained("MVRL/satmae-vitlarge-fmow-pretrain-800")
print(model.forward_encoder(torch.randn(1, 3, 224, 224), mask_ratio=0.0)[0].shape)
input = model.transform(torch.randint(0, 256, (224, 224, 3)).float().numpy(), 224).unsqueeze(0)
print(model.forward_encoder(input, mask_ratio=0.0)[0].shape)
```

### TODO:
Expand Down
2 changes: 1 addition & 1 deletion rshf/sinr/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .model import ResidualFCNet as SINR
from .model import ResidualFCNet as SINR, config, preprocess_locs
21 changes: 18 additions & 3 deletions rshf/sinr/model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
import torch
import torch.nn as nn
import math
from huggingface_hub import PyTorchModelHubMixin

config = {}
config['num_inputs'] = 4
config['num_classes'] = 47375
config['num_filts'] = 256

class ResLayer(nn.Module):
def __init__(self, linear_size):
super(ResLayer, self).__init__()
Expand All @@ -23,7 +29,7 @@ def forward(self, x):

class ResidualFCNet(nn.Module, PyTorchModelHubMixin):

def __init__(self, num_inputs, num_classes, num_filts, depth=4):
def __init__(self, num_inputs=config['num_inputs'], num_classes=config['num_classes'], num_filts=config['num_filts'], depth=4):
super(ResidualFCNet, self).__init__()
self.inc_bias = False
self.class_emb = nn.Linear(num_filts, num_classes, bias=self.inc_bias)
Expand All @@ -34,12 +40,21 @@ def __init__(self, num_inputs, num_classes, num_filts, depth=4):
layers.append(ResLayer(num_filts))
self.feats = torch.nn.Sequential(*layers)

def forward(self, x, class_of_interest=None, return_feats=False):
def forward(self, x, class_of_interest=None, return_feats=True):
loc_emb = self.feats(x)
if return_feats:
return loc_emb
if class_of_interest is None:
class_pred = self.class_emb(loc_emb)
else:
class_pred = self.eval_single_class(loc_emb, class_of_interest)
return torch.sigmoid(class_pred)
return torch.sigmoid(class_pred)

def preprocess_locs(locs):
locs[:, 0] /= 180.0
locs[:, 1] /= 90.0

feats = torch.cat((torch.sin(math.pi*locs), torch.cos(math.pi*locs)), dim=1)

return feats

2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[metadata]
name = rshf
version = 0.0.14
version = 0.0.15
author = Srikumar Sastry
author_email = s.sastry@wustl.edu
description = RS pretrained models in huggingface style
Expand Down

0 comments on commit 1b6d50c

Please sign in to comment.