Problem formulation (Easeformer Core)
headi=Softmax(d(ht)TWiQ(ht)TWiKT)(ht)TWiV
# decoder input
if self.args.padding == 0:
dec_inp_zero = torch.zeros([batch_y.shape[0], self.args.pred_len, (batch_y.shape[-1]-1)]).float()
dec_inp = torch.cat((batch_y[:,self.args.label_len:self.args.label_len+self.args.pred_len,:1],dec_inp_zero), dim=2)
elif self.args.padding == 1:
dec_inp_one = torch.ones([batch_y.shape[0], self.args.pred_len, batch_y.shape[-1]]).float()
dec_inp = torch.cat((batch_y[:,self.args.label_len:self.args.label_len+self.args.pred_len,:1],dec_inp_one), dim=2)
dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device)
# encoder - decoder
if self.args.use_amp:
with torch.cuda.amp.autocast():
if self.args.output_attention:
outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]
else:
outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
else:
if self.args.output_attention:
outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]
else:
outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)