Это не официальный сайт wikipedia.org 01.01.2023

Google JAX — Википедия

Google JAX — фреймворк машинного обучения для преобразования числовых функций.[1][2][3] Он описывается как объединение измененной версии autograd (автоматическое получение градиентной функции через дифференцирование функции) и TensorFlow's XLA (Ускоренная линейная алгебра(Accelerated Linear Algebra)). Он спроектирован, чтобы максимально соответствовать структуре и рабочему процессу NumPy для работы с различными существующими фреймворками, такими как TensorFlow и PyTorch.[4][5] Основными функциями JAX являются:[1]

  1. grad: автоматическое дифференцирование
  2. jit: компиляция
  3. vmap: автоматическая векторизация
  4. pmap: SPMD программирование
Google JAX
Логотип программы Google JAX
Скриншот программы Google JAX
Тип Machine learning
Разработчик Google (компания)
Написана на Python, C++
Операционная система Linux, macOS, Windows
Аппаратная платформа Python, NumPy
Тестовая версия v0.3.13 (16 мая 2022; 9 месяцев назад (2022-05-16))
Лицензия Apache 2.0
Сайт jax.readthedocs.io/en/la…

gradПравить

Код представленный ниже демонстрирует функцию автоматического дифференцирования пакета grad.

# imports
from jax import grad
import jax.numpy as jnp

# define the logistic function
def logistic(x):  
    return jnp.exp(x) / (jnp.exp(x) + 1)

# obtain the gradient function of the logistic function
grad_logistic = grad(logistic)

# evaluate the gradient of the logistic function at x = 1 
grad_log_out = grad_logistic(1.0)   
print(grad_log_out)

Код должен напечатать:

0.19661194

jitПравить

Код представленный ниже демонстрирует функцию оптимизации через слияние пакета jit.

# imports
from jax import jit
import jax.numpy as jnp

# define the cube function
def cube(x):
    return x * x * x

# generate data
x = jnp.ones((10000, 10000))

# create the jit version of the cube function
jit_cube = jit(cube)

# apply the cube and jit_cube functions to the same data for speed comparison
cube(x)
jit_cube(x)

Вычислительное время для jit_cube (строка 17) должно быть заметно короче, чем для cube (строка 16). Увеличение значения в строке 7, будет увеличивать разницу.

vmapПравить

Код представленный ниже демонстрирует функцию векторизации пакета vmap.

# imports
from functools import partial
from jax import vmap
import jax.numpy as jnp

# define function
def grads(self, inputs):
    in_grad_partial = partial(self._net_grads, self._net_params)
    grad_vmap = jax.vmap(in_grad_partial)
    rich_grads = grad_vmap(inputs)
    flat_grads = np.asarray(self._flatten_batch(rich_grads))
    assert flat_grads.ndim == 2 and flat_grads.shape[0] == inputs.shape[0]
    return flat_grads

Изображение в правой части раздела иллюстрирует идея векторизованного сложения.

 
Иллюстрационное видео векторизованного сложения

pmapПравить

Код представленный ниже демонстрирует распараллеливание для умножения матриц пакета pmap.

# import pmap and random from JAX; import JAX NumPy
from jax import pmap, random
import jax.numpy as jnp

# generate 2 random matrices of dimensions 5000 x 6000, one per device
random_keys = random.split(random.PRNGKey(0), 2)
matrices = pmap(lambda key: random.normal(key, (5000, 6000)))(random_keys)

# without data transfer, in parallel, perform a local matrix multiplication on each CPU/GPU
outputs = pmap(lambda x: jnp.dot(x, x.T))(matrices)

# without data transfer, in parallel, obtain the mean for both matrices on each CPU/GPU separately
means = pmap(jnp.mean)(outputs)
print(means)

Последняя строка должна напечатать значенияː

[1.1566595 1.1805978]

Библиотеки использующие JaxПравить

Несколько библиотек Python используют Jax в качестве бэкенда, включая:

  • Flax, высокоуровневая библиотека для нейронных сетей изначально разработанная Google Brain.[6]
  • Haiku, объектно-ориентированная библиотека для нейронных сетей разработанная DeepMind.[7]
  • Equinox, библиотека которая вращается вокруг идеи представления параметризованных функций (включая нейронные сети) как PyTrees. Она была создана Патриком Кидгером.[8]
  • Optax, библиотека для градиентной обработки и оптимизации разработанная DeepMind.[9]
  • RLax, библиотека для разработки агентов для обучения с подкреплением, разработанная DeepMind.[10]

См. такжеПравить

ПримечанияПравить

  1. 1 2 Bradbury, James; Frostig, Roy; Hawkins, Peter & Johnson, Matthew James (2022-06-18), JAX: Autograd and XLA, Astrophysics Source Code Library (Google), <https://github.com/google/jax>. Проверено 18 июня 2022. 
  2. Frostig, Roy; Johnson, Matthew James; Leary, Chris (2018-02-02). “Compiling machine learning programs via high-level tracing” (PDF). MLsys: 1—3. Архивировано из оригинала (PDF) 2022-06-21. Используется устаревший параметр |url-status= (справка)
  3. Using JAX to accelerate our research (англ.). www.deepmind.com. Дата обращения: 18 июня 2022. Архивировано 18 июня 2022 года.
  4. Lynley, Matthew Google is quietly replacing the backbone of its AI product strategy after its last big push for dominance got overshadowed by Meta (англ.). Business Insider. Дата обращения: 21 июня 2022. Архивировано 21 июня 2022 года.
  5. Why is Google's JAX so popular? (англ.). Analytics India Magazine (25 апреля 2022). Дата обращения: 18 июня 2022. Архивировано 18 июня 2022 года.
  6. Flax: A neural network library and ecosystem for JAX designed for flexibility, Google, 2022-07-29, <https://github.com/google/flax>. Проверено 29 июля 2022. 
  7. Haiku: Sonnet for JAX, DeepMind, 2022-07-29, <https://github.com/deepmind/dm-haiku>. Проверено 29 июля 2022. 
  8. Kidger, Patrick (2022-07-29), Equinox, <https://github.com/patrick-kidger/equinox>. Проверено 29 июля 2022. 
  9. Optax, DeepMind, 2022-07-28, <https://github.com/deepmind/optax>. Проверено 29 июля 2022. 
  10. RLax, DeepMind, 2022-07-29, <https://github.com/deepmind/rlax>. Проверено 29 июля 2022. 

СсылкиПравить