1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
| class Memory(nn.Module): def __init__(self, memory_size, feature_dim, key_dim, temp_update, temp_gather): super(Memory, self).__init__() self.memory_size = memory_size self.feature_dim = feature_dim self.key_dim = key_dim self.temp_update = temp_update self.temp_gather = temp_gather def hard_neg_mem(self, mem, i): similarity = torch.matmul(mem,torch.t(self.keys_var)) similarity[:,i] = -1 _, max_idx = torch.topk(similarity, 1, dim=1) return self.keys_var[max_idx] def random_pick_memory(self, mem, max_indices): m, d = mem.size() output = [] for i in range(m): flattened_indices = (max_indices==i).nonzero() a, _ = flattened_indices.size() if a != 0: number = np.random.choice(a, 1) output.append(flattened_indices[number, 0]) else: output.append(-1) return torch.tensor(output) def get_update_query(self, mem, max_indices, update_indices, score, query, train): m, d = mem.size() if train: query_update = torch.zeros((m,d)).cuda() for i in range(m): idx = torch.nonzero(max_indices.squeeze(1)==i) a, _ = idx.size() if a != 0: query_update[i] = torch.sum(((score[idx,i] / torch.max(score[:,i])) *query[idx].squeeze(1)), dim=0) else: query_update[i] = 0 return query_update else: query_update = torch.zeros((m,d)).cuda() for i in range(m): idx = torch.nonzero(max_indices.squeeze(1)==i) a, _ = idx.size() if a != 0: query_update[i] = torch.sum(((score[idx,i] / torch.max(score[:,i])) *query[idx].squeeze(1)), dim=0) else: query_update[i] = 0 return query_update def get_score(self, mem, query): bs, h,w,d = query.size() m, d = mem.size() score = torch.matmul(query, torch.t(mem)) score = score.view(bs*h*w, m) score_query = F.softmax(score, dim=0) score_memory = F.softmax(score,dim=1) return score_query, score_memory def forward(self, query, keys, train=True): batch_size, dims,h,w = query.size() query = F.normalize(query, dim=1) query = query.permute(0,2,3,1) if train: separateness_loss, compactness_loss = self.gather_loss(query,keys, train) updated_query, softmax_score_query,softmax_score_memory = self.read(query, keys) updated_memory = self.update(query, keys, train) return updated_query, updated_memory, softmax_score_query, softmax_score_memory, separateness_loss, compactness_loss else: compactness_loss, query_re, top1_keys, keys_ind = self.gather_loss(query,keys, train) updated_query, softmax_score_query,softmax_score_memory = self.read(query, keys) updated_memory = keys return updated_query, updated_memory, softmax_score_query, softmax_score_memory, query_re, top1_keys,keys_ind, compactness_loss def update(self, query, keys,train): batch_size, h,w,dims = query.size() softmax_score_query, softmax_score_memory = self.get_score(keys, query) query_reshape = query.contiguous().view(batch_size*h*w, dims) _, gathering_indices = torch.topk(softmax_score_memory, 1, dim=1) _, updating_indices = torch.topk(softmax_score_query, 1, dim=0) if train: query_update = self.get_update_query(keys, gathering_indices, updating_indices, softmax_score_query, query_reshape,train) updated_memory = F.normalize(query_update + keys, dim=1) else: query_update = self.get_update_query(keys, gathering_indices, updating_indices, softmax_score_query, query_reshape, train) updated_memory = F.normalize(query_update + keys, dim=1) return updated_memory.detach() def pointwise_gather_loss(self, query_reshape, keys, gathering_indices, train): n,dims = query_reshape.size() loss_mse = torch.nn.MSELoss(reduction='none') pointwise_loss = loss_mse(query_reshape, keys[gathering_indices].squeeze(1).detach()) return pointwise_loss def gather_loss(self,query, keys, train): batch_size, h,w,dims = query.size() if train: loss = torch.nn.TripletMarginLoss(margin=1.0) loss_mse = torch.nn.MSELoss() softmax_score_query, softmax_score_memory = self.get_score(keys, query) query_reshape = query.contiguous().view(batch_size*h*w, dims) _, gathering_indices = torch.topk(softmax_score_memory, 2, dim=1) pos = keys[gathering_indices[:,0]] neg = keys[gathering_indices[:,1]] top1_loss = loss_mse(query_reshape, pos.detach()) gathering_loss = loss(query_reshape,pos.detach(), neg.detach()) return gathering_loss, top1_loss else: loss_mse = torch.nn.MSELoss() softmax_score_query, softmax_score_memory = self.get_score(keys, query) query_reshape = query.contiguous().view(batch_size*h*w, dims) _, gathering_indices = torch.topk(softmax_score_memory, 1, dim=1) gathering_loss = loss_mse(query_reshape, keys[gathering_indices].squeeze(1).detach()) return gathering_loss, query_reshape, keys[gathering_indices].squeeze(1).detach(), gathering_indices[:,0] def read(self, query, updated_memory): batch_size, h,w,dims = query.size() softmax_score_query, softmax_score_memory = self.get_score(updated_memory, query) query_reshape = query.contiguous().view(batch_size*h*w, dims) concat_memory = torch.matmul(softmax_score_memory.detach(), updated_memory) updated_query = torch.cat((query_reshape, concat_memory), dim = 1) updated_query = updated_query.view(batch_size, h, w, 2*dims) updated_query = updated_query.permute(0,3,1,2) return updated_query, softmax_score_query, softmax_score_memory
|