Commit Graph

5 Commits (a2f539f0552a9958a1960c8f5079a8b3782eb803)

Author SHA1 Message Date
Ross Wightman ab8cb070fc Add xavier_uniform init of MNVC hybrid attention modules. Small improvement in training stability. 2024-07-26 17:03:40 -07:00
Ross Wightman 2180800646 MQA query_strides bugs fix #2237. No padding for avg_pool2d if not 'same', use scale_factor for Upsample. 2024-07-19 14:26:54 -07:00
Ross Wightman 7fe96e7a92 More MobileNet-v4 fixes
* missed final norm after post pooling 1x1 PW head conv
* improve repr of model by flipping a few modules to None when not used, nn.Sequential for MultiQueryAttention query/key/value/output
* allow layer scaling to be enabled/disabled at model variant level, conv variants don't use it
2024-05-24 15:09:29 -07:00
Ross Wightman 70176a2dae torchscript typing fixes 2024-05-23 11:43:05 -07:00
Ross Wightman 2a1a6b1236 Adding missing attention2d.py 2024-05-23 11:06:32 -07:00