diff --git a/dataset.py b/dataset.py index 3b93666..ccbf52e 100644 --- a/dataset.py +++ b/dataset.py @@ -161,10 +161,10 @@ def next_batch(self,batchsize): for i in range(batchsize): for k in range(self.seq_length): - x_batch[i,k,0,:,:] = cv2.imread(root_path+"/SCMD_"+str(random_order[i])+"/"+str(k+1)+".png",cv2.IMREAD_GRAYSCALE) - y_batch[i,0,0,:,:] = cv2.imread(root_path+"/SCMD_"+str(random_order[i])+"/"+str(self.seq_length+1)+".png",cv2.IMREAD_GRAYSCALE) + x_batch[i,k,0,:,:] = cv2.imread(root_path+"/"+str(random_order[i])+"/"+str(k+1)+".png",cv2.IMREAD_GRAYSCALE) + y_batch[i,0,0,:,:] = cv2.imread(root_path+"/"+str(random_order[i])+"/"+str(self.seq_length+1)+".png",cv2.IMREAD_GRAYSCALE) - x_batch,y_batch = x_batch*10.0,y_batch*10.0 + x_batch,y_batch = x_batch,y_batch if self.baseline: x_batch = x_batch[:,:,:,100,100]