2 puan yazan GN⁺ 2024-09-24 | 1 yorum | WhatsApp'ta paylaş

Felafax BlogTune Llama3 405B on AMD MI300x (yolculuğumuz)

Giriş

  • Açık kaynak modeller büyüdükçe, büyük ölçekli yapay zeka eğitimini kaldırabilecek güçlü altyapıya olan ihtiyaç artıyor
  • Felafax, LLaMA 3.1 405B modelini AMD GPU üzerinde ince ayar yaparak AMD donanımının verimliliğini gösterdi
  • Tüm çalışmayı GitHub'da açık kaynak olarak yayımladı
  • AMD MI300X GPU, NVIDIA yapay zeka donanımına kıyasla yüksek performans sunuyor
  • Proje, TensorWave'in desteği sayesinde mümkün oldu

JAX nedir ve neden seçildi

  • JAX, NumPy benzeri bir API, otomatik türev alma ve Google'ın XLA derleyicisini birleştiren güçlü bir makine öğrenimi kütüphanesidir
  • Model paralelleştirme için mükemmel API'ler sunar ve bu da onu büyük ölçekli model eğitimi için ideal hale getirir

JAX'in avantajları

  • Saf fonksiyonlar: JAX, saf fonksiyonlar yazmayı teşvik eder; bu da kodun kurgulanmasını, hata ayıklanmasını ve okunmasını kolaylaştırır
  • Gelişmiş paralelleştirme: JAX'in esnek JIT API'si, büyük ölçekli eğitim için kritik olan gelişmiş veri ve model paralelliğini destekler
  • Temiz kod tabanı: JAX'in tasarım felsefesi, donanım platformları arasında taşınabilir kod yazmayı teşvik eder

JAX neden NVIDIA dışı donanımlarda öne çıkıyor

  • Donanımdan bağımsız yaklaşım: JAX, hesaplamaları donanımdan bağımsız bir ara gösterime derlemek için XLA derleyicisini kullanır
  • Platformdan bağımsız optimizasyon: XLA derleyicisi, optimizasyonları donanımdan bağımsız olarak gerçekleştirir
  • Kolay taşınabilirlik: JAX kullanıldığında NVIDIA'dan AMD'ye geçerken kod değişiklikleri minimum düzeyde kalır

AMD GPU üzerinde JAX kurulumu

  • Docker imajı çekilip konteyner başlatıldıktan sonra kurulum doğrulanır
  • LLaMA 405B modeli, 8 adet AMD MI300x GPU kullanılarak eğitilir

LLaMA 405B eğitimi: performans ve ölçeklenebilirlik

  • LLaMA 405B modeli, JAX kullanılarak AMD GPU üzerinde eğitilir
  • LoRA ince ayarıyla model ağırlıkları ve LoRA parametreleri bfloat16 hassasiyetinde ayarlanır
  • Model boyutu: yaklaşık 800 GB VRAM kullanır
  • LoRA ağırlıkları ve optimizer durumu: yaklaşık 400 GB VRAM kullanır
  • Toplam VRAM kullanımı: yaklaşık 1200 GB
  • Eğitim hızı: saniyede yaklaşık 35 token
  • Bellek verimliliği: yaklaşık %70 korunur
  • Ölçeklenebilirlik: JAX kullanılarak 8 GPU üzerinde neredeyse doğrusal biçimde ölçeklenir

Eğitim kurulumumuz

  • LLaMA 3.1, PyTorch'tan JAX'e dönüştürüldü
  • Model yükleme ve parametre shard etme yoluyla verimli şekilde dağıtıldı

JAX'te parametre shard etme

  • JAX'in device mesh özelliği kullanılarak model 8 AMD GPU'ya verimli biçimde dağıtılır
  • Parametre shard kuralları tanımlanarak her tensörün boyutları mesh eksenlerine göre shard edilir

LoRA eğitim uygulaması

  • LoRA, ağırlık güncellemelerini düşük dereceli matrislere ayırarak eğitilebilir parametre sayısını azaltır
  • LoRA parametrelerini içeren LoRADense katmanı uygulanır
  • LoRA parametreleri verimli biçimde dağıtılarak bellek kullanımı ve hesaplama verimliliği optimize edilir

Sonuç

  • AMD GPU ve JAX kullanarak LLaMA 3.1 405B modeline ince ayar yapma deneyimi oldukça olumluydu
  • JAX'in güçlü paralelleştirme yetenekleri ve donanımdan bağımsız yaklaşımı kullanılarak model verimli biçimde dağıtıldı
  • AMD GPU'ların büyük ölçekli yapay zeka eğitimi için güçlü bir alternatif olduğu gösterildi
  • Tüm kod GitHub deposunda incelenebilir ve doğrudan çalıştırılabilir

GN⁺ özeti

  • Bu yazı, AMD GPU ve JAX kullanarak büyük ölçekli yapay zeka modellerinin nasıl verimli biçimde eğitilebileceğini açıklıyor
  • AMD donanımının NVIDIA'ya kıyasla maliyet açısından verimli bir alternatif olduğunu vurguluyor
  • JAX'in donanımdan bağımsız yaklaşımı, kod taşınabilirliğini artırıyor ve bakımı kolaylaştırıyor
  • Büyük ölçekli model eğitimiyle ilgilenenler için faydalı bilgiler ve uygulamalı kod sağlıyor
  • Benzer işlevlere sahip projeler arasında NVIDIA'nın CUDA'sı ve PyTorch yer alıyor

1 yorum

 
GN⁺ 2024-09-24
Hacker News görüşleri
  • JAX kullanarak Llama3.1 405B modelini 8xAMD MI300x GPU üzerinde fine-tune etme başarısını paylaşıyor

    • JAX'in gelişmiş sharding API'leri sayesinde güçlü performans elde edildi
    • Blog yazısı ve açık kaynak kod bağlantısı paylaşılıyor: GitHub bağlantısı
    • NVIDIA donanımı yerine TPU, AMD ve Trainium üzerinde LLM'leri fine-tune eden ve servis eden yapay zeka altyapısı kuran bir startup olduklarını belirtiyor
    • Birçok şirketin AMD GPU'larda PyTorch çalıştırmaya uğraştığını, ancak bunun zor bir yol olduğunu düşündüğünü söylüyor
    • PyTorch'un NVIDIA ekosistemiyle derin biçimde bağlantılı olduğunu, bu yüzden NVIDIA dışı donanımda çalıştırmak için çok sayıda değişiklik gerektiğini ifade ediyor
    • JAX'in NVIDIA dışı donanım için daha uygun olduğuna inanıyor
    • JAX'te ML model kodu donanımdan bağımsız HLO grafiğine derleniyor ve XLA derleyicisi donanıma özgü optimizasyonları yapıyor
    • Aynı JAX kodu, hiçbir değişiklik olmadan Google TPU ve AMD GPU üzerinde çalışabiliyor
    • Şirket stratejisinin modelleri JAX'e taşımak ve NVIDIA dışı backend'lerde azami performans çıkarmak için XLA kernel'lerini kullanmak olduğunu söylüyor
    • Llama 3.1'i ilk olarak PyTorch'tan JAX'e taşıdıklarını ve artık aynı JAX modelinin TPU ve AMD GPU üzerinde iyi çalıştığını belirtiyor
    • Vizyonları ve repo hakkında görüş duymak istiyor
  • Bellek kısıtlarını aşıp JIT derlenmiş sürümü çalıştırma yöntemlerinin araştırılmasını öneriyor

    • Bunun ek performans artışı sağlayabileceğini düşünüyor
  • AMD GPU ve ROCm desteğiyle ilgili deneyimini paylaşıyor

    • Bir yıl önce AMD GPU ve ROCm desteğini denediğini, ancak AMD'nin NVIDIA'yı yakalamasına daha çok yol olduğunu hissettiğini söylüyor
    • JAX seçiminin ilginç bir yaklaşım olduğunu, ancak PyTorch'tan uzaklaşırken ne tür zorluklar yaşandığını merak ediyor
  • 405B modelinin inference tarafında yaptığı deneyi paylaşıyor

    • torch.cuda'nın o kadar da kötü olmadığını düşünüyor
    • AMD sürümündeki PyTorch'un bunu çevirdiğini, dolayısıyla bunun sadece bir isim meselesi olduğunu söylüyor
    • rocm:pytorch container'ını kullanmanın rocm:jax container'ını kullanmak kadar kolay olduğunu belirtiyor
    • Fazla performans verisi yayımlanmadığına dikkat çekiyor
    • MFU (model kullanım oranı) rakamlarını merak ediyor
  • Performans verisinin olmamasına dair soru soruyor

    • AMD GPU'ların büyük miktarda sipariş edilmesi nedeniyle buradan değer çıkarma ihtimalini sorguluyor
    • Aldığı izlenimin "hayır" yönünde olduğunu söylüyor
  • Obsidian'ın (not alma uygulaması) neden bunu yaptığına dair soru işareti dile getiriyor

    • İlk başta bunun Obsidian'ın kendi gönderisi olduğunu düşündüğünü söylüyor
    • GitHub.com ile GitHub.io'nun neden hâlâ ayırt edilmediğini sorguluyor
  • @dang'den URL'lerde kullanıcı adının yer almasını istiyor

    • Bu gönderinin Obsidian'ın kendisiyle değil, kullanıcı tarafından oluşturulmuş bir blogla ilgili olduğunu belirtiyor