Modelinizi Nasıl Ölçeklersiniz: TPU'da LLM'lere Sistem Perspektifinden Bakış
(jax-ml.github.io)- Derin öğrenme performansını büyük ölçekte optimize etmek bir tür “simya” gibi görünebilir, ancak gerçekte anlaşılabilir basit ilkelerle model verimliliği artırılabilir
- Tek bir hızlandırıcıdan on binlerce hızlandırıcıya kadar, nispeten basit ilkeler her yerde geçerlidir ve bunları anlamak aşağıdaki gibi yararlı işler yapmayı mümkün kılar:
- Modelin her bölümünün teorik optimuma ne kadar yaklaştığını kabaca değerlendirmek
- Farklı ölçeklerde çeşitli paralelleştirme tekniklerini seçmek için bir temel oluşturmak
- Büyük Transformer modellerinin eğitimi ve çalıştırılması için gereken maliyet ve zamanı tahmin etmek
- Belirli donanımın özelliklerinden yararlanan algoritmalar tasarlamak
- Mevcut algoritma performansının sınırlarını açıkça anlayarak donanım tasarlamak
- Gerekli arka plan bilgisi
- LLM ve Transformer mimarisi hakkında temel kavramları anlamak gerekir
- Büyük ölçekli çalışma biçimlerine dair bilgi zorunlu değildir
- LLM eğitimi hakkında temel bilgi ve JAX kullanım deneyimi varsa daha iyi olur
- Transformer mimarisiyle ilgili blog yazıları ve JAX'te LLM ölçeklendirmesi hakkındaki slaytlara bakılması önerilir
- Hedefler
- Verilen bir donanım üzerinde modeli hangi şekilde paralelleştirmenin uygun olacağını tahmin edebilme yetkinliği kazanmak
- Eğitim ve çıkarım için gereken zaman ve maliyeti kabaca hesaplayabilme becerisi kazanmak
Neden ilgi göstermelisiniz?
- 3-4 yıl öncesine kadar çoğu ML araştırmacısının bu tür büyük ölçekli optimizasyonları derinlemesine bilmesine gerek yoktu
- Bugün ise “küçük” modeller bile donanım sınırlarına yakın çalıştığı için, büyük ölçekli verimli çalışma biçimlerini anlamak zorunlu hale geldi
- ML tarihi, sistem yenilikleri ile yazılım iyileştirmelerinin birbirini besleyerek geliştiği bir akış olarak görülebilir
- Son dönemde Transformer modellerinin donanım sınırlarına kadar kullanılması nedeniyle, model verimliliği anlaşılmadan geliştirilen yeni mimarilerin veya araştırmaların gerçek kullanımda başarısız olma ihtimali yüksektir
- Benchmark'ta %20 performans artışı elde etseniz bile, donanım verimliliği %20 düşerse sonuçta pratik değeri düşük kalır
- Model ölçeklendirmenin temel amacı, çiplerin sayısı arttığında iş hacminin doğrusal biçimde artmasını sağlamaktır
- Buna "güçlü ölçeklendirme" denir
- Çip eklemek hesaplama süresini kısaltır, ancak çipler arası iletişim maliyeti doğurur
- İletişim hesaplamadan daha uzun sürerse sistem "iletişim sınırlı" duruma gelir ve güçlü ölçeklendirme mümkün olmaz
- Donanımı yeterince iyi anlayıp bu darboğazların nerede oluşacağını öngörebilirseniz, modeli bunları önleyecek şekilde tasarlayabilir veya yeniden yapılandırabilirsiniz
- Bu kitabın amacı, TPU (ve GPU) donanımının nasıl çalıştığını ve Transformer mimarisinin mevcut donanımda iyi çalışacak şekilde nasıl evrildiğini açıklamaktır
- Bunun hem yeni mimariler tasarlayan araştırmacılar hem de mevcut nesil LLM'leri hızlı çalıştırmaya çalışan mühendisler için faydalı olması umuluyor
Genel bakış
- Bu yazı şu şekilde yapılandırılmıştır
- Bölüm 1, roofline analizi üzerinden modelin performans sınırlarını belirleyen unsurları (iletişim, hesaplama, bellek) açıklar
- Bölüm 2 ve Bölüm 3, TPU ve GPU'ların iç yapısını ve çipler arası bağlantı biçimlerini ele alır
- Bu sayede aşağıdaki sorular yanıtlanır
- Belirli boyuttaki bir matris çarpımı teorik olarak ne kadar hızlı yapılabilir?
- Hangi noktada hesaplama bellek bant genişliği ya da iletişim bant genişliğiyle sınırlanır?
- Bir TPU kümesi nasıl bağlanır ve veriyi bir çipten diğerine taşımak kabaca ne kadar sürer?
- Dağıtık matrisler verimli biçimde nasıl çarpılabilir?
- Bu sayede aşağıdaki sorular yanıtlanır
- Bölüm 4, Transformer mimarisinin formüllerini ayrıntılı biçimde ele alır (matris boyutları, parametre sayısı, FLOPs)
- Bölüm 5 ve Bölüm 7 ana bölümlerdir; modeli birden çok çip üzerinde paralelleştirmenin çeşitli yollarını tanıtır
- Data parallel, Tensor parallel, Pipeline parallel, Expert parallel
- ZeRO, Rematerialisation, Host offload, Gradient accumulation gibi bellek tasarrufu teknikleri de ele alınır
- Bölüm 6 ve Bölüm 8, TPU üzerinde LLaMA-3 eğitimi ve çıkarımını örnek alarak gerçek maliyet, süre ve yapılandırma biçimlerini sunar
- Son olarak Bölüm 9 ve Bölüm 10, JAX'te modeli profilleme, hata ayıklama ve paralel işleme uygulamanın pratik yollarını ele alır
Ayrıntılar: kitabın ana bölümlerinin özeti
-
Bölüm 1: Preliminaries
-
Bölüm 1: Basit Roofline analizine giriş
- Algoritmaları sınırlayan üç unsur: hesaplama, iletişim, bellek
- Buradan hareketle hesaplama hızının üst sınırının nasıl tahmin edileceği öğrenilir
-
- TPU'nun hesaplamayı nasıl yaptığı
- Systolic array yapısının ne olduğu
- TPU'nun bellek ve iletişim bant genişliğini nasıl sağladığına dair temel bir anlayış
-
Bölüm 3: Dağıtık matrisler ve dağıtık çarpım
- Model parametrelerini birden fazla çipe bölerek saklama (Sharding) tekniği
- Dağıtık matris işlemlerinde ortaya çıkan iletişim ve darboğazların nasıl ele alındığı
-
-
Bölüm 2: Transformers
-
Bölüm 4: Gerekli Transformer formüllerinin derlenmesi
- Transformer'da matris çarpımlarının somut olarak hangi biçimde olduğu
- Parametre sayısı, FLOPs, KV cache boyutu gibi değerlerin nasıl hesaplandığı
- Attention işlemlerinin Feed-Forward bloklarına kıyasla ne kadar hesaplama gerektirdiğini anlamak
-
Bölüm 5: Transformer eğitimi için paralelleştirme stratejileri
- Data parallel, Tensor parallel, Pipeline parallel, Expert parallel tekniklerine giriş
- ZeRO(FSDP), Rematerialisation, Gradient accumulation, Host offload gibi bellek tasarrufu yöntemleri
- Belirli model boyutu ve çip sayısına göre paralelleştirme kurgulama anlayışı
-
Bölüm 6: LLaMA 3'ün TPU eğitimine uygulanması
- Gerçek bir TPU ortamında LLaMA 3 modelinin eğitildiği varsayıldığında, süre ve maliyet tahmini
- Batch size, paralelleştirme yöntemi, bellek kullanımı gibi konularda somut örnekler
-
Bölüm 7: Transformer çıkarımı hakkında her şey
- Çıkarımda gecikme (latency), önemli yeni bir etken olarak öne çıkar
- KV cache gibi unsurların yol açtığı bellek kullanımı ve iletişim sorunları
- Model sunumu için birden fazla çipin nasıl tahsis edilip bağlanacağına dair tartışmalar
-
Bölüm 8: LLaMA 3'ün TPU üzerinde sunuma uygulanması
- TPU v5e üzerinde LLaMA 3 sunumu varsayımında, yaklaşık maliyet, gecikme ve iş hacmi arasındaki trade-off analizi
-
-
Bölüm 3: Practical Tutorials
-
Bölüm 9: TPU kodu nasıl profillenir
- JAX+XLA yığınını anlamak
- Gerçek performans düşüşü sorunlarını tespit etme ve çözüm yolları
- JAX/TensorBoard profiler kullanım şekli
-
Bölüm 10: JAX ile TPU programlama
- JAX'in paralelleştirme API'lerini (primitives) kullanma
- Örnekler ve alıştırmalarla paralel hesaplama kavramlarını öğrenme
-
Bölüm 11: Sonuç ve ek kaynaklar
- TPU ve LLM hakkında ek okuma önerileri
- Genel içeriğin kısa bir kapanışı ve geleceğe dair değerlendirmeler
-
1 yorum
Hacker News yorumu