Adam J. Stewart
f5c4d5cbb7
Add missing imports
2025-01-11 15:13:16 +01:00
Adam J. Stewart
19aaea3c8f
Fix nn.Module type hints
2025-01-11 15:09:21 +01:00
Louis Lac
2d5277e858
Merge branch 'main' into fix-mqa-v2
2025-01-02 00:11:22 +01:00
Louis Lac
2d734d9058
Fixed unfused attn2d scale
2025-01-01 12:34:07 -08:00
Louis Lac
6171e756d3
Fix MQA V2 scale and out shape
2025-01-01 15:37:28 +01:00
Ross Wightman
ab8cb070fc
Add xavier_uniform init of MNVC hybrid attention modules. Small improvement in training stability.
2024-07-26 17:03:40 -07:00
Ross Wightman
2180800646
MQA query_strides bugs fix #2237 . No padding for avg_pool2d if not 'same', use scale_factor for Upsample.
2024-07-19 14:26:54 -07:00
Ross Wightman
7fe96e7a92
More MobileNet-v4 fixes
...
* missed final norm after post pooling 1x1 PW head conv
* improve repr of model by flipping a few modules to None when not used, nn.Sequential for MultiQueryAttention query/key/value/output
* allow layer scaling to be enabled/disabled at model variant level, conv variants don't use it
2024-05-24 15:09:29 -07:00
Ross Wightman
70176a2dae
torchscript typing fixes
2024-05-23 11:43:05 -07:00
Ross Wightman
2a1a6b1236
Adding missing attention2d.py
2024-05-23 11:06:32 -07:00