AMD GPU ile Llama 405B ince ayarı
(publish.obsidian.md)Felafax BlogTune Llama3 405B on AMD MI300x (yolculuğumuz)
Giriş
- Açık kaynak modeller büyüdükçe, büyük ölçekli yapay zeka eğitimini kaldırabilecek güçlü altyapıya olan ihtiyaç artıyor
- Felafax, LLaMA 3.1 405B modelini AMD GPU üzerinde ince ayar yaparak AMD donanımının verimliliğini gösterdi
- Tüm çalışmayı GitHub'da açık kaynak olarak yayımladı
- AMD MI300X GPU, NVIDIA yapay zeka donanımına kıyasla yüksek performans sunuyor
- Proje, TensorWave'in desteği sayesinde mümkün oldu
JAX nedir ve neden seçildi
- JAX, NumPy benzeri bir API, otomatik türev alma ve Google'ın XLA derleyicisini birleştiren güçlü bir makine öğrenimi kütüphanesidir
- Model paralelleştirme için mükemmel API'ler sunar ve bu da onu büyük ölçekli model eğitimi için ideal hale getirir
JAX'in avantajları
- Saf fonksiyonlar: JAX, saf fonksiyonlar yazmayı teşvik eder; bu da kodun kurgulanmasını, hata ayıklanmasını ve okunmasını kolaylaştırır
- Gelişmiş paralelleştirme: JAX'in esnek JIT API'si, büyük ölçekli eğitim için kritik olan gelişmiş veri ve model paralelliğini destekler
- Temiz kod tabanı: JAX'in tasarım felsefesi, donanım platformları arasında taşınabilir kod yazmayı teşvik eder
JAX neden NVIDIA dışı donanımlarda öne çıkıyor
- Donanımdan bağımsız yaklaşım: JAX, hesaplamaları donanımdan bağımsız bir ara gösterime derlemek için XLA derleyicisini kullanır
- Platformdan bağımsız optimizasyon: XLA derleyicisi, optimizasyonları donanımdan bağımsız olarak gerçekleştirir
- Kolay taşınabilirlik: JAX kullanıldığında NVIDIA'dan AMD'ye geçerken kod değişiklikleri minimum düzeyde kalır
AMD GPU üzerinde JAX kurulumu
- Docker imajı çekilip konteyner başlatıldıktan sonra kurulum doğrulanır
- LLaMA 405B modeli, 8 adet AMD MI300x GPU kullanılarak eğitilir
LLaMA 405B eğitimi: performans ve ölçeklenebilirlik
- LLaMA 405B modeli, JAX kullanılarak AMD GPU üzerinde eğitilir
- LoRA ince ayarıyla model ağırlıkları ve LoRA parametreleri bfloat16 hassasiyetinde ayarlanır
- Model boyutu: yaklaşık 800 GB VRAM kullanır
- LoRA ağırlıkları ve optimizer durumu: yaklaşık 400 GB VRAM kullanır
- Toplam VRAM kullanımı: yaklaşık 1200 GB
- Eğitim hızı: saniyede yaklaşık 35 token
- Bellek verimliliği: yaklaşık %70 korunur
- Ölçeklenebilirlik: JAX kullanılarak 8 GPU üzerinde neredeyse doğrusal biçimde ölçeklenir
Eğitim kurulumumuz
- LLaMA 3.1, PyTorch'tan JAX'e dönüştürüldü
- Model yükleme ve parametre shard etme yoluyla verimli şekilde dağıtıldı
JAX'te parametre shard etme
- JAX'in device mesh özelliği kullanılarak model 8 AMD GPU'ya verimli biçimde dağıtılır
- Parametre shard kuralları tanımlanarak her tensörün boyutları mesh eksenlerine göre shard edilir
LoRA eğitim uygulaması
- LoRA, ağırlık güncellemelerini düşük dereceli matrislere ayırarak eğitilebilir parametre sayısını azaltır
- LoRA parametrelerini içeren
LoRADensekatmanı uygulanır - LoRA parametreleri verimli biçimde dağıtılarak bellek kullanımı ve hesaplama verimliliği optimize edilir
Sonuç
- AMD GPU ve JAX kullanarak LLaMA 3.1 405B modeline ince ayar yapma deneyimi oldukça olumluydu
- JAX'in güçlü paralelleştirme yetenekleri ve donanımdan bağımsız yaklaşımı kullanılarak model verimli biçimde dağıtıldı
- AMD GPU'ların büyük ölçekli yapay zeka eğitimi için güçlü bir alternatif olduğu gösterildi
- Tüm kod GitHub deposunda incelenebilir ve doğrudan çalıştırılabilir
GN⁺ özeti
- Bu yazı, AMD GPU ve JAX kullanarak büyük ölçekli yapay zeka modellerinin nasıl verimli biçimde eğitilebileceğini açıklıyor
- AMD donanımının NVIDIA'ya kıyasla maliyet açısından verimli bir alternatif olduğunu vurguluyor
- JAX'in donanımdan bağımsız yaklaşımı, kod taşınabilirliğini artırıyor ve bakımı kolaylaştırıyor
- Büyük ölçekli model eğitimiyle ilgilenenler için faydalı bilgiler ve uygulamalı kod sağlıyor
- Benzer işlevlere sahip projeler arasında NVIDIA'nın CUDA'sı ve PyTorch yer alıyor
1 yorum
Hacker News yorumları
Yakın zamanda 8xAMD MI300x GPU üzerinde PyTorch yerine JAX ile llama3.1 405B modeline ince ayar yaptık
JAX’in gelişmiş parçalama API’si sayesinde iyi performans aldık; kullandığımız parçalama tekniklerini blogda özetledik. Kodu da yayımladık: https://github.com/felafax/felafax
NVIDIA dışı donanımlarda (TPU, AMD, Trainium) LLM ince ayarı ve sunumu için yapay zeka altyapısı geliştiren küçük bir startup’ız
Birçok şirket AMD GPU’larda PyTorch çalıştırmaya çalışıyor; ancak PyTorch’un
torch.cudaya dascaled_dot_product_attentiongibi NVIDIA ekosistemiyle derinden iç içe olduğunu ve ciddi bir “NVIDIA’dan arındırma” gerektirdiğini düşünüyoruzJAX’in, model kodunu donanımdan bağımsız bir HLO grafiğine derleyip ardından XLA derleyicisinin optimize etmesi ve sonrasında donanıma özel optimizasyonlar uygulaması nedeniyle NVIDIA dışı donanımlara daha uygun olduğunu düşünüyoruz. Aynı LLaMA3 JAX kodu Google TPU ve AMD GPU üzerinde değişiklik gerektirmeden çalıştı
Şirket stratejimiz, modelleri önce JAX’e taşımak, sonra JAX framework’ü ve XLA kernel’lerinden yararlanarak NVIDIA dışı backend’lerde maksimum performansı çıkarmak. Bu yüzden önce Llama 3.1’i PyTorch’tan JAX’e taşıdık; aynı JAX modeli TPU ve AMD GPU üzerinde iyi çalışıyor
Kişisel olarak PyTorch kullanmamın başlıca nedeni, orijinal modelin PyTorch ile yapılmış olması. Farklı model sürümlerinde mantık aynı görünse bile, devasa veri ölçeklerinde çok küçük kayan nokta hataları birikerek model drift’ine yol açabilir
Büyük modellerde bu tür doğruluk uyuşmazlıklarını debug etmek, cehennemin 10. çemberinden bile beter bir işe yakındı
hipblaslt, Composable Kernel FA gibi şeylerJAX’i pek iyi bilmiyorum ama MI300x’te PyTorch eğitim performansının berbat olmasının önemli bir kısmının, içeride kullanılan ROCm kütüphanelerinin yavaş olmasından kaynaklandığını düşünüyorum
Burada çalışmaktan kastım, sürücüyü çalıştırmak için 2 hafta harcadıktan sonra sunucuyu bir daha asla güncelleyememek değil
Karşılaştığınız teknik sorunları da merak ediyorum
Açık konuşmak gerekirse bu performans oldukça kötü. Muhtemelen derlemeyi düzgün çalıştıramamış olmanızdan kaynaklanıyor
405B modelde 35 token/sn elde ediliyor; bu yaklaşık 85 teraflops’a denk geliyor. 8 adet MI300x GPU ise 10,4 petaflops seviyesinde, yani MFU yaklaşık %0,8
Makul eğitim performansı sayılan %30~40 MFU’dan 40~50 kat düşük; bu yüzden AMD açısından darboğazın yazılım yığını olmasını umuyorlardır
GitHub sayfasında “Google Cloud TPU’da LLaMa3.1’i %30 daha düşük maliyetle fine-tune edebilirsiniz” deniyor ama performanstan söz edilmiyor
Harika iş. Yaklaşık bir yıl önce AMD GPU ve ROCm desteğiyle biraz uğraşmıştım; AMD’nin Nvidia’yı yakalaması için hâlâ kat etmesi gereken uzun bir yol olduğu açıktı
JAX’i seçen yaklaşım ilginç; makine öğrenmesinin standart kütüphanesine yakın olan PyTorch’tan uzaklaşırken ne gibi zorluklar yaşadığınızı merak ediyorum
Başta hedefimiz TPU üzerinde LLaMA 3’e ince ayar yapmaktı, ancak PyTorch XLA hantal olduğu için modeli JAX ile yeniden yazmaya karar verdik
Daha önce dediğim gibi JAX’i NVIDIA dışı GPU’lar için daha iyi bir platform olarak görüyoruz ve JAX+openXLA üzerinde NVIDIA dışı GPU’lara yönelik altyapı oluşturmak istiyoruz
Güzel iş. Geçen hafta sonu ben de 405B’nin çıkarım tarafıyla uğraşıyordum [0]
torch.cuda’nın bu kadar kötü olduğundan emin değilim. Çünkü AMD için PyTorch onu sizin yerinize çeviriyor. Bu özsel bir sorundan çok isimlendirme sorunu gibiAslında
rocm:pytorchcontainer’ını çekmek,rocm:jaxcontainer’ını çekmek kadar kolayYayımlanmış çok sayı yok; MFU’nuzun ne olduğunu merak ediyorum
[0] https://x.com/HotAisle/status/1837580046732874026
MFU’yu hesaplamamız gerekiyor. GPU ve VRAM ayrıntıları depoda görülebilir: https://dub.sh/amd-405b-res
Önümüzdeki hafta sonu eğitim çalıştırmasını yeniden denemeyi, tüm eğitim adımını JIT derlemeyi ve o sırada MFU’yu hesaplamayı planlıyoruz
ZML’de ölçtüğümüzde MI300X, H100’den %30 daha hızlıydı. Harika çipler
8xAMD MI300 host kiralayabileceğimiz bir bulut sağlayıcısı var mı merak ediyorum
İş gereği AWS’i çok kullanıyorum ve AMD GPU’yu bir denemek istemiştim
Performans verileri nerede?
Kod ve VRAM kısıtları nedeniyle 405B modelinin JIT derlenmiş sürümünü çalıştıramadık. Bu kısmı daha fazla araştırmamız gerekiyor
Tüm eğitim çalıştırması JAX eager modu ile yapıldı, dolayısıyla performans iyileştirmesi için geniş alan var
Eager modda bile GPU kullanımı genel olarak yaklaşık %30~40’tı ve bu oldukça iyi sayılır. JIT kullanırsak GPU kullanımının kolayca %50~60’a çıkabileceğini düşünüyorum
Mümkünse bellek kısıtlarını aşarak JIT derlenmiş sürümü çalıştırmanın yollarını araştırmak ilginç olurdu. Ek performans iyileştirmesine yol açabilir
JIT derlenmiş eğitim adımına, daha optimize veri yükleme ve parçalamaya, gradyan biriktirmeye ve aktivasyon checkpointing’e ihtiyacımız var
Geliştirmeye devam edip tüm iyileştirmeleri uyguladıktan sonra yakında tekrar blog yazacağız
AMD’nin, GPU toplu siparişleri ve tedarik sıkıntısı sayesinde buradan değer çıkarmaya biraz olsun yaklaşıp yaklaşmadığını merak ediyorum
Benim izlenimim “hayır”a yakın
Karşımızdakinin muazzam bir ilk hamle avantajı var ve yazılım tarafında yapılacak çok iş olduğu açık. Zamana ihtiyaç var
Bir not alma uygulaması olan Obsidian bunu neden yapıyor?