轨迹预测模型
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
#######################################
class TrajPreModel(nn.Module):
"""self-attention model"""
def __init__(self, loc_size=528, loc_emb_size=128, hidden_size=32, head_num=1, dropout_p=0):
super(TrajPreModel, self).__init__()
self.loc_size = loc_size
self.loc_emb_size = loc_emb_size
self.hidden_size = hidden_size
self.heads = head_num
self.dropout_p = dropout_p
# embeding
self.emb_loc = nn.Embedding(self.loc_size, self.loc_emb_size)
self.weight = self.emb_loc.weight
#-------------model---------------
self.attention = MultiSelfAttention(self.heads, self.loc_emb_size, dropout=self.dropout_p)
self.fc = nn.Linear(self.loc_emb_size, self.loc_size)
self.is_weight_sharing = False#is_weight_sharing
self.init_weights()
self.dropout = nn.Dropout(p=dropout_p)
def init_weights(self):
ih = (param.data for name, param in self.named_parameters() if 'weight_ih' in name)
hh = (param.data for name, param in self.named_parameters() if 'weight_hh' in name)
b = (param.data for name, param in self.named_parameters() if 'bias' in name)
for t in ih:
nn.init.xavier_uniform(t)
for t in hh:
nn.init.orthogonal(t)
for t in b:
nn.init.constant_(t, 0)
def forward(self, x):
seq = x[1] # [batch_size, seq_len]
loc_emb = self.emb_loc(seq)
output = self.dropout(loc_emb)
#Self-attention
output = self.attention(output,output, output)
output = self.dropout(output)
if not self.is_weight_sharing:
y = self.fc(output)
else:
y = F.linear(output, self.weight)
score = F.log_softmax(y, dim=-1)
return score.view(-1, self.loc_size) # [batch_size, seq_len, loc_size]