transformer

attention

self attention

transformer encoder

cross attention

transformer decoder

causal attention

CLIP

class Clip(nn.Module):
    def __init__(self,motion_dim=75,music_dim=438,feature_dim=256):
        super(Clip, self).__init__()

        self.motion_encoder = MotionEncoder(input_channels=motion_dim,feature_dim=feature_dim)
        self.music_encoder = MusicEncoder(input_channels=music_dim,feature_dim=feature_dim)

        self.motion_project = nn.Linear(feature_dim, feature_dim)
        self.music_project = nn.Linear(feature_dim, feature_dim)

        self.temperature = nn.Parameter(torch.tensor(1.0))

        self.criterion = nn.CrossEntropyLoss()

    def forward(self, motion:Tensor, music:Tensor):
        assert motion.shape[1] == music.shape[1]

        b,s,c= motion.shape

        motion_features = self.motion_encoder(motion)
        music_features = self.music_encoder(music)

        motion_features =F.normalize( self.motion_project(motion_features),p=2,dim=-1)
        music_features = F.normalize( self.music_project(music_features),p=2,dim=-1)


        # relation=(motion_features@music_features.T)*(1.0 / math.sqrt(c))

        # batch matrix multiplication and .mT is batch transpose matrix
        logits=torch.bmm(motion_features,music_features.mT)*self.temperature

        labels=torch.arange(s).repeat(b,1).to(motion.device)

        loss_motion = self.criterion(logits, labels)
        loss_music = self.criterion(logits.mT, labels)

        loss=(loss_motion+loss_music)/2

        return (motion_features,music_features),(loss,loss_motion,loss_music)