Skip to content

Commit

Permalink
Fix minor issues for final release
Browse files Browse the repository at this point in the history
  • Loading branch information
washingtonsk8 committed Apr 1, 2022
1 parent 4e1c22f commit 18f601f
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 80 deletions.
28 changes: 5 additions & 23 deletions hubconf.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,10 @@
dependencies = ['torch', 'torchvision', 'torchvideo', 'numpy']

from semantic_encoding.utils import init_weights
from text_driven_video_acceleration import JointModel
import torch
import numpy as np

def TextDrivenVideoAcceleration(pretrained=False, progress=False, sent_emb_size=512, hidden_feat_emb_size=512, final_feat_emb_size=128, sent_att_size=1024, word_att_size=1024, use_visual_shortcut=True, use_sentence_level_attention=True, use_word_level_attention=True, word_embeddings=None, fine_tune_word_embeddings=False, fine_tune_resnet=False, learn_first_hidden_vector=False, action_size=3, ):
if not word_embeddings:
word_embeddings = np.random.random((400002, 300)).astype(np.float32)

model = JointModel(vocab_size=400002,
def TextDrivenVideoAcceleration(pretrained=False, progress=False, sent_emb_size=512, hidden_feat_emb_size=512, final_feat_emb_size=128, sent_att_size=1024, word_att_size=1024, use_visual_shortcut=True, use_sentence_level_attention=True, use_word_level_attention=True, learn_first_hidden_vector=False, action_size=3):

model = JointModel(vocab_size=400002, # Number of words in the GloVe vocabulary
doc_emb_size=512, # R(2+1)D embedding size
sent_emb_size=sent_emb_size,
word_emb_size=300, # GloVe embeddings size
Expand All @@ -23,20 +18,7 @@ def TextDrivenVideoAcceleration(pretrained=False, progress=False, sent_emb_size=
use_sentence_level_attention=use_sentence_level_attention,
use_word_level_attention=use_word_level_attention,
learn_first_hidden_vector=learn_first_hidden_vector,
action_size=action_size)

# Init word embeddings layer with pretrained embeddings
model.vdan_plus.text_embedder.doc_embedder.sent_embedder.init_pretrained_embeddings(torch.from_numpy(word_embeddings))
model.vdan_plus.text_embedder.doc_embedder.sent_embedder.allow_word_embeddings_finetunening(fine_tune_word_embeddings) # Make it available to finetune the word embeddings
model.vdan_plus.vid_embedder.fine_tune(fine_tune_resnet) # Freeze/Unfreeze R(2+1)D layers. We didn't use it in our paper. But, feel free to try ;)

model.vdan_plus.apply(init_weights) # Apply function "init_weights" to all FC layers of our model.

if pretrained:
vdan_plus_state_dict = torch.hub.load_state_dict_from_url('https://github.com/verlab/TextDrivenVideoAcceleration_TPAMI_2022/releases/download/pre_release/vdan+_pretrained_model.pth', progress=progress)
agent_state_dict = torch.hub.load_state_dict_from_url('https://github.com/verlab/TextDrivenVideoAcceleration_TPAMI_2022/releases/download/pre_release/youcookii_saffa_model.pth', progress=progress)

model.vdan_plus.load_state_dict(vdan_plus_state_dict['model_state_dict'])
model.policy.load_state_dict(agent_state_dict['policy_state_dict'])
action_size=action_size,
pretrained=pretrained)

return model
9 changes: 6 additions & 3 deletions semantic_encoding/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,14 @@ def save_checkpoint(epoch, model, optimizer, word_map, datetimestamp, model_para
print('\t[{}] Done!\n'.format(datetime.now().strftime('%Y-%m-%d %H:%M:%S')))


def load_checkpoint(filename, is_vdan=False):
def load_checkpoint(filename_or_url, is_vdan=False, load_from_url=False, progress=False):
"""
Load model checkpoint.
"""
checkpoint = torch.load(filename, map_location=device)
if load_from_url:
checkpoint = torch.hub.load_state_dict_from_url(filename_or_url, progress=progress)
else:
checkpoint = torch.load(filename_or_url, map_location=device)

epoch = checkpoint['epoch']
word_map = checkpoint['word_map']
Expand Down Expand Up @@ -179,7 +182,7 @@ def computeMRR(X, Y):

return MRR

def extract_vdan_plus_feats(model, train_params, model_params, word_map, video_filename, document_filename, batch_size, max_frames, use_vid_transformer=False, tqdm_leave=True):
def extract_vdan_plus_feats(model, train_params, model_params, word_map, video_filename, document_filename, batch_size, max_frames, tqdm_leave=True):
# Load inputs
document = np.loadtxt(document_filename, delimiter='\n', dtype=str, encoding='utf-8')

Expand Down
112 changes: 58 additions & 54 deletions text_driven_video_acceleration.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,29 @@
import os
import cv2
import torch
import torch.nn as nn
import json
import cv2
import os
import numpy as np
import torchvideo.transforms as VT
from torchvision.transforms import Compose
from tqdm import tqdm
from semantic_encoding.models import VDAN_PLUS
from semantic_encoding.utils import convert_sentences_to_word_idxs
from semantic_encoding.utils import load_checkpoint, extract_vdan_plus_feats
from rl_fast_forward.REINFORCE.policy import Policy
from rl_fast_forward.REINFORCE.critic import Critic

KINECTS400_MEAN = [0.43216, 0.394666, 0.37645]
KINECTS400_STD = [0.22803, 0.22145, 0.216989]

MIN_SKIP = 1
MAX_SKIP = 25
MAX_ACC = 5
MIN_ACC = 1
MAX_FRAMES = 32

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class JointModel(nn.Module):
def __init__(
self, vocab_size, doc_emb_size, sent_emb_size, word_emb_size, sent_rnn_layers, word_rnn_layers, hidden_feat_emb_size, final_feat_emb_size,
def __init__(self, vocab_size, doc_emb_size, sent_emb_size, word_emb_size, sent_rnn_layers, word_rnn_layers, hidden_feat_emb_size, final_feat_emb_size,
sent_att_size, word_att_size, use_visual_shortcut=False, use_sentence_level_attention=False, use_word_level_attention=False,
learn_first_hidden_vector=True, action_size=3):
learn_first_hidden_vector=True, action_size=3, pretrained=False, progress=False):
super(JointModel, self).__init__()

self.vdan_plus = VDAN_PLUS(vocab_size=vocab_size,
doc_emb_size=doc_emb_size, # ResNet-50 embedding size
doc_emb_size=doc_emb_size, # R(2+1)D embedding size
sent_emb_size=sent_emb_size,
word_emb_size=word_emb_size, # GloVe embeddings size
sent_rnn_layers=sent_rnn_layers,
Expand All @@ -51,69 +43,70 @@ def __init__(

self.policy = Policy(state_size=self.state_size, action_size=action_size)
self.critic = Critic(state_size=self.state_size)

if pretrained:
_, vdan_plus_model, _, vdan_plus_word_map, vdan_plus_model_params, vdan_plus_train_params = load_checkpoint('https://github.com/verlab/TextDrivenVideoAcceleration_TPAMI_2022/releases/download/pre_release/vdan+_pretrained_model.pth', load_from_url=True, progress=progress)

self.vdan_plus_data = {
'model_name': 'vdan+_pretrained_model',
'semantic_encoder_model': vdan_plus_model,
'word_map': vdan_plus_word_map,
'train_params': vdan_plus_train_params,
'model_params': vdan_plus_model_params,
'input_frames_length': 32
}

self.vdan_plus = vdan_plus_model

agent_state_dict = torch.hub.load_state_dict_from_url('https://github.com/verlab/TextDrivenVideoAcceleration_TPAMI_2022/releases/download/pre_release/youcookii_saffa_model.pth', progress=progress)
self.policy.load_state_dict(agent_state_dict['policy_state_dict'])

def fast_forward_video(self, video_filename, document, desired_speedup, output_video_filename=None, max_words=20):
word_map = json.load(open(f'{os.path.dirname(os.path.abspath(__file__))}/semantic_encoding/resources/glove6B_word_map.json'))
def fast_forward_video(self, video_filename, document, desired_speedup, output_video_filename=None, vdan_plus_batch_size=16, semantic_embeddings=None):

if semantic_embeddings is None:
semantic_embeddings = self.get_semantic_embeddings(video_filename, document, vdan_plus_batch_size).unsqueeze(0)

semantic_embeddings = semantic_embeddings.to(device)

vid_transformer = Compose([
VT.NDArrayToPILVideo(),
VT.ResizeVideo(112),
VT.PILVideoToTensor(),
VT.NormalizeVideo(mean=KINECTS400_MEAN, std=KINECTS400_STD)
])

converted_sentences, words_per_sentence = convert_sentences_to_word_idxs(document, max_words, word_map)
sentences_per_document = np.array([converted_sentences.shape[0]])

transformed_document = torch.from_numpy(converted_sentences).unsqueeze(0).to(device) # (batch_size, sentence_limit, word_limit)
sentences_per_document = torch.from_numpy(sentences_per_document).to(device) # (batch_size)
words_per_sentence = torch.from_numpy(words_per_sentence).unsqueeze(0).to(device) # (batch_size, sentence_limit)

video = cv2.VideoCapture(video_filename)
if output_video_filename:
video = cv2.VideoCapture(video_filename)
fourcc = cv2.VideoWriter_fourcc('M', 'P', 'E', 'G')
fps = video.get(5)
frame_width = int(video.get(3))
frame_height = int(video.get(4))
output_video = cv2.VideoWriter(output_video_filename, fourcc, fps, (frame_width, frame_height))

num_frames = int(video.get(7))
acceleration = 1
skip = 1
frame_idx = 0
selected_frames = []

selected_frames = []
num_frames = semantic_embeddings.shape[0]

self.Im = torch.eye(self.q).to(device)
self.NRPE = self.get_NRPE(num_frames).to(device)

curr_frames = [None for _ in range(MAX_FRAMES)]
skips = [skip]
pbar = tqdm(total=num_frames)
while frame_idx < num_frames:
video.set(1, frame_idx)
for i in range(MAX_FRAMES):
ret, frame = video.read()
if output_video_filename:
i = 0
while i < skip and frame_idx < num_frames:
ret, frame = video.read()
i += 1

if not ret:
curr_frames[i] = np.zeros((int(video.get(4)), int(video.get(3)), 3), dtype=np.uint8)
continue

frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
curr_frames[i] = frame
print('Error reading frame: {}'.format(frame_idx))
break

transformed_frames = vid_transformer(curr_frames).unsqueeze(0).to(device)

if output_video_filename:
font = cv2.FONT_HERSHEY_SIMPLEX
cv2.putText(curr_frames[0], '{}x'.format(skip), (50, 50), font, 1, (255, 255, 255), 2, cv2.LINE_AA)
output_video.write(curr_frames[0])

with torch.no_grad():
vid_embedding, text_embedding, _, _, _ = self.vdan_plus(transformed_frames, transformed_document, sentences_per_document, words_per_sentence)
cv2.putText(frame, '{}x'.format(skip), (50, 50), font, 1, (255, 255, 255), 2, cv2.LINE_AA)
output_video.write(frame)

SA_vector = self.Im[int(np.round((np.mean(skips) - desired_speedup) + MAX_SKIP))]
observation = torch.cat([vid_embedding, text_embedding, SA_vector.unsqueeze(0), self.NRPE[frame_idx].unsqueeze(0)], axis=1).unsqueeze(0)
action_probs = self.policy(observation)
observation = torch.cat((semantic_embeddings[frame_idx],
self.NRPE[frame_idx],
self.Im[int(np.round((np.mean(skips) - desired_speedup) + MAX_SKIP))])).unsqueeze(0)

action_probs = self.policy(observation.unsqueeze(0))

action = torch.argmax(action_probs.squeeze(0)).item()

Expand Down Expand Up @@ -154,3 +147,14 @@ def get_NRPE(self, F):
NRPE[-(t+1), odd_idxs] = np.cos(wks*t)

return torch.from_numpy(NRPE)

def get_semantic_embeddings(self, video_filename, document, vdan_plus_batch_size=16):

temp_doc_filename = f'{os.path.basename(os.path.abspath(video_filename)).split(".")[0]}.txt'
np.savetxt(temp_doc_filename, document, fmt='%s')

vid_embeddings, doc_embeddings, _, _, _ = extract_vdan_plus_feats(self.vdan_plus_data['semantic_encoder_model'], self.vdan_plus_data['train_params'], self.vdan_plus_data['model_params'], self.vdan_plus_data['word_map'], video_filename, temp_doc_filename, vdan_plus_batch_size, self.vdan_plus_data['input_frames_length'])

os.remove(temp_doc_filename)

return torch.from_numpy(np.concatenate([doc_embeddings, vid_embeddings], axis=1))

0 comments on commit 18f601f

Please sign in to comment.