{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"Taxi-v3.ipynb","provenance":[],"collapsed_sections":[],"authorship_tag":"ABX9TyOyJLxJk4jmhiCnhYAcoyNF"},"kernelspec":{"name":"python3","display_name":"Python 3"}},"cells":[{"cell_type":"markdown","metadata":{"id":"yE4W6pk8dGxB"},"source":["https://www.learndatasci.com/tutorials/reinforcement-q-learning-scratch-python-openai-gym/"]},{"cell_type":"code","metadata":{"id":"j_Bz8FJuYnYO"},"source":["!pip install cmake 'gym[atari]' scipy"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"32VTNLFPY0CY"},"source":["import gym\n","\n","env = gym.make(\"Taxi-v3\").env\n","\n","env.render()"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"mDyF4qnTY6nF"},"source":["env.reset() # reset environment to a new, random state\n","env.render()\n","\n","print(\"Action Space {}\".format(env.action_space))\n","print(\"State Space {}\".format(env.observation_space))"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"pHJMQHgZZHNx"},"source":["state = env.encode(3, 1, 2, 0) # (taxi row, taxi column, passenger index, destination index)\n","print(\"State:\", state)\n","\n","env.s = state\n","env.render()"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"NsS9-CwMZN8a"},"source":["env.P[328]"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"bCSozaDQZR2O"},"source":["env.s = 328 # set environment to illustration's state\n","\n","epochs = 0\n","penalties, reward = 0, 0\n","\n","frames = [] # for animation\n","\n","done = False\n","\n","while not done:\n"," action = env.action_space.sample()\n"," state, reward, done, info = env.step(action)\n","\n"," if reward == -10:\n"," penalties += 1\n"," \n"," # Put each rendered frame into dict for animation\n"," frames.append({\n"," 'frame': env.render(mode='ansi'),\n"," 'state': state,\n"," 'action': action,\n"," 'reward': reward\n"," }\n"," )\n","\n"," epochs += 1\n"," \n"," \n","print(\"Timesteps taken: {}\".format(epochs))\n","print(\"Penalties incurred: {}\".format(penalties))"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"vgzJIvpiZjjk"},"source":["from IPython.display import clear_output\n","from time import sleep\n","\n","def print_frames(frames):\n"," for i, frame in enumerate(frames):\n"," clear_output(wait=True)\n"," print(frame['frame'])\n"," print(f\"Timestep: {i + 1}\")\n"," print(f\"State: {frame['state']}\")\n"," print(f\"Action: {frame['action']}\")\n"," print(f\"Reward: {frame['reward']}\")\n"," sleep(.1)\n"," \n","print_frames(frames)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"aT8dWtEDZu1Q"},"source":["import numpy as np\n","q_table = np.zeros([env.observation_space.n, env.action_space.n])"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"1v3A0Yh2bV1A"},"source":["%%time\n","\"\"\"Training the agent\"\"\"\n","\n","import random\n","from IPython.display import clear_output\n","\n","# Hyperparameters\n","alpha = 0.1\n","gamma = 0.6\n","epsilon = 0.1\n","\n","# For plotting metrics\n","all_epochs = []\n","all_penalties = []\n","\n","for i in range(1, 100001):\n"," state = env.reset()\n","\n"," epochs, penalties, reward, = 0, 0, 0\n"," done = False\n"," \n"," while not done:\n"," \n"," if random.uniform(0, 1) < epsilon:\n"," action = env.action_space.sample() # Explore action space, non-greedy (NG) action selection \n"," else:\n"," action = np.argmax(q_table[state]) # Exploit learned values, greedy (G) action selection\n","\n"," next_state, reward, done, info = env.step(action) \n","\n"," \n"," old_value = q_table[state, action]\n"," next_max = np.max(q_table[next_state]) # q-table update is always greedy (np.max). q-learning is off-police since the action taken can be of different policy (non-greedy, random) (NG)\n"," \n"," new_value = (1 - alpha) * old_value + alpha * (reward + gamma * next_max)\n"," q_table[state, action] = new_value\n","\n"," if reward == -10:\n"," penalties += 1\n","\n"," state = next_state\n"," epochs += 1\n"," \n"," if i % 100 == 0:\n"," clear_output(wait=True)\n"," print(f\"Episode: {i}\")\n","\n","print(\"Training finished.\\n\")"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"kVUMYQEubcqW"},"source":["q_table[328]"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"p92pLuSeb2n0"},"source":["\"\"\"Evaluate agent's performance after Q-learning\"\"\"\n","frames = [] # for animation\n","\n","total_epochs, total_penalties = 0, 0\n","episodes = 100\n","\n","for _ in range(episodes):\n"," state = env.reset()\n"," epochs, penalties, reward = 0, 0, 0\n"," \n"," done = False\n"," \n"," while not done:\n"," # Put each rendered frame into dict for animation\n"," frames.append({\n"," 'frame': env.render(mode='ansi'),\n"," 'state': state,\n"," 'action': action,\n"," 'reward': reward\n"," })\n","\n"," action = np.argmax(q_table[state])\n"," state, reward, done, info = env.step(action)\n","\n"," if reward == -10:\n"," penalties += 1\n","\n"," epochs += 1\n","\n"," total_penalties += penalties\n"," total_epochs += epochs\n","\n","print(f\"Results after {episodes} episodes:\")\n","print(f\"Average timesteps per episode: {total_epochs / episodes}\")\n","print(f\"Average penalties per episode: {total_penalties / episodes}\")"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"e5YWHRyIb9Ot"},"source":["print_frames(frames)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"p3_PVA-XolVY"},"source":[""],"execution_count":null,"outputs":[]}]}