14 puan yazan xguru 2024-08-19 | 8 yorum | WhatsApp'ta paylaş
  • PyTorch'un üretkenlik kaybına ve geliştirme zamanının boşa harcanmasına yol açmasının nedeni, "çerçevenin kendisinin kötü olması değil, şu anda uygulandığı kullanım senaryolarına uygun şekilde tasarlanmamış olması"

PyTorch'un felsefesi

  • PyTorch'un felsefesi dinamiklik, hata ayıklama kolaylığı ve Pythonic olmaktır
  • Buna karşılık TensorFlow 1.x, XLA derleyicisini yoğun biçimde kullanarak statik ama yüksek performanslı bir çerçeve olmayı hedefledi
  • TensorFlow geliştiricileri, topluluğun 1.x API'sini sevmediğini fark etti ve ana arayüz olarak Keras'ı kullanmaya karar verip XLA derleyicisinin rolünü küçülttü
  • PyTorch köklerine sadık kaldı ve TensorFlow'un statik ve ertelenmiş yaklaşımının aksine, torch.Tensor'ın anında değerlendirildiği daha dinamik bir "eager execution" yaklaşımını benimsedi
  • Bu da sonuç verince birçok araştırma PyTorch'a geçti
  • 2021'de GPT-3'ün ortaya çıkmasıyla performans ve ölçeklenebilirlik başlıca ilgi alanı haline geldi
  • PyTorch bu talebe bir ölçüde iyi karşılık verdi, ancak bu felsefe gözetilerek tasarlanmadığı için giderek teknik borç birikmeye ve temelleri sarsılmaya başladı
  • PyTorch geliştiricileri hiçbir taviz istemedi ve iki yolu aynı anda izlemeyi seçti
    • XLA derleyicisini performans ve kararlılığı yüksek varsayılan backend olarak kullanmak
    • Gerektiğinde kullanıcıların derleyiciyi çağırma özgürlüğüne sahip olması için torch.compile yığınını inşa etmek
  • Uzun vadeli strateji eksikliği ciddi bir sorundur
  • PyTorch, derleyici merkezli bir felsefeye (JAX gibi) bağlanmak istemiyor ama iyi bir alternatif de görünmüyor
  • Rakip ürünlerin bu soruna çözümü ne?

JAX'in derleyici tabanlı geliştirmesi

  • JAX, TensorFlow'un güçlü derleyici yığını XLA'dan yararlanır
  • XLA güçlü bir derleyicidir, ancak son kullanıcı için tümüyle soyutlanmıştır
  • Bir fonksiyon saf (pure) olduğu sürece @jax.jit dekoratörü kullanılarak JIT derlenebilir ve XLA üzerinde çalıştırılabilir
  • XLA; üretilen grafiğin doğruluğunu doğrulama, JAX'te sharding ile otomatik paralelleştirmeyi yürüten GSPMD partitioner, grafik optimizasyonu, operatör ve kernel fusion, gecikme gizleme zamanlaması, asenkron iletişim örtüşmesi, triton gibi diğer backend'ler için kod üretimi gibi işleri perde arkasında halleder
  • JAX kısıtlarına uyduğunuz sürece XLA bunları otomatik olarak üstlenir
  • Örneğin paralelleştirme sırasında torch.distributed.barrier() gibi iletişim primitive'lerine ihtiyaç duymazsınız
  • DDP desteği basit bir kodla mümkündür
  • XLA'nın yaklaşımı, hesaplamanın sharding'i takip etmesidir. Dolayısıyla giriş dizisi bir eksen boyunca shard edilirse XLA alt hesaplamaları otomatik olarak buna göre yönetir
  • "Derleyici tabanlı geliştirme" fikri, Rust derleyicisinin çalışma biçimine benzer
  • PyTorch'un sınırlamaları
    • PyTorch geliştiricilerinin esneklik ve özgürlük şeklindeki temel felsefeyi korumak yerine yeni özellikler için derleyici yığınını entegre edip ona bağımlı olma tercihinden memnun değilim
    • PyTorch 2.x'in resmi yol haritasına göre, XLA'yı Torch ile tamamen entegre etmeye yönelik uzun vadeli plan açıkça belirtiliyor
    • Bu korkunç bir fikir. Rust derleyicisine zorla C++ kodu sokmanın, doğrudan Rust kullanmaktan daha iyi bir deneyim olacağını söylemeye benziyor
    • Torch, JAX'in aksine XLA merkezli olarak tasarlanmadı
    • PyTorch XLA tabanlı bir derleyici yığını kullanmaya karar veriyorsa, ideal çerçevenin bunun etrafında özel olarak tasarlanıp inşa edilmiş olması gerekmez mi?
    • PyTorch, istediği derleyici backend'ini seçebileceğiniz bir "multi-backend" yaklaşımını izlese bile, bu parçalanma sorununu daha da büyütüp tüm derleyici yığınlarının kısıtlarına uymaya çalışırken API'yi tamamen mahvetmez mi?
    • TPU üzerinde Torch/XLA kullanmış herkes ağır bir PTSD yaşıyor

Multi-Backend başarısız oldu

  • PyTorch her şeyi aynı anda yapmaya çalışarak feci şekilde başarısız oldu
  • "Multi-backend" tasarım kararı bu sorunu katlanarak büyüttü
  • Teoride istediğiniz yığını seçebiliyormuşsunuz gibi geliyor ama pratikte anlaşılması zor traceback'ler ve uyumsuzluk sorunlarının düğümlenmiş kaosu ortaya çıkıyor
  • Backend'ler arası kısıtlar ile PyTorch API'sinin çatışması
    • Sorun bu backend'leri çalıştırmanın zor olması değil; bu backend'lerin beklediği kısıtların PyTorch'un esnek ve Pythonic API'siyle iyi uyuşmaması
    • API tutarlılığını korumak ile backend kısıtlarına uymak arasında bir trade-off var
    • Sonuç olarak geliştiriciler, tek bir backend ile gerçekten bütünleşip ona bağlanmak yerine kod üretimine daha fazla yaslanmaya çalışıyor
  • PyTorch'un strateji eksikliği
    • PyTorch anlamlı trade-off'ları reddettiği için her karar bir uzlaşma gibi hissettiriyor
    • Ne tutarlılık var ne de genel bir strateji
    • Sonuçta kullanıcıda büyük bir hayal kırıklığı yaratıyor ve birbiriyle uyumsuz özelliklerin karması gibi duruyor
    • Ekosistemi öldürmenin daha hızlı bir yolu yok
  • Neden JAX yaklaşımı izlenmemeli
    • PyTorch, JAX'in "birleşik derleyici ve backend" yaklaşımını izlememeli
    • Çünkü JAX açıkça XLA ile çalışacak şekilde tasarlandı
    • PyTorch frontend'ini JAX'inkiyle değiştirmek bir strateji olamaz
    • XLA üzerinde JAX'ten daha iyi bir API tasarlamak fiilen imkânsızdır
    • Geliştiricilerin yeni ve farklı fikirler denemesini eleştirmiyorum
    • Ancak PyTorch'un zamanın sınavına dayanabilmesi için, ideal eğitim senaryolarının dışında anında dağılan havalı yeni özellikler sunmaktan çok temelini güçlendirmeye odaklanması gerekir

PyTorch'un parçalanması ve JAX'in fonksiyonel programlaması

  • JAX'in fonksiyonel API'si
    • JAX fonksiyonları saf (pure) olmalıdır; yani global yan etkileri olmamalıdır
    • Matematik fonksiyonları gibi, aynı veri verildiğinde yürütme bağlamından bağımsız olarak her zaman aynı çıktıyı üretmelidir
    • Bu tasarım felsefesi sayesinde JAX fonksiyonları birleştirilebilir ve birbiriyle iyi çalışır
    • Geliştirme karmaşıklığı azalır; fonksiyonlar belirli imzalar ve iyi tanımlanmış somut işler olarak tanımlanır
    • Türler uyduğu sürece fonksiyonun hemen çalışacağı garanti edilir
    • Bu, bilimsel hesaplama ve özellikle derin öğrenmede gereken iş türlerine uygundur
  • optax API örneği
    • Fonksiyonel yaklaşım sayesinde optax içinde "chain" denilen bir yapı vardır
    • Bu, gradient'lere sırasıyla uygulanan birden fazla fonksiyonu içerir
    • Temel yapı taşı GradientTransformation'dır
    • Güçlü ama aynı zamanda ifade gücü yüksek bir API ortaya çıkarır
    • Örneğin gradient clipping yapmak, gradient'in EMA'sını almak veya optimizer'ları birleştirmek son derece kolay hale gelir
  • Fonksiyonel tasarımın avantajları
    • Fonksiyonel tasarımın bir diğer harika sonucu da vmap'tir
    • Bu, 'vectorized' map anlamına gelir ve tam olarak ne yaptığını anlatır
    • Her şeyi map edebilirsiniz ve vmap olduğu sürece XLA bunu otomatik olarak fuse edip optimize eder
    • Fonksiyon yazarken batch boyutunu düşünmeniz gerekmez
    • Tüm kodu sadece vmap etmeniz yeterlidir
    • Bu da daha az ein-* işlemi gerektiği anlamına gelir
    • 2D/3D tensör manipülasyonlarını anlamak daha sezgisel olur ve okunabilirlik de çok daha iyidir
    • Yalnızca tek tek bileşenleri izole edip akıl yürütmeniz gerektiğinden, iyi çalışan karmaşık kodları daha kolay yazabilirsiniz
    • Saflık kısıtlarına uyup doğru imzalara sahip olduğunuz sürece composability gibi diğer tüm avantajlardan yararlanabilirsiniz
  • PyTorch ekosisteminin sorunları
    • torch tarafında, kullandığınız yığın ne olursa olsun (FSDP + çoklu düğüm + torch.compile gibi) her zaman bir şeylerin kırılma ihtimali vardır
    • Birçok parçanın birlikte doğru çalışması gerekir ve tek bir bileşen bile başarısız olursa sabah 3'e kadar hata ayıklamak zorunda kalırsınız
    • PyTorch'un sunduğu onlarca özelliğin tüm kombinasyonlarını test etmek mümkün olmadığından, geliştirme sırasında yakalanmamış hatalar her zaman olacaktır
    • Kayda değer çaba olmadan düzgün çalışan kod yazmak imkânsızdır
    • torch ekosistemi aşırı şişmiş ve hata dolu hale geldi
    • Ortak bir soyutlama olmadığı için, diğer "çözümlerle" arayüz kurmak üzere tasarlanmamış yeni kütüphane ve çerçeveler ortaya çıkıyor
    • Bu da hızla bağımlılık ve requirements.txt karmaşasına dönüşüyor
    • GitHub issue'larının ya da forum tartışmalarının %70-80'i, yalnızca farklı kütüphanelerde hata çıkmasından ibaret
    • Bunu çözmenin neredeyse bir yolu yok
  • Çözüm eksikliği
    • Bu bir OOP ve tasarım sorunudur
    • PyTree gibi temel ve PyTorch tarzı bir nesnenin, soyutlama için ortak bir zemin kurmaya yardımcı olabileceğini düşünüyorum
    • Fonksiyonel programlama paradigması da benimsenemez
    • Çünkü bu, tüm mevcut torch kod tabanlarının geriye dönük uyumluluğunu bozarken JAX'in daha kötü performans veren bir sürümüne yakınsamaya yol açar
    • PyTorch bu konuda tamamen bozulmuş görünüyor

JAX'in yeniden üretilebilirlik üstünlüğü

  • Seed yönetimi
    • PyTorch'un seed yönetimi ideal değil
    • Genellikle birden fazla satır kod çalıştırmanız gerekir
    • Bunu unutmak ya da yanlış yapılandırmak kolaydır
    • JAX, açık anahtarlar oluşturup bunları rastgeleliğe ihtiyaç duyan her fonksiyona geçirmenizi zorunlu kılar
    • Bu yaklaşım, RNG her zaman statik olarak seed'lendiği için sorunu tamamen ortadan kaldırır
    • JAX'in kendi NumPy sürümü (jax.numpy) olduğundan ayrıca seed ayarlamanız gerekmez
    • Bu tür küçük QoL kararları, tüm çerçevenin kullanıcı deneyimini çok daha iyi hale getirebilir
  • Taşınabilirlik
    • PyTorch kod tabanlarıyla çalışırken en büyük sorunlardan biri taşınabilirlik eksikliğidir
    • CUDA/GPU için yazılmış bir kod tabanı, TPU, NPU, AMD GPU gibi Nvidia dışı donanımlarda çalıştırıldığında genellikle iyi çalışmaz
    • Tek düğüm için yazılmış PyTorch kodunu çok düğümlü yapıya taşımak zordur
    • Çok düğümlü yapı çoğu zaman onlarca saat geliştirme süresi ve ciddi kod değişiklikleri gerektirir
    • JAX'in derleyici merkezli yaklaşımı burada avantaj sağlar
    • XLA, cihaz backend'leri arasında geçişi yönetir ve minimum kod değişikliğiyle GPU/TPU/çoklu düğüm/çoklu slice üzerinde iyi çalışır
    • Donanım üreticilerinin cihaz desteği eklemesini ve cihazlar arasında geçişi kolaylaştırır
    • Herkes aynı donanıma erişemediğinden, farklı donanım türlerinde taşınabilir kod tabanları derin öğrenmeyi yeni başlayanlar ve orta seviye kullanıcılar için daha erişilebilir hale getiren küçük bir adım olabilir
  • Otomatik ölçekleme
    • Kendi kendine iyi otomatik ölçeklenebilen bir kod tabanı, yeniden üretim açısından çok faydalıdır
    • İdeal durumda bunun, minimum kod değişikliğiyle ağ sınırlarından bağımsız biçimde otomatik gerçekleşmesi gerekir
    • JAX bunu iyi yapar
    • JAX kodu yazarken iletişim primitive'lerini belirtmenize veya her yere torch.distributed.barrier() koymanıza gerek yoktur
    • XLA, mevcut donanımı dikkate alarak bunları otomatik ekler
    • JAX'in algılayabildiği tüm cihazlar, ağ yapısı, topoloji, yapılandırma vb. fark etmeksizin otomatik olarak kullanılır
    • Hesaplamayı otomatik olarak senkronize edip hazırlar ve kernel'lerin asenkron yürütmesini en üst düzeye çıkarıp gecikmeyi en aza indirmek için optimizasyon geçişleri uygular
    • İnsan olarak yapmanız gereken tek şey, cihazlar arasında dağıtmak istediğiniz tensörün sharding'ini belirtmektir; örneğin giriş dizisinin batch boyutu gibi
    • XLA'nın "hesaplama sharding'i takip eder" yaklaşımı sayesinde gerisini otomatik olarak çözer
    • Ölçekte doğrulanmış deneyleri hobi amaçlı bile kolayca çalıştırıp deneyebilir ve potansiyel olarak yineleyebilirsiniz
    • Bu, unutulmuş fikirleri keşfetmeyi kolaylaştırabilir ve çok daha büyük ölçekte minimum çabayla fonksiyon olarak test etmeyi mümkün kıldığı için bu tür deneyleri teşvik edebilir

JAX'in dezavantajları

  • Yönetişim yapısı
    • Şu anda XLA, TensorFlow yönetişimi altında
    • PyTorch'a benzer ayrı bir kurumsal yapı oluşturulması konuşuldu, ancak somut anlamda pek ilerleme olmadı
    • Google'ın popüler olmayan ürünleri kapatma konusundaki itibarı nedeniyle Google'a güven düzeyi yüksek değil
    • JAX teknik olarak bir DeepMind projesi ve Google'ın genel yapay zeka hamlesi için kritik öneme sahip, ancak tüm ekosistem açısından uzun vadede büyük fayda sağlayacağı da görülüyor
    • Ayrı bir yönetişim organı, proje geliştirmesine rehberlik edebilir
    • Bu, daha somut bir yapı sunar ve Google'ın meşhur bürokrasisinden ayrışarak birçok sorunu bir anda önleyebilir
    • JAX'in illa bu tür resmî bir yapıya ihtiyacı olmayabilir, ancak Google üst yönetiminin kararlarından bağımsız olarak JAX geliştirmesinin uzun yıllar süreceğine dair bir güvence olması iyi olurdu
    • Bu, bir gün bakımsız kalabilecek araçları entegre etmeye kaynak ayırmakta tereddüt eden şirketler ve büyük araştırma laboratuvarları tarafından benimsenmesine açıkça yardımcı olurdu
  • XLA'nın açık kaynağa geçişi
    • Uzun süre boyunca XLA kapalı kaynak bir projeydi
    • Ancak bunu açık kaynak yapmak için çaba gösterildi ve bugün OpenXLA, dahili XLA derlemelerinden çok daha iyi performans gösteriyor
    • Buna rağmen XLA'nın iç yapısına dair dokümantasyon hâlâ yetersiz
    • Kaynakların çoğu canlı konuşmalar ve ara sıra yayımlanan makalelerden ibaret; bunlar da çoğu zaman güncel değil
    • Planlanan özelliklere dair herkese açık bir yol haritası olsa, insanlar gelişimi takip edebilir ve özellikle ilgilerini çeken alanlara katkı sunmaları kolaylaşırdı
    • XLA derleyici yığınının her aşamasını inceleyip ayrıntılarını anlatan, Edward Yang tarzı mini blog yazıları; uygulayıcıların XLA'nın ne yapabildiğini ve ne yapamadığını daha iyi değerlendirmesine yardımcı olabilir
    • Bunun kaynak yoğun bir iş olduğunu ve bu kaynağın başka yere ayrılmasının daha iyi görülebileceğini anlıyorum; ancak insanlar araçları anladıkça onlara daha çok güvenir ve bunun tüm ekosistemde herkese yarar sağlayan olumlu bir dalga etkisi yaratacağını düşünüyorum
  • Ekosistem entegrasyonu
    • flax, JAX ekosisteminin baş belasıdır
    • Sezgisel olmayan API'si, özlü sözdizimi vardır ve PyTorch'tan geçen yeni başlayanlar için tam bir kâbustur
    • equinox kullanmak daha iyidir
    • Geliştirme ekibinin flaxın sorunlarını çözmeye yönelik girişimleri oldu ama sonuçta bu zaman kaybıdır
    • equinox tarzı bir API istiyorsanız doğrudan equinox kullanmak daha iyidir
    • flaxın belirgin biçimde daha iyi yaptığı çok az şey vardır ve bunları equinox ile kopyalamak zor değildir
    • Şu anda JAX ekosisteminin büyük kısmı flax etrafında tasarlanmış durumda
    • equinox temelde PyTree ile arayüz kurduğu için tüm kütüphanelerle uyumludur, ancak biraz eqx.partition ve filter gerekir
    • Mevcut durumun değişmesini istiyorum. equinox her yerde birinci sınıf desteğe sahip olmalı
    • Bu tartışmalı bir görüş olabilir ama bu klasik bir sunk cost yanılgısıdır
    • equinox, JAX çerçevesinin en başından beri olması gerektiği şekilde daha iyi çalışıyor
    • equinox dokümantasyonunda özetlendiği gibi equinox ile flax karşılaştırıldığında equinox daha iyi
    • JAX ekosistemi yöneticilerinin equinoxun popülerliğini fark edip buna göre hareket etmesi sevindirici, ancak Google ve flax ekibinden de resmî olarak daha fazla destek görmek isterim
    • JAX'i denemek istiyorsanız equinox kullanmanız daha iyi olur
  • Sivri köşeler
    • API tasarım kararları ve XLA kısıtları nedeniyle JAX'te dikkat edilmesi gereken bazı "sivri köşeler" vardır
    • İyi yazılmış belgelerde bunlar çok özlü biçimde açıklanmıştır
    • JAX kullanmadan önce bunları en az bir kez okumanız tavsiye edilir
    • Her zaman olduğu gibi RTFM yapmak size çok zaman ve enerji kazandıracaktır

Sonuç

  • Bu blog yazısı, PyTorch'un gerçek araştırma iş yükleri için, özellikle de GPU tarafında, en uygun seçenek olduğuna dair sıkça tekrarlanan miti düzeltmek için yazıldı. Artık öyle değil
  • Hatta tüm PyTorch kodunu JAX'e taşımak, alanın tamamı için muazzam derecede faydalı olur diyecek kadar ileri gidiyorum
    • Otomatik paralelleştirme, yeniden üretilebilirlik, temiz fonksiyonel API gibi özellikler önemsiz ayrıntılar değil; pek çok araştırma kod tabanı için büyük fayda sağlayabilir
  • Bu alanı az da olsa daha iyi hale getirmek istiyorsanız, kod tabanınızı JAX ile yeniden yazmayı düşünün

8 yorum

 
xguru 2024-08-25

Dünya akıp gitmeye devam ediyor. haha

2022'de PyTorch ve TensorFlow karşılaştırması

 
hilft 2024-08-21

torch ve onnx ile idare edeceğim

 
flrngel 2024-08-21

Lisans öğrencisinin yazdığı bir yazı... Vay be.

 
cosine20 2024-08-21

PyTorch, Huggingface olmasaydı gerçekten bitmiştik lol

 
lemonmint 2024-08-19

Yaşasın JAX! Kısa süre önce denedim ve NNX API'sini gerçekten çok beğendim.

 
stareta1202 2024-08-19

JAX’in en büyük sorunu Google çıkışlı olması. Google’ın açık kaynak projeleri bırakmasıyla epey ünlü olduğu biliniyor (Tflite, android things, dart, angular, bazel vb.); tensorflow da bir noktadan sonra düzgün güncelleme almamaya başladı. Buna karşılık torch, devasa açık kaynak projeleri yürüten Facebook’ta başladı, oldukça iyi yönetildi ve zaten şu anda Torch Foundation tarafından işletiliyor. torch için söylenen dezavantajların haklı olduğu noktalar elbette var, ancak bu açık kaynağı kimin sürdürülebilir biçimde işleteceği açısından bakınca JAX sanki en baştan büyük bir riskle yola çıkıyor gibi görünüyor.

 
dalinaum 2024-08-20

En azından Dart, Flutter sayesinde bir süre daha iyi yaşayacak gibi görünüyor.

 
ilotoki0804 2024-08-20

Facebook, React, Django gibi en azından kendi kullandıkları teknoloji yığınına karşı bir şekilde sadık kalıp(?) sürekli katkı yapıyor gibi görünüyor; ama Google biraz demode olur olmaz paçavra gibi bir kenara atıyor sanki...