A **Swich-gated linear unit** takes two linear projections A and B of the input: A drives the Swish gate and B modulates the actual value stream. The Swish gate is computed by multiplying the first linear projection A element-wise with the sigmoid (logistic) function $sig$. Finally, this gate modulates the value stream through element-wise multiplication (Hadamard product) of the gate with the second linear projection B, to produce the output.
$(linA(x) \odot sig(linA(x)) \odot linB(x)$
Where $\odot$ is the Hadamard product (`*` operator applied to tensors in PyTorch).
```Python
import torch
import torch.nn as nn
class SwiGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.linA = nn.Linear(dim_in, dim_out)
self.linB = nn.Linear(dim_in, dim_out)
def forward(self, x):
a = self.linA(x)
gate = a * torch.sigmoid(a)
return gate * self.linB(x)
```