Felafax BlogTune Llama3 405B on AMD MI300x (yolculuğumuz)
Giriş
- Açık kaynak modeller büyüdükçe, büyük ölçekli yapay zeka eğitimini kaldırabilecek güçlü altyapıya olan ihtiyaç artıyor
- Felafax, LLaMA 3.1 405B modelini AMD GPU üzerinde ince ayar yaparak AMD donanımının verimliliğini gösterdi
- Tüm çalışmayı GitHub'da açık kaynak olarak yayımladı
- AMD MI300X GPU, NVIDIA yapay zeka donanımına kıyasla yüksek performans sunuyor
- Proje, TensorWave'in desteği sayesinde mümkün oldu
JAX nedir ve neden seçildi
- JAX, NumPy benzeri bir API, otomatik türev alma ve Google'ın XLA derleyicisini birleştiren güçlü bir makine öğrenimi kütüphanesidir
- Model paralelleştirme için mükemmel API'ler sunar ve bu da onu büyük ölçekli model eğitimi için ideal hale getirir
JAX'in avantajları
- Saf fonksiyonlar: JAX, saf fonksiyonlar yazmayı teşvik eder; bu da kodun kurgulanmasını, hata ayıklanmasını ve okunmasını kolaylaştırır
- Gelişmiş paralelleştirme: JAX'in esnek JIT API'si, büyük ölçekli eğitim için kritik olan gelişmiş veri ve model paralelliğini destekler
- Temiz kod tabanı: JAX'in tasarım felsefesi, donanım platformları arasında taşınabilir kod yazmayı teşvik eder
JAX neden NVIDIA dışı donanımlarda öne çıkıyor
- Donanımdan bağımsız yaklaşım: JAX, hesaplamaları donanımdan bağımsız bir ara gösterime derlemek için XLA derleyicisini kullanır
- Platformdan bağımsız optimizasyon: XLA derleyicisi, optimizasyonları donanımdan bağımsız olarak gerçekleştirir
- Kolay taşınabilirlik: JAX kullanıldığında NVIDIA'dan AMD'ye geçerken kod değişiklikleri minimum düzeyde kalır
AMD GPU üzerinde JAX kurulumu
- Docker imajı çekilip konteyner başlatıldıktan sonra kurulum doğrulanır
- LLaMA 405B modeli, 8 adet AMD MI300x GPU kullanılarak eğitilir
LLaMA 405B eğitimi: performans ve ölçeklenebilirlik
- LLaMA 405B modeli, JAX kullanılarak AMD GPU üzerinde eğitilir
- LoRA ince ayarıyla model ağırlıkları ve LoRA parametreleri bfloat16 hassasiyetinde ayarlanır
- Model boyutu: yaklaşık 800 GB VRAM kullanır
- LoRA ağırlıkları ve optimizer durumu: yaklaşık 400 GB VRAM kullanır
- Toplam VRAM kullanımı: yaklaşık 1200 GB
- Eğitim hızı: saniyede yaklaşık 35 token
- Bellek verimliliği: yaklaşık %70 korunur
- Ölçeklenebilirlik: JAX kullanılarak 8 GPU üzerinde neredeyse doğrusal biçimde ölçeklenir
Eğitim kurulumumuz
- LLaMA 3.1, PyTorch'tan JAX'e dönüştürüldü
- Model yükleme ve parametre shard etme yoluyla verimli şekilde dağıtıldı
JAX'te parametre shard etme
- JAX'in device mesh özelliği kullanılarak model 8 AMD GPU'ya verimli biçimde dağıtılır
- Parametre shard kuralları tanımlanarak her tensörün boyutları mesh eksenlerine göre shard edilir
LoRA eğitim uygulaması
- LoRA, ağırlık güncellemelerini düşük dereceli matrislere ayırarak eğitilebilir parametre sayısını azaltır
- LoRA parametrelerini içeren
LoRADense katmanı uygulanır
- LoRA parametreleri verimli biçimde dağıtılarak bellek kullanımı ve hesaplama verimliliği optimize edilir
Sonuç
- AMD GPU ve JAX kullanarak LLaMA 3.1 405B modeline ince ayar yapma deneyimi oldukça olumluydu
- JAX'in güçlü paralelleştirme yetenekleri ve donanımdan bağımsız yaklaşımı kullanılarak model verimli biçimde dağıtıldı
- AMD GPU'ların büyük ölçekli yapay zeka eğitimi için güçlü bir alternatif olduğu gösterildi
- Tüm kod GitHub deposunda incelenebilir ve doğrudan çalıştırılabilir
GN⁺ özeti
- Bu yazı, AMD GPU ve JAX kullanarak büyük ölçekli yapay zeka modellerinin nasıl verimli biçimde eğitilebileceğini açıklıyor
- AMD donanımının NVIDIA'ya kıyasla maliyet açısından verimli bir alternatif olduğunu vurguluyor
- JAX'in donanımdan bağımsız yaklaşımı, kod taşınabilirliğini artırıyor ve bakımı kolaylaştırıyor
- Büyük ölçekli model eğitimiyle ilgilenenler için faydalı bilgiler ve uygulamalı kod sağlıyor
- Benzer işlevlere sahip projeler arasında NVIDIA'nın CUDA'sı ve PyTorch yer alıyor
1 yorum
Hacker News görüşleri
JAX kullanarak Llama3.1 405B modelini 8xAMD MI300x GPU üzerinde fine-tune etme başarısını paylaşıyor
Bellek kısıtlarını aşıp JIT derlenmiş sürümü çalıştırma yöntemlerinin araştırılmasını öneriyor
AMD GPU ve ROCm desteğiyle ilgili deneyimini paylaşıyor
405B modelinin inference tarafında yaptığı deneyi paylaşıyor
torch.cuda'nın o kadar da kötü olmadığını düşünüyorrocm:pytorchcontainer'ını kullanmanınrocm:jaxcontainer'ını kullanmak kadar kolay olduğunu belirtiyorPerformans verisinin olmamasına dair soru soruyor
Obsidian'ın (not alma uygulaması) neden bunu yaptığına dair soru işareti dile getiriyor
@dang'den URL'lerde kullanıcı adının yer almasını istiyor