>>106758784
class SwiGLUTorch(nn.Module):
"""
SwiGLU MLP: y = W3( SiLU(W1 x) ⊙ (W2 x) )
- Supports packed weights via a single Linear projecting to 2*hidden_features.
- For compatibility with callers that pass extra kwargs (e.g., HW=...), forward accepts **kwargs.
"""
def __init__(self, in_features, hidden_features=None, out_features=None, bias=True, _pack_weights=True):
super().__init__()
self.in_features = in_features
self.hidden_features = hidden_features or in_features
self.out_features = out_features or in_features
self._pack_weights = _pack_weights
if _pack_weights:
self.w12 = nn.Linear(in_features, 2 * self.hidden_features, bias=bias)
self.w1 = None
self.w2 = None
else:
self.w12 = None
self.w1 = nn.Linear(in_features, self.hidden_features, bias=bias)
self.w2 = nn.Linear(in_features, self.hidden_features, bias=bias)
self.w3 = nn.Linear(self.hidden_features, self.out_features, bias=bias)
def forward(self, x, *args, **kwargs):
if self.w12 is not None:
x1, x2 = self.w12(x).chunk(2, dim=-1)
else:
x1 = self.w1(x)
x2 = self.w2(x)
return self.w3(F.silu(x1) * x2)
It's how the layers and parameters are glued together and controls how the data flows..