← 返回
未分类 中文

jax-skills

High-performance numerical computing and machine learning workflows using JAX. Supports array operations, automatic differentiation, JIT compilation, RNN-sty...
使用JAX进行高性能数值计算和机器学习工作流,支持数组操作、自动微分、JIT编译、RNN风格...
wu-uk wu-uk 来源
未分类 clawhub v0.1.0 1 版本 100000 Key: 无需
★ 0
Stars
📥 357
下载
💾 0
安装
1
版本
#latest

概述

Requirements for Outputs

General Guidelines

Arrays

  • All arrays MUST be compatible with JAX (jnp.array) or convertible from Python lists.
  • Use .npy, .npz, JSON, or pickle for saving arrays.

Operations

  • Validate input types and shapes for all functions.
  • Maintain numerical stability for all operations.
  • Provide meaningful error messages for unsupported operations or invalid inputs.

JAX Skills

1. Loading and Saving Arrays

load(path)

Description: Load a JAX-compatible array from a file. Supports .npy and .npz.

Parameters:

  • path (str): Path to the input file.

Returns: JAX array or dict of arrays if .npz.

import jax_skills as jx

arr = jx.load("data.npy")
arr_dict = jx.load("data.npz")

save(data, path)

Description: Save a JAX array or Python array to .npy.

Parameters:

  • data (array): Array to save.
  • path (str): File path to save.
jx.save(arr, "output.npy")

2. Map and Reduce Operations

map_op(array, op)

Description: Apply elementwise operations on an array using JAX vmap.

Parameters:

  • array (array): Input array.
  • op (str): Operation name ("square" supported).
squared = jx.map_op(arr, "square")

reduce_op(array, op, axis)

Description: Reduce array along a given axis.

Parameters:

  • array (array): Input array.
  • op (str): Operation name ("mean" supported).
  • axis (int): Axis along which to reduce.
mean_vals = jx.reduce_op(arr, "mean", axis=0)

3. Gradients and Optimization

logistic_grad(x, y, w)

Description: Compute the gradient of logistic loss with respect to weights.

Parameters:

  • x (array): Input features.
  • y (array): Labels.
  • w (array): Weight vector.
grad_w = jx.logistic_grad(X_train, y_train, w_init)

Notes:

  • Uses jax.grad for automatic differentiation.
  • Logistic loss: mean(log(1 + exp(-y * (x @ w)))).

4. Recurrent Scan

rnn_scan(seq, Wx, Wh, b)

Description: Apply an RNN-style scan over a sequence using JAX lax.scan.

Parameters:

  • seq (array): Input sequence.
  • Wx (array): Input-to-hidden weight matrix.
  • Wh (array): Hidden-to-hidden weight matrix.
  • b (array): Bias vector.
hseq = jx.rnn_scan(sequence, Wx, Wh, b)

Notes:

  • Returns sequence of hidden states.
  • Uses tanh activation.

5. JIT Compilation

jit_run(fn, args)

Description: JIT compile and run a function using JAX.

Parameters:

  • fn (callable): Function to compile.
  • args (tuple): Arguments for the function.
result = jx.jit_run(my_function, (arg1, arg2))

Notes:

  • Speeds up repeated function calls.
  • Input shapes must be consistent across calls.

Best Practices

  • Prefer JAX arrays (jnp.array) for all operations; convert to NumPy only when saving.
  • Avoid side effects inside functions passed to vmap or scan.
  • Validate input shapes for map_op, reduce_op, and rnn_scan.
  • Use JIT compilation (jit_run) for compute-heavy functions.
  • Save arrays using .npy or pickle/json to avoid system-specific issues.

Example Workflow

import jax.numpy as jnp
import jax_skills as jx

# Load array
arr = jx.load("data.npy")

# Square elements
arr2 = jx.map_op(arr, "square")

# Reduce along axis
mean_arr = jx.reduce_op(arr2, "mean", axis=0)

# Compute logistic gradient
grad_w = jx.logistic_grad(X_train, y_train, w_init)

# RNN scan
hseq = jx.rnn_scan(sequence, Wx, Wh, b)

# Save result
jx.save(hseq, "hseq.npy")

Notes

  • This skill set is designed for scientific computing, ML model prototyping, and dynamic array transformations.
  • Emphasizes JAX-native operations, automatic differentiation, and JIT compilation.
  • Avoid unnecessary conversions to NumPy; only convert when interacting with external file formats.

版本历史

共 1 个版本

  • v0.1.0 当前
    2026-05-07 18:55 安全 安全

安全检测

腾讯云安全 (Keen)

安全,无风险
查看报告

腾讯云安全 (Sanbu)

安全,无风险
查看报告

🔗 相关推荐

xlsx

wu-uk
全面的电子表格创建、编辑与分析,支持公式、格式化、数据分析和可视化。当Claude需要工作时...
★ 0 📥 521

xlsx

wu-uk
全面的电子表格创建、编辑与分析,支持公式、格式化、数据分析和可视化。当Claude需要工作时...
★ 1 📥 999
office-efficiency

modora

wu-uk
使用此技能,可通过远程 MoDora HTTP 服务分析 PDF,凭据通过声明的环境变量管理,不会在服务器上存储。
★ 0 📥 503