mlx: bump dependency (#16935)

Update MLX to 548dd80.

Fix direct MLX tests to run on pinned MLX threads so test execution matches the runner's MLX thread-affinity model.
This commit is contained in:
Daniel Hiltgen 2026-06-29 09:39:11 -07:00 committed by GitHub
parent 32a97b7493
commit 7926b99e0e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 101 additions and 56 deletions

View file

@ -1 +1 @@
51b2768da7e1897d3c4258f7ddbb47083d1eef01
548dd80e87454f6e4c1c7736ce09551d145c11d5

View file

@ -43,8 +43,12 @@ func TestFromValues(t *testing.T) {
}
func TestComparisonOpsAndBernoulli(t *testing.T) {
skipIfNoMLX(t)
withMLXThread(t, func() {
testComparisonOpsAndBernoulli(t)
})
}
func testComparisonOpsAndBernoulli(t *testing.T) {
a := FromValues([]float32{1, 2, 3}, 3)
b := FromValues([]float32{1, 1, 4}, 3)
eq := a.Equal(b).AsType(DTypeInt32)

View file

@ -1,21 +1,44 @@
package gemma4
import (
"runtime"
"testing"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
func useMLXTestThread(t *testing.T) {
t.Helper()
runtime.LockOSThread()
initialized := false
t.Cleanup(func() {
if initialized {
mlx.Sweep()
mlx.ClearCache()
if mlx.GPUIsAvailable() {
mlx.SetDefaultDeviceGPU()
}
}
runtime.UnlockOSThread()
})
if err := mlx.CheckInit(); err != nil {
t.Skipf("MLX not available: %v", err)
}
initialized = true
if mlx.GPUIsAvailable() {
mlx.SetDefaultDeviceGPU()
}
}
// onesLike creates a tensor of the given shape filled with a small constant.
func onesLike(shape ...int) *mlx.Array {
return mlx.AddScalar(mlx.Zeros(mlx.DTypeBFloat16, shape...), 0.01)
}
func TestMoEForward(t *testing.T) {
skipIfNoMLX(t)
// Small config matching 26b architecture pattern.
cfg := &TextConfig{
func tinyMoEConfig() *TextConfig {
return &TextConfig{
HiddenSize: 16, // tiny for testing
NumAttentionHeads: 2,
NumKeyValueHeads: 1,
@ -31,73 +54,91 @@ func TestMoEForward(t *testing.T) {
SlidingScale: 1.0,
FullScale: 1.0,
}
}
B, L := int32(1), int32(3)
x := onesLike(int(B), int(L), int(cfg.HiddenSize))
// Test Router.Forward.
router := &Router{
func newRouter(cfg *TextConfig) *Router {
return &Router{
Proj: linearFromWeight(onesLike(int(cfg.NumExperts), int(cfg.HiddenSize))),
Scale: onesLike(int(cfg.HiddenSize)),
}
}
t.Run("Router", func(t *testing.T) {
scores, inds := router.Forward(x, cfg)
mlx.Eval(scores, inds)
sDims := scores.Dims()
iDims := inds.Dims()
t.Logf("scores shape: %v, inds shape: %v", sDims, iDims)
if len(sDims) != 2 || sDims[0] != int(B*L) || sDims[1] != int(cfg.TopKExperts) {
t.Errorf("scores shape = %v, want [%d, %d]", sDims, B*L, cfg.TopKExperts)
}
if len(iDims) != 2 || iDims[0] != int(B*L) || iDims[1] != int(cfg.TopKExperts) {
t.Errorf("inds shape = %v, want [%d, %d]", iDims, B*L, cfg.TopKExperts)
}
})
// Test MoEBlock.Forward.
moe := &MoEBlock{
func newMoEBlock(cfg *TextConfig) *MoEBlock {
return &MoEBlock{
GateWeight: onesLike(int(cfg.NumExperts), int(cfg.HiddenSize), int(cfg.ExpertIntermediateSize)),
UpWeight: onesLike(int(cfg.NumExperts), int(cfg.HiddenSize), int(cfg.ExpertIntermediateSize)),
DownWeight: onesLike(int(cfg.NumExperts), int(cfg.ExpertIntermediateSize), int(cfg.HiddenSize)),
PerExpertScale: onesLike(int(cfg.NumExperts)),
}
}
t.Run("MoEBlock", func(t *testing.T) {
scores, inds := router.Forward(x, cfg)
mlx.Eval(scores, inds)
func TestMoERouterForward(t *testing.T) {
useMLXTestThread(t)
out := moe.Forward(x, scores, inds, cfg)
mlx.Eval(out)
cfg := tinyMoEConfig()
B, L := int32(1), int32(3)
x := onesLike(int(B), int(L), int(cfg.HiddenSize))
router := newRouter(cfg)
outDims := out.Dims()
t.Logf("MoE output shape: %v", outDims)
scores, inds := router.Forward(x, cfg)
mlx.Eval(scores, inds)
if len(outDims) != 3 || outDims[0] != int(B) || outDims[1] != int(L) || outDims[2] != int(cfg.HiddenSize) {
t.Errorf("output shape = %v, want [%d, %d, %d]", outDims, B, L, cfg.HiddenSize)
}
})
sDims := scores.Dims()
iDims := inds.Dims()
t.Logf("scores shape: %v, inds shape: %v", sDims, iDims)
// Test with larger batch to exercise the sorted GatherMM path (B*L >= 64).
t.Run("MoEBlock_sorted", func(t *testing.T) {
bigB, bigL := int32(1), int32(128)
bigX := onesLike(int(bigB), int(bigL), int(cfg.HiddenSize))
if len(sDims) != 2 || sDims[0] != int(B*L) || sDims[1] != int(cfg.TopKExperts) {
t.Errorf("scores shape = %v, want [%d, %d]", sDims, B*L, cfg.TopKExperts)
}
if len(iDims) != 2 || iDims[0] != int(B*L) || iDims[1] != int(cfg.TopKExperts) {
t.Errorf("inds shape = %v, want [%d, %d]", iDims, B*L, cfg.TopKExperts)
}
}
scores, inds := router.Forward(bigX, cfg)
mlx.Eval(scores, inds)
func TestMoEBlockForward(t *testing.T) {
useMLXTestThread(t)
out := moe.Forward(bigX, scores, inds, cfg)
mlx.Eval(out)
cfg := tinyMoEConfig()
B, L := int32(1), int32(3)
x := onesLike(int(B), int(L), int(cfg.HiddenSize))
router := newRouter(cfg)
moe := newMoEBlock(cfg)
outDims := out.Dims()
t.Logf("MoE sorted output shape: %v", outDims)
scores, inds := router.Forward(x, cfg)
mlx.Eval(scores, inds)
if len(outDims) != 3 || outDims[0] != int(bigB) || outDims[1] != int(bigL) || outDims[2] != int(cfg.HiddenSize) {
t.Errorf("output shape = %v, want [%d, %d, %d]", outDims, bigB, bigL, cfg.HiddenSize)
}
})
out := moe.Forward(x, scores, inds, cfg)
mlx.Eval(out)
outDims := out.Dims()
t.Logf("MoE output shape: %v", outDims)
if len(outDims) != 3 || outDims[0] != int(B) || outDims[1] != int(L) || outDims[2] != int(cfg.HiddenSize) {
t.Errorf("output shape = %v, want [%d, %d, %d]", outDims, B, L, cfg.HiddenSize)
}
}
func TestMoEBlockSortedForward(t *testing.T) {
useMLXTestThread(t)
cfg := tinyMoEConfig()
B, L := int32(1), int32(128)
x := onesLike(int(B), int(L), int(cfg.HiddenSize))
router := newRouter(cfg)
moe := newMoEBlock(cfg)
scores, inds := router.Forward(x, cfg)
mlx.Eval(scores, inds)
out := moe.Forward(x, scores, inds, cfg)
mlx.Eval(out)
outDims := out.Dims()
t.Logf("MoE sorted output shape: %v", outDims)
if len(outDims) != 3 || outDims[0] != int(B) || outDims[1] != int(L) || outDims[2] != int(cfg.HiddenSize) {
t.Errorf("output shape = %v, want [%d, %d, %d]", outDims, B, L, cfg.HiddenSize)
}
}
// TestRouterForwardMatchesLegacy verifies the optimized Router.Forward —
@ -106,7 +147,7 @@ func TestMoEForward(t *testing.T) {
// normalized scores as the legacy path that softmaxes over every expert
// first, gathers the top-k probabilities, then renormalizes.
func TestRouterForwardMatchesLegacy(t *testing.T) {
skipIfNoMLX(t)
useMLXTestThread(t)
cfg := &TextConfig{
HiddenSize: 8,