diff --git a/README.md b/README.md index 9602dfc..f508752 100644 --- a/README.md +++ b/README.md @@ -42,7 +42,7 @@ > (解决网络, 文件, res, assets等图片的获取, 解析, 展示, 缓存等需求...) 名称 | 概要 | 详情 ---- | --- | --- | --- +--- | --- | --- [*Picasso](https://github.com/square/picasso) | Github大神推荐的强大的图片下载和缓存库 | Square 开源的项目,主导者是 [JakeWharton](https://github.com/JakeWharton). [*Glide](https://github.com/bumptech/glide) | Google推荐的图片加载和缓存的库 | 专注于平滑滚动时的流畅加载, Google开源项目, 2014年Google I/O 上被推荐 [*Fresco](https://github.com/facebook/fresco) | Facebook推荐的的Android图片加载库 | 自动管理图片的加载和图片的缓存.Facebook 在2015年上半年开源的图片加载库 @@ -55,7 +55,7 @@ > (解决图片缩放, 裁剪, 平移, 旋转等需求) 名称 | 概要 | 详情 ---- | --- | --- | --- +--- | --- | --- [PinchImageView](https://github.com/boycy815/PinchImageView)|国人写的, 可能是体验最好的图片手势控件| 支持双击放大,双击缩小,超出边界会回弹, 滑动惯性,不同分辨率无缝切换,可与ViewPager结合使用。 star:360 [GestureViews](https://github.com/alexvasilkov/GestureViews)|包含ImageView的自定义FrameLayout | 项目目的是让图片的查看尽可能流畅平滑, 让开发者更加方便地集成到自己的应用中, 支持手势控制和动画 star:582 [*PhotoView](https://github.com/chrisbanes/PhotoView) | 致力于帮助开发者高效的创建可缩放的ImageView | 重写ImageView的实现, 支持多点触摸的图片缩放 star:4705 @@ -85,7 +85,7 @@ > 解决各种协议(GET, POST, PUT, HEAD, DETELE...)的网络数据的获取及请求, 支持异步,同步请求; 文件多线程下载断点续传, 上传; 请求自动重试, gzip压缩, Cookies自动解析并持久化. 数据的缓存. 目标是让网络请求更方便, 简介, 高效, 稳定. 名称 | 概要 | 详情 ---- | --- | --- | --- +--- | --- | --- [*Retrofit2.0](http://square.github.io/retrofit/) | 以接口/注解的形式定义请求和响应 | Square 开源的项目. 是一套RESTful架构的Android(Java)客户端实现,基于注解,提供JSON to POJO(Plain Ordinary Java Object,简单Java对象),POJO to JSON,网络请求(POST,GET,PUT,DELETE等)封装。 Jake Wharton大神力荐. 本身的网络核心可以替换. 如Apache HTTP client, URL connection, OKHttp等, 数据解析核心也可以替换如Gson, Jackson, fastjson, xStream等. 力求用最少的代码, 实现最强大的功能. [官方主页](http://square.github.io/retrofit/) [*okhttp](https://github.com/square/okhttp) | 一个为安卓和java应用诞生的Http+SPDY的网络处理库 | square开源项目. a. 支持HTTP, HTTPS, HTTP/2.0, and SPDY协议 b. 自动缓存数据, 节省流量, c.内部自动GZIP压缩内容. [android-async-http](https://github.com/loopj/android-async-http) | 一个异步的AndroidHttp库 | 比较经典的网络请求库, 基于Apache的HttpClient库实现, 但是由于AndroidM(6.0)去除了对HttpClient相关API, 意味着google不再推荐使用. diff --git a/train.ipynb b/train.ipynb new file mode 100644 index 0000000..4edb923 --- /dev/null +++ b/train.ipynb @@ -0,0 +1,934 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "markdown", + "id": "0", + "metadata": { + "id": "0" + }, + "source": [ + "# K-Scale Humanoid Benchmark\n", + "\n", + "Welcome to the K-Scale Humanoid Benchmark! This notebook will walk you through training your own reinforcement learning policy, which you can then use to control a K-Scale robot.\n", + "\n", + "*Note:* The Just-In-Time compilation may take a while and cause your Colab instance to appear to disconnect. However, your training cell may actually still be running. Make sure to check before restarting!\n", + "\n", + "*Last updated: 2025/05/15*" + ] + }, + { + "cell_type": "markdown", + "id": "1", + "metadata": { + "id": "1" + }, + "source": [ + "## Dependencies and Config\n", + "\n", + "The K-Scale Humanoid Benchmark uses K-Scale's open-source RL framework [K-Sim](https://github.com/kscalelabs/ksim) for training and the [K-Scale API](https://github.com/kscalelabs/kscale) for asset management." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2", + "metadata": { + "id": "2" + }, + "outputs": [], + "source": [ + "# Install packages\n", + "\n", + "!pip install ksim==0.1.2 xax==0.3.0 mujoco-scenes" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": { + "id": "3" + }, + "outputs": [], + "source": [ + "# Set up environment variables\n", + "%env MUJOCO_GL=egl" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": { + "id": "4" + }, + "outputs": [], + "source": [ + "import asyncio\n", + "import functools\n", + "import math\n", + "from dataclasses import dataclass\n", + "from typing import Self\n", + "\n", + "import attrs\n", + "import distrax\n", + "import equinox as eqx\n", + "import jax\n", + "import jax.numpy as jnp\n", + "import ksim\n", + "import mujoco\n", + "import mujoco_scenes\n", + "import mujoco_scenes.mjcf\n", + "import nest_asyncio\n", + "import optax\n", + "import xax\n", + "from jaxtyping import Array, PRNGKeyArray\n", + "\n", + "nest_asyncio.apply()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": { + "id": "5" + }, + "outputs": [], + "source": [ + "# These are in the order of the neural network outputs.\n", + "ZEROS: list[tuple[str, float]] = [\n", + " (\"dof_right_shoulder_pitch_03\", 0.0),\n", + " (\"dof_right_shoulder_roll_03\", math.radians(-10.0)),\n", + " (\"dof_right_shoulder_yaw_02\", 0.0),\n", + " (\"dof_right_elbow_02\", math.radians(90.0)),\n", + " (\"dof_right_wrist_00\", 0.0),\n", + " (\"dof_left_shoulder_pitch_03\", 0.0),\n", + " (\"dof_left_shoulder_roll_03\", math.radians(10.0)),\n", + " (\"dof_left_shoulder_yaw_02\", 0.0),\n", + " (\"dof_left_elbow_02\", math.radians(-90.0)),\n", + " (\"dof_left_wrist_00\", 0.0),\n", + " (\"dof_right_hip_pitch_04\", math.radians(-20.0)),\n", + " (\"dof_right_hip_roll_03\", math.radians(-0.0)),\n", + " (\"dof_right_hip_yaw_03\", 0.0),\n", + " (\"dof_right_knee_04\", math.radians(-50.0)),\n", + " (\"dof_right_ankle_02\", math.radians(30.0)),\n", + " (\"dof_left_hip_pitch_04\", math.radians(20.0)),\n", + " (\"dof_left_hip_roll_03\", math.radians(0.0)),\n", + " (\"dof_left_hip_yaw_03\", 0.0),\n", + " (\"dof_left_knee_04\", math.radians(50.0)),\n", + " (\"dof_left_ankle_02\", math.radians(-30.0)),\n", + "]" + ] + }, + { + "cell_type": "markdown", + "id": "6", + "metadata": { + "id": "6" + }, + "source": [ + "## Rewards\n", + "\n", + "When training a reinforcement learning agent, the most important thing to define is what reward you want the agent to maximimze. `ksim` includes a number of useful default rewards for training walking agents, but it is often a good idea to define new rewards to encourage specific types of behavior. The cell below shows an example of how to define a custom reward. A similar pattern can be used to define custom objectives, events, observations, and more." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": { + "id": "7" + }, + "outputs": [], + "source": [ + "@attrs.define(frozen=True, kw_only=True)\n", + "class JointPositionPenalty(ksim.JointDeviationPenalty):\n", + " @classmethod\n", + " def create_from_names(\n", + " cls,\n", + " names: list[str],\n", + " physics_model: ksim.PhysicsModel,\n", + " scale: float = -1.0,\n", + " scale_by_curriculum: bool = False,\n", + " ) -> Self:\n", + " zeros = {k: v for k, v in ZEROS}\n", + " joint_targets = [zeros[name] for name in names]\n", + "\n", + " return cls.create(\n", + " physics_model=physics_model,\n", + " joint_names=tuple(names),\n", + " joint_targets=tuple(joint_targets),\n", + " scale=scale,\n", + " scale_by_curriculum=scale_by_curriculum,\n", + " )\n", + "\n", + "\n", + "@attrs.define(frozen=True, kw_only=True)\n", + "class BentArmPenalty(JointPositionPenalty):\n", + " @classmethod\n", + " def create_penalty(\n", + " cls,\n", + " physics_model: ksim.PhysicsModel,\n", + " scale: float = -1.0,\n", + " scale_by_curriculum: bool = False,\n", + " ) -> Self:\n", + " return cls.create_from_names(\n", + " names=[\n", + " \"dof_right_shoulder_pitch_03\",\n", + " \"dof_right_shoulder_roll_03\",\n", + " \"dof_right_shoulder_yaw_02\",\n", + " \"dof_right_elbow_02\",\n", + " \"dof_right_wrist_00\",\n", + " \"dof_left_shoulder_pitch_03\",\n", + " \"dof_left_shoulder_roll_03\",\n", + " \"dof_left_shoulder_yaw_02\",\n", + " \"dof_left_elbow_02\",\n", + " \"dof_left_wrist_00\",\n", + " ],\n", + " physics_model=physics_model,\n", + " scale=scale,\n", + " scale_by_curriculum=scale_by_curriculum,\n", + " )\n", + "\n", + "\n", + "@attrs.define(frozen=True, kw_only=True)\n", + "class StraightLegPenalty(JointPositionPenalty):\n", + " @classmethod\n", + " def create_penalty(\n", + " cls,\n", + " physics_model: ksim.PhysicsModel,\n", + " scale: float = -1.0,\n", + " scale_by_curriculum: bool = False,\n", + " ) -> Self:\n", + " return cls.create_from_names(\n", + " names=[\n", + " \"dof_left_hip_roll_03\",\n", + " \"dof_left_hip_yaw_03\",\n", + " \"dof_right_hip_roll_03\",\n", + " \"dof_right_hip_yaw_03\",\n", + " ],\n", + " physics_model=physics_model,\n", + " scale=scale,\n", + " scale_by_curriculum=scale_by_curriculum,\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "8", + "metadata": { + "id": "8" + }, + "source": [ + "## Actor-Critic Model\n", + "\n", + "We train our reinforcement learning agent using an RNN-based actor and critic, which we define below." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": { + "id": "9" + }, + "outputs": [], + "source": [ + "class Actor(eqx.Module):\n", + " \"\"\"Actor for the walking task.\"\"\"\n", + "\n", + " input_proj: eqx.nn.Linear\n", + " rnns: tuple[eqx.nn.GRUCell, ...]\n", + " output_proj: eqx.nn.Linear\n", + " num_inputs: int = eqx.static_field()\n", + " num_outputs: int = eqx.static_field()\n", + " num_mixtures: int = eqx.static_field()\n", + " min_std: float = eqx.static_field()\n", + " max_std: float = eqx.static_field()\n", + " var_scale: float = eqx.static_field()\n", + "\n", + " def __init__(\n", + " self,\n", + " key: PRNGKeyArray,\n", + " *,\n", + " num_inputs: int,\n", + " num_outputs: int,\n", + " min_std: float,\n", + " max_std: float,\n", + " var_scale: float,\n", + " hidden_size: int,\n", + " num_mixtures: int,\n", + " depth: int,\n", + " ) -> None:\n", + " # Project input to hidden size\n", + " key, input_proj_key = jax.random.split(key)\n", + " self.input_proj = eqx.nn.Linear(\n", + " in_features=num_inputs,\n", + " out_features=hidden_size,\n", + " key=input_proj_key,\n", + " )\n", + "\n", + " # Create RNN layer\n", + " key, rnn_key = jax.random.split(key)\n", + " rnn_keys = jax.random.split(rnn_key, depth)\n", + " self.rnns = tuple(\n", + " [\n", + " eqx.nn.GRUCell(\n", + " input_size=hidden_size,\n", + " hidden_size=hidden_size,\n", + " key=rnn_key,\n", + " )\n", + " for rnn_key in rnn_keys\n", + " ]\n", + " )\n", + "\n", + " # Project to output\n", + " self.output_proj = eqx.nn.Linear(\n", + " in_features=hidden_size,\n", + " out_features=num_outputs * 3 * num_mixtures,\n", + " key=key,\n", + " )\n", + "\n", + " self.num_inputs = num_inputs\n", + " self.num_outputs = num_outputs\n", + " self.num_mixtures = num_mixtures\n", + " self.min_std = min_std\n", + " self.max_std = max_std\n", + " self.var_scale = var_scale\n", + "\n", + " def forward(self, obs_n: Array, carry: Array) -> tuple[distrax.Distribution, Array]:\n", + " x_n = self.input_proj(obs_n)\n", + " out_carries = []\n", + " for i, rnn in enumerate(self.rnns):\n", + " x_n = rnn(x_n, carry[i])\n", + " out_carries.append(x_n)\n", + " out_n = self.output_proj(x_n)\n", + "\n", + " # Reshape the output to be a mixture of gaussians.\n", + " slice_len = self.num_outputs * self.num_mixtures\n", + " mean_nm = out_n[..., :slice_len].reshape(self.num_outputs, self.num_mixtures)\n", + " std_nm = out_n[..., slice_len : slice_len * 2].reshape(self.num_outputs, self.num_mixtures)\n", + " logits_nm = out_n[..., slice_len * 2 :].reshape(self.num_outputs, self.num_mixtures)\n", + "\n", + " # Softplus and clip to ensure positive standard deviations.\n", + " std_nm = jnp.clip((jax.nn.softplus(std_nm) + self.min_std) * self.var_scale, max=self.max_std)\n", + "\n", + " # Apply bias to the means.\n", + " mean_nm = mean_nm + jnp.array([v for _, v in ZEROS])[:, None]\n", + "\n", + " dist_n = ksim.MixtureOfGaussians(means_nm=mean_nm, stds_nm=std_nm, logits_nm=logits_nm)\n", + "\n", + " return dist_n, jnp.stack(out_carries, axis=0)\n", + "\n", + "\n", + "class Critic(eqx.Module):\n", + " \"\"\"Critic for the walking task.\"\"\"\n", + "\n", + " input_proj: eqx.nn.Linear\n", + " rnns: tuple[eqx.nn.GRUCell, ...]\n", + " output_proj: eqx.nn.Linear\n", + " num_inputs: int = eqx.static_field()\n", + "\n", + " def __init__(\n", + " self,\n", + " key: PRNGKeyArray,\n", + " *,\n", + " num_inputs: int,\n", + " hidden_size: int,\n", + " depth: int,\n", + " ) -> None:\n", + " num_outputs = 1\n", + "\n", + " # Project input to hidden size\n", + " key, input_proj_key = jax.random.split(key)\n", + " self.input_proj = eqx.nn.Linear(\n", + " in_features=num_inputs,\n", + " out_features=hidden_size,\n", + " key=input_proj_key,\n", + " )\n", + "\n", + " # Create RNN layer\n", + " key, rnn_key = jax.random.split(key)\n", + " rnn_keys = jax.random.split(rnn_key, depth)\n", + " self.rnns = tuple(\n", + " [\n", + " eqx.nn.GRUCell(\n", + " input_size=hidden_size,\n", + " hidden_size=hidden_size,\n", + " key=rnn_key,\n", + " )\n", + " for rnn_key in rnn_keys\n", + " ]\n", + " )\n", + "\n", + " # Project to output\n", + " self.output_proj = eqx.nn.Linear(\n", + " in_features=hidden_size,\n", + " out_features=num_outputs,\n", + " key=key,\n", + " )\n", + "\n", + " self.num_inputs = num_inputs\n", + "\n", + " def forward(self, obs_n: Array, carry: Array) -> tuple[Array, Array]:\n", + " x_n = self.input_proj(obs_n)\n", + " out_carries = []\n", + " for i, rnn in enumerate(self.rnns):\n", + " x_n = rnn(x_n, carry[i])\n", + " out_carries.append(x_n)\n", + " out_n = self.output_proj(x_n)\n", + "\n", + " return out_n, jnp.stack(out_carries, axis=0)\n", + "\n", + "\n", + "class Model(eqx.Module):\n", + " actor: Actor\n", + " critic: Critic\n", + "\n", + " def __init__(\n", + " self,\n", + " key: PRNGKeyArray,\n", + " *,\n", + " num_actor_inputs: int,\n", + " num_actor_outputs: int,\n", + " num_critic_inputs: int,\n", + " min_std: float,\n", + " max_std: float,\n", + " var_scale: float,\n", + " hidden_size: int,\n", + " num_mixtures: int,\n", + " depth: int,\n", + " ) -> None:\n", + " actor_key, critic_key = jax.random.split(key)\n", + " self.actor = Actor(\n", + " actor_key,\n", + " num_inputs=num_actor_inputs,\n", + " num_outputs=num_actor_outputs,\n", + " min_std=min_std,\n", + " max_std=max_std,\n", + " var_scale=var_scale,\n", + " hidden_size=hidden_size,\n", + " num_mixtures=num_mixtures,\n", + " depth=depth,\n", + " )\n", + " self.critic = Critic(\n", + " critic_key,\n", + " hidden_size=hidden_size,\n", + " depth=depth,\n", + " num_inputs=num_critic_inputs,\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "10", + "metadata": { + "id": "10" + }, + "source": [ + "## Config\n", + "\n", + "The [ksim framework](https://github.com/kscalelabs/ksim) is based on [xax](https://github.com/kscalelabs/xax), a JAX training library built by K-Scale. To provide configuration options, xax uses a Config dataclass to parse command-line options. We define the config here." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11", + "metadata": { + "id": "11" + }, + "outputs": [], + "source": [ + "@dataclass\n", + "class HumanoidWalkingTaskConfig(ksim.PPOConfig):\n", + " \"\"\"Config for the humanoid walking task.\"\"\"\n", + "\n", + " # Model parameters.\n", + " hidden_size: int = xax.field(\n", + " value=128,\n", + " help=\"The hidden size for the MLPs.\",\n", + " )\n", + " depth: int = xax.field(\n", + " value=5,\n", + " help=\"The depth for the MLPs.\",\n", + " )\n", + " num_mixtures: int = xax.field(\n", + " value=5,\n", + " help=\"The number of mixtures for the actor.\",\n", + " )\n", + " var_scale: float = xax.field(\n", + " value=0.5,\n", + " help=\"The scale for the standard deviations of the actor.\",\n", + " )\n", + " use_acc_gyro: bool = xax.field(\n", + " value=True,\n", + " help=\"Whether to use the IMU acceleration and gyroscope observations.\",\n", + " )\n", + "\n", + " # Curriculum parameters.\n", + " num_curriculum_levels: int = xax.field(\n", + " value=100,\n", + " help=\"The number of curriculum levels to use.\",\n", + " )\n", + " increase_threshold: float = xax.field(\n", + " value=5.0,\n", + " help=\"Increase the curriculum level when the mean trajectory length is above this threshold.\",\n", + " )\n", + " decrease_threshold: float = xax.field(\n", + " value=1.0,\n", + " help=\"Decrease the curriculum level when the mean trajectory length is below this threshold.\",\n", + " )\n", + " min_level_steps: int = xax.field(\n", + " value=1,\n", + " help=\"The minimum number of steps to wait before changing the curriculum level.\",\n", + " )\n", + "\n", + " # Optimizer parameters.\n", + " learning_rate: float = xax.field(\n", + " value=3e-4,\n", + " help=\"Learning rate for PPO.\",\n", + " )\n", + " adam_weight_decay: float = xax.field(\n", + " value=1e-5,\n", + " help=\"Weight decay for the Adam optimizer.\",\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "12", + "metadata": { + "id": "12" + }, + "source": [ + "## Task\n", + "\n", + "The meat-and-potatoes of our training code is the task. This defines the observations, rewards, model calling logic, and everything else needed by `ksim` to train our reinforcement learning agent." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "13", + "metadata": { + "id": "13" + }, + "outputs": [], + "source": [ + "class HumanoidWalkingTask(ksim.PPOTask[HumanoidWalkingTaskConfig]):\n", + " def get_optimizer(self) -> optax.GradientTransformation:\n", + " return (\n", + " optax.adam(self.config.learning_rate)\n", + " if self.config.adam_weight_decay == 0.0\n", + " else optax.adamw(self.config.learning_rate, weight_decay=self.config.adam_weight_decay)\n", + " )\n", + "\n", + " def get_mujoco_model(self) -> mujoco.MjModel:\n", + " mjcf_path = asyncio.run(ksim.get_mujoco_model_path(\"kbot\", name=\"robot\"))\n", + " return mujoco_scenes.mjcf.load_mjmodel(mjcf_path, scene=\"smooth\")\n", + "\n", + " def get_mujoco_model_metadata(self, mj_model: mujoco.MjModel) -> ksim.Metadata:\n", + " metadata = asyncio.run(ksim.get_mujoco_model_metadata(\"kbot\"))\n", + " if metadata.joint_name_to_metadata is None:\n", + " raise ValueError(\"Joint metadata is not available\")\n", + " if metadata.actuator_type_to_metadata is None:\n", + " raise ValueError(\"Actuator metadata is not available\")\n", + " return metadata\n", + "\n", + " def get_actuators(\n", + " self,\n", + " physics_model: ksim.PhysicsModel,\n", + " metadata: ksim.Metadata | None = None,\n", + " ) -> ksim.Actuators:\n", + " assert metadata is not None, \"Metadata is required\"\n", + " return ksim.PositionActuators(\n", + " physics_model=physics_model,\n", + " metadata=metadata,\n", + " )\n", + "\n", + " def get_physics_randomizers(self, physics_model: ksim.PhysicsModel) -> list[ksim.PhysicsRandomizer]:\n", + " return [\n", + " ksim.StaticFrictionRandomizer(),\n", + " ksim.ArmatureRandomizer(),\n", + " ksim.AllBodiesMassMultiplicationRandomizer(scale_lower=0.95, scale_upper=1.05),\n", + " ksim.JointDampingRandomizer(),\n", + " ksim.JointZeroPositionRandomizer(scale_lower=math.radians(-2), scale_upper=math.radians(2)),\n", + " ]\n", + "\n", + " def get_events(self, physics_model: ksim.PhysicsModel) -> list[ksim.Event]:\n", + " return [\n", + " ksim.PushEvent(\n", + " x_force=1.0,\n", + " y_force=1.0,\n", + " z_force=0.3,\n", + " force_range=(0.5, 1.0),\n", + " x_angular_force=0.0,\n", + " y_angular_force=0.0,\n", + " z_angular_force=0.0,\n", + " interval_range=(0.5, 4.0),\n", + " ),\n", + " ]\n", + "\n", + " def get_resets(self, physics_model: ksim.PhysicsModel) -> list[ksim.Reset]:\n", + " return [\n", + " ksim.RandomJointPositionReset.create(physics_model, {k: v for k, v in ZEROS}, scale=0.1),\n", + " ksim.RandomJointVelocityReset(),\n", + " ]\n", + "\n", + " def get_observations(self, physics_model: ksim.PhysicsModel) -> list[ksim.Observation]:\n", + " return [\n", + " ksim.TimestepObservation(),\n", + " ksim.JointPositionObservation(noise=math.radians(2)),\n", + " ksim.JointVelocityObservation(noise=math.radians(10)),\n", + " ksim.ActuatorForceObservation(),\n", + " ksim.CenterOfMassInertiaObservation(),\n", + " ksim.CenterOfMassVelocityObservation(),\n", + " ksim.BasePositionObservation(),\n", + " ksim.BaseOrientationObservation(),\n", + " ksim.BaseLinearVelocityObservation(),\n", + " ksim.BaseAngularVelocityObservation(),\n", + " ksim.BaseLinearAccelerationObservation(),\n", + " ksim.BaseAngularAccelerationObservation(),\n", + " ksim.ActuatorAccelerationObservation(),\n", + " ksim.ProjectedGravityObservation.create(\n", + " physics_model=physics_model,\n", + " framequat_name=\"imu_site_quat\",\n", + " lag_range=(0.0, 0.1),\n", + " noise=math.radians(1),\n", + " ),\n", + " ksim.SensorObservation.create(\n", + " physics_model=physics_model,\n", + " sensor_name=\"imu_acc\",\n", + " noise=1.0,\n", + " ),\n", + " ksim.SensorObservation.create(\n", + " physics_model=physics_model,\n", + " sensor_name=\"imu_gyro\",\n", + " noise=math.radians(10),\n", + " ),\n", + " ]\n", + "\n", + " def get_commands(self, physics_model: ksim.PhysicsModel) -> list[ksim.Command]:\n", + " return []\n", + "\n", + " def get_rewards(self, physics_model: ksim.PhysicsModel) -> list[ksim.Reward]:\n", + " return [\n", + " # Standard rewards.\n", + " ksim.NaiveForwardReward(clip_max=1.25, in_robot_frame=False, scale=3.0),\n", + " ksim.NaiveForwardOrientationReward(scale=1.0),\n", + " ksim.StayAliveReward(scale=1.0),\n", + " ksim.UprightReward(scale=0.5),\n", + " # Avoid movement penalties.\n", + " ksim.AngularVelocityPenalty(index=(\"x\", \"y\"), scale=-0.1),\n", + " ksim.LinearVelocityPenalty(index=(\"z\"), scale=-0.1),\n", + " # Normalization penalties.\n", + " ksim.AvoidLimitsPenalty.create(physics_model, scale=-0.01),\n", + " ksim.JointAccelerationPenalty(scale=-0.01, scale_by_curriculum=True),\n", + " ksim.JointJerkPenalty(scale=-0.01, scale_by_curriculum=True),\n", + " ksim.LinkAccelerationPenalty(scale=-0.01, scale_by_curriculum=True),\n", + " ksim.LinkJerkPenalty(scale=-0.01, scale_by_curriculum=True),\n", + " ksim.ActionAccelerationPenalty(scale=-0.01, scale_by_curriculum=True),\n", + " # Bespoke rewards.\n", + " BentArmPenalty.create_penalty(physics_model, scale=-0.1),\n", + " StraightLegPenalty.create_penalty(physics_model, scale=-0.1),\n", + " ]\n", + "\n", + " def get_terminations(self, physics_model: ksim.PhysicsModel) -> list[ksim.Termination]:\n", + " return [\n", + " ksim.BadZTermination(unhealthy_z_lower=0.6, unhealthy_z_upper=1.2),\n", + " ksim.FarFromOriginTermination(max_dist=10.0),\n", + " ]\n", + "\n", + " def get_curriculum(self, physics_model: ksim.PhysicsModel) -> ksim.Curriculum:\n", + " return ksim.DistanceFromOriginCurriculum(\n", + " min_level_steps=5,\n", + " )\n", + "\n", + " def get_model(self, key: PRNGKeyArray) -> Model:\n", + " return Model(\n", + " key,\n", + " num_actor_inputs=51 if self.config.use_acc_gyro else 45,\n", + " num_actor_outputs=len(ZEROS),\n", + " num_critic_inputs=446,\n", + " min_std=0.001,\n", + " max_std=1.0,\n", + " var_scale=self.config.var_scale,\n", + " hidden_size=self.config.hidden_size,\n", + " num_mixtures=self.config.num_mixtures,\n", + " depth=self.config.depth,\n", + " )\n", + "\n", + " def run_actor(\n", + " self,\n", + " model: Actor,\n", + " observations: xax.FrozenDict[str, Array],\n", + " commands: xax.FrozenDict[str, Array],\n", + " carry: Array,\n", + " ) -> tuple[distrax.Distribution, Array]:\n", + " time_1 = observations[\"timestep_observation\"]\n", + " joint_pos_n = observations[\"joint_position_observation\"]\n", + " joint_vel_n = observations[\"joint_velocity_observation\"]\n", + " proj_grav_3 = observations[\"projected_gravity_observation\"]\n", + " imu_acc_3 = observations[\"sensor_observation_imu_acc\"]\n", + " imu_gyro_3 = observations[\"sensor_observation_imu_gyro\"]\n", + "\n", + " obs = [\n", + " jnp.sin(time_1),\n", + " jnp.cos(time_1),\n", + " joint_pos_n, # NUM_JOINTS\n", + " joint_vel_n, # NUM_JOINTS\n", + " proj_grav_3, # 3\n", + " ]\n", + " if self.config.use_acc_gyro:\n", + " obs += [\n", + " imu_acc_3, # 3\n", + " imu_gyro_3, # 3\n", + " ]\n", + "\n", + " obs_n = jnp.concatenate(obs, axis=-1)\n", + " action, carry = model.forward(obs_n, carry)\n", + "\n", + " return action, carry\n", + "\n", + " def run_critic(\n", + " self,\n", + " model: Critic,\n", + " observations: xax.FrozenDict[str, Array],\n", + " commands: xax.FrozenDict[str, Array],\n", + " carry: Array,\n", + " ) -> tuple[Array, Array]:\n", + " time_1 = observations[\"timestep_observation\"]\n", + " dh_joint_pos_j = observations[\"joint_position_observation\"]\n", + " dh_joint_vel_j = observations[\"joint_velocity_observation\"]\n", + " com_inertia_n = observations[\"center_of_mass_inertia_observation\"]\n", + " com_vel_n = observations[\"center_of_mass_velocity_observation\"]\n", + " imu_acc_3 = observations[\"sensor_observation_imu_acc\"]\n", + " imu_gyro_3 = observations[\"sensor_observation_imu_gyro\"]\n", + " proj_grav_3 = observations[\"projected_gravity_observation\"]\n", + " act_frc_obs_n = observations[\"actuator_force_observation\"]\n", + " base_pos_3 = observations[\"base_position_observation\"]\n", + " base_quat_4 = observations[\"base_orientation_observation\"]\n", + "\n", + " obs_n = jnp.concatenate(\n", + " [\n", + " jnp.sin(time_1),\n", + " jnp.cos(time_1),\n", + " dh_joint_pos_j, # NUM_JOINTS\n", + " dh_joint_vel_j / 10.0, # NUM_JOINTS\n", + " com_inertia_n, # 160\n", + " com_vel_n, # 96\n", + " imu_acc_3, # 3\n", + " imu_gyro_3, # 3\n", + " proj_grav_3, # 3\n", + " act_frc_obs_n / 100.0, # NUM_JOINTS\n", + " base_pos_3, # 3\n", + " base_quat_4, # 4\n", + " ],\n", + " axis=-1,\n", + " )\n", + "\n", + " return model.forward(obs_n, carry)\n", + "\n", + " def _model_scan_fn(\n", + " self,\n", + " actor_critic_carry: tuple[Array, Array],\n", + " xs: tuple[ksim.Trajectory, PRNGKeyArray],\n", + " model: Model,\n", + " ) -> tuple[tuple[Array, Array], ksim.PPOVariables]:\n", + " transition, rng = xs\n", + "\n", + " actor_carry, critic_carry = actor_critic_carry\n", + " actor_dist, next_actor_carry = self.run_actor(\n", + " model=model.actor,\n", + " observations=transition.obs,\n", + " commands=transition.command,\n", + " carry=actor_carry,\n", + " )\n", + "\n", + " # Gets the log probabilities of the action.\n", + " log_probs = actor_dist.log_prob(transition.action)\n", + " assert isinstance(log_probs, Array)\n", + "\n", + " value, next_critic_carry = self.run_critic(\n", + " model=model.critic,\n", + " observations=transition.obs,\n", + " commands=transition.command,\n", + " carry=critic_carry,\n", + " )\n", + "\n", + " transition_ppo_variables = ksim.PPOVariables(\n", + " log_probs=log_probs,\n", + " values=value.squeeze(-1),\n", + " )\n", + "\n", + " next_carry = jax.tree.map(\n", + " lambda x, y: jnp.where(transition.done, x, y),\n", + " self.get_initial_model_carry(rng),\n", + " (next_actor_carry, next_critic_carry),\n", + " )\n", + "\n", + " return next_carry, transition_ppo_variables\n", + "\n", + " def get_ppo_variables(\n", + " self,\n", + " model: Model,\n", + " trajectory: ksim.Trajectory,\n", + " model_carry: tuple[Array, Array],\n", + " rng: PRNGKeyArray,\n", + " ) -> tuple[ksim.PPOVariables, tuple[Array, Array]]:\n", + " scan_fn = functools.partial(self._model_scan_fn, model=model)\n", + " next_model_carry, ppo_variables = xax.scan(\n", + " scan_fn,\n", + " model_carry,\n", + " (trajectory, jax.random.split(rng, len(trajectory.done))),\n", + " jit_level=4,\n", + " )\n", + " return ppo_variables, next_model_carry\n", + "\n", + " def get_initial_model_carry(self, rng: PRNGKeyArray) -> tuple[Array, Array]:\n", + " return (\n", + " jnp.zeros(shape=(self.config.depth, self.config.hidden_size)),\n", + " jnp.zeros(shape=(self.config.depth, self.config.hidden_size)),\n", + " )\n", + "\n", + " def sample_action(\n", + " self,\n", + " model: Model,\n", + " model_carry: tuple[Array, Array],\n", + " physics_model: ksim.PhysicsModel,\n", + " physics_state: ksim.PhysicsState,\n", + " observations: xax.FrozenDict[str, Array],\n", + " commands: xax.FrozenDict[str, Array],\n", + " rng: PRNGKeyArray,\n", + " argmax: bool,\n", + " ) -> ksim.Action:\n", + " actor_carry_in, critic_carry_in = model_carry\n", + " action_dist_j, actor_carry = self.run_actor(\n", + " model=model.actor,\n", + " observations=observations,\n", + " commands=commands,\n", + " carry=actor_carry_in,\n", + " )\n", + " action_j = action_dist_j.mode() if argmax else action_dist_j.sample(seed=rng)\n", + " return ksim.Action(action=action_j, carry=(actor_carry, critic_carry_in))" + ] + }, + { + "cell_type": "markdown", + "id": "4dcf85a7", + "metadata": { + "id": "4dcf85a7" + }, + "source": [ + "# Launch TensorBoard\n", + "\n", + "The below cell launches TensorBoard to visualize the training progress.\n", + "\n", + "After launching an experiment, please wait for ~5 minutes for the task to start running and then click the reload button in the top right corner of the TensorBoard page. You can also open the settings and check the \"Reload data\" option to automatically reload the TensorBoard." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "14", + "metadata": { + "id": "14" + }, + "outputs": [], + "source": [ + "# Launch TensorBoard\n", + "%load_ext tensorboard\n", + "%tensorboard --logdir humanoid_walking_task" + ] + }, + { + "cell_type": "markdown", + "id": "15", + "metadata": { + "id": "15" + }, + "source": [ + "## Launching an Experiment\n", + "\n", + "To launch an experiment with `xax`, you can use `Task.launch(config)`. Note that this is usually intended to be called from the command-line, so it will by default attempt to parse additional command-line arguments unless `use_cli=False` is set.\n", + "\n", + "By default, runs will be logged to a directory called `run_[x]` in the task directory (/content/humanoid_walking_task/ in Colab). From there, you can download the ckpt `.bin` files and the TensorBoard logs.\n", + "\n", + "Also note that since this is a Jupyter notebook, the task will be unable to find the task training code and emit a warning like \"Could not resolve task path for , returning current working directory\". You can safely ignore this warning." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "16", + "metadata": { + "id": "16" + }, + "outputs": [], + "source": [ + "if __name__ == \"__main__\":\n", + " HumanoidWalkingTask.launch(\n", + " HumanoidWalkingTaskConfig(\n", + " # Training parameters.\n", + " num_envs=2048,\n", + " batch_size=256,\n", + " num_passes=4,\n", + " epochs_per_log_step=1,\n", + " rollout_length_seconds=8.0,\n", + " global_grad_clip=2.0,\n", + " # Simulation parameters.\n", + " dt=0.002,\n", + " ctrl_dt=0.02,\n", + " iterations=8,\n", + " ls_iterations=8,\n", + " action_latency_range=(0.003, 0.01), # Simulate 3-10ms of latency.\n", + " drop_action_prob=0.05, # Drop 5% of commands.\n", + " # Visualization parameters\n", + " render_track_body_id=0,\n", + " # Checkpointing parameters.\n", + " save_every_n_seconds=60,\n", + " ),\n", + " use_cli=False,\n", + " )" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "provenance": [], + "include_colab_link": true + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} \ No newline at end of file