750 lines
83 KiB
Plaintext
750 lines
83 KiB
Plaintext
|
{
|
|||
|
"cells": [
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"## Игра «Крестики-нолики».\n",
|
|||
|
"\n",
|
|||
|
"---\n",
|
|||
|
"\n",
|
|||
|
"## Исходный проект: [GitHub](https://github.com/nczempin/gym-tic-tac-toe).\n",
|
|||
|
"\n",
|
|||
|
"### Описание проекта:\n",
|
|||
|
"\n",
|
|||
|
"Проект представляет собой реализацию среды для игры в крестики-нолики с использованием библиотеки Gymnasium. Среда моделирует игровую логику с соблюдением правил, таких как легальность ходов, определение победителя и смена хода между игроками. Основные методы среды включают step(action) для выполнения хода, reset() для сброса игрового поля и render() для отображения текущего состояния игры. Также реализован генератор возможных ходов.\n",
|
|||
|
"\n",
|
|||
|
"Для тестирования и демонстрации работы используется случайный выбор хода, который осуществляется с помощью функции random_move(). Игровой процесс симулируется в серии эпизодов (300 партий), где игроки делают ходы до окончания игры. После каждого эпизода вычисляются награды и отображаются результаты.\n",
|
|||
|
"\n",
|
|||
|
"---"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"### Реализация среды:\n",
|
|||
|
"\n",
|
|||
|
"**Среда (environment)** – среда, в которой объект выполняет действия для решения задачи."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 131,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"import gymnasium as gym\n",
|
|||
|
"from gymnasium import spaces\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"class TicTacToeEnv(gym.Env):\n",
|
|||
|
" metadata = {'render.modes': ['human']}\n",
|
|||
|
" \n",
|
|||
|
" symbols = ['O', ' ', 'X'];\n",
|
|||
|
"\n",
|
|||
|
" def __init__(self) -> None:\n",
|
|||
|
" super().__init__()\n",
|
|||
|
" self.action_space = spaces.Discrete(9)\n",
|
|||
|
" self.observation_space = spaces.Discrete(9*3*2) # flattened\n",
|
|||
|
" self.reset()\n",
|
|||
|
" \n",
|
|||
|
" def step(self, action):\n",
|
|||
|
" done = False\n",
|
|||
|
" reward = 0\n",
|
|||
|
"\n",
|
|||
|
" p, square = action\n",
|
|||
|
" \n",
|
|||
|
" # check move legality\n",
|
|||
|
" board = self.state['board']\n",
|
|||
|
" proposed = board[square]\n",
|
|||
|
" om = self.state['on_move']\n",
|
|||
|
" \n",
|
|||
|
" if proposed != 0: # wrong player, not empty\n",
|
|||
|
" print(\"illegal move \", action, \". (square occupied): \", square)\n",
|
|||
|
" done = True\n",
|
|||
|
" reward = -1 * om # player who did NOT make the illegal mov\n",
|
|||
|
" if p != om: # wrong player, not empty\n",
|
|||
|
" print(\"illegal move \", action, \" not on move: \", p)\n",
|
|||
|
" done = True\n",
|
|||
|
" reward = -1 * om # player who did NOT make the illegal move\n",
|
|||
|
" else:\n",
|
|||
|
" board[square] = p\n",
|
|||
|
" self.state['on_move'] = -p\n",
|
|||
|
"\n",
|
|||
|
" # check game over\n",
|
|||
|
" for i in range(3):\n",
|
|||
|
" # horizontals and verticals\n",
|
|||
|
" if ((board[i * 3] == p and board[i * 3 + 1] == p and board[i * 3 + 2] == p)\n",
|
|||
|
" or (board[i + 0] == p and board[i + 3] == p and board[i + 6] == p)):\n",
|
|||
|
" reward = p\n",
|
|||
|
" done = True\n",
|
|||
|
" break\n",
|
|||
|
" # diagonals\n",
|
|||
|
" if ((board[0] == p and board[4] == p and board[8] == p)\n",
|
|||
|
" or (board[2] == p and board[4] == p and board[6] == p)):\n",
|
|||
|
" reward = p\n",
|
|||
|
" done = True\n",
|
|||
|
" \n",
|
|||
|
" return self.state, reward, done, {}\n",
|
|||
|
" \n",
|
|||
|
" def reset(self):\n",
|
|||
|
" self.state = {}\n",
|
|||
|
" self.state['board'] = [0, 0, 0, 0, 0, 0, 0, 0, 0]\n",
|
|||
|
" self.state['on_move'] = 1\n",
|
|||
|
" return self.state\n",
|
|||
|
" \n",
|
|||
|
" def render(self, close=False):\n",
|
|||
|
" if close:\n",
|
|||
|
" return\n",
|
|||
|
" print(\"on move: \" , self.symbols[self.state['on_move']+1])\n",
|
|||
|
" for i in range(9):\n",
|
|||
|
" print(self.symbols[self.state['board'][i]+1], end=\" \")\n",
|
|||
|
" if i % 3 == 2:\n",
|
|||
|
" print()\n",
|
|||
|
" print()\n",
|
|||
|
" \n",
|
|||
|
" def move_generator(self):\n",
|
|||
|
" moves = []\n",
|
|||
|
" for i in range(9):\n",
|
|||
|
" if self.state['board'][i] == 0:\n",
|
|||
|
" p = self.state['on_move']\n",
|
|||
|
" m = [p, i]\n",
|
|||
|
" moves.append(m)\n",
|
|||
|
" return moves"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"### Реализация основного цикла обучения:"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 132,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"on move: O\n",
|
|||
|
" \n",
|
|||
|
" \n",
|
|||
|
"X \n",
|
|||
|
"\n",
|
|||
|
"on move: X\n",
|
|||
|
" \n",
|
|||
|
"O \n",
|
|||
|
"X \n",
|
|||
|
"\n",
|
|||
|
"on move: O\n",
|
|||
|
" X \n",
|
|||
|
"O \n",
|
|||
|
"X \n",
|
|||
|
"\n",
|
|||
|
"on move: X\n",
|
|||
|
"O X \n",
|
|||
|
"O \n",
|
|||
|
"X \n",
|
|||
|
"\n",
|
|||
|
"on move: O\n",
|
|||
|
"O X \n",
|
|||
|
"O \n",
|
|||
|
"X X \n",
|
|||
|
"\n",
|
|||
|
"on move: X\n",
|
|||
|
"O X O \n",
|
|||
|
"O \n",
|
|||
|
"X X \n",
|
|||
|
"\n",
|
|||
|
"on move: O\n",
|
|||
|
"O X O \n",
|
|||
|
"O X \n",
|
|||
|
"X X \n",
|
|||
|
"\n",
|
|||
|
"on move: X\n",
|
|||
|
"O X O \n",
|
|||
|
"O X O \n",
|
|||
|
"X X \n",
|
|||
|
"\n",
|
|||
|
"on move: O\n",
|
|||
|
"O X O \n",
|
|||
|
"O X O \n",
|
|||
|
"X X X \n",
|
|||
|
"\n",
|
|||
|
"Episode 50, Total Reward: 1\n",
|
|||
|
"Average Reward: 0.34\n",
|
|||
|
"\n",
|
|||
|
"on move: O\n",
|
|||
|
" X \n",
|
|||
|
" \n",
|
|||
|
" \n",
|
|||
|
"\n",
|
|||
|
"on move: X\n",
|
|||
|
" X \n",
|
|||
|
" \n",
|
|||
|
"O \n",
|
|||
|
"\n",
|
|||
|
"on move: O\n",
|
|||
|
" X \n",
|
|||
|
"X \n",
|
|||
|
"O \n",
|
|||
|
"\n",
|
|||
|
"on move: X\n",
|
|||
|
"O X \n",
|
|||
|
"X \n",
|
|||
|
"O \n",
|
|||
|
"\n",
|
|||
|
"on move: O\n",
|
|||
|
"O X X \n",
|
|||
|
"X \n",
|
|||
|
"O \n",
|
|||
|
"\n",
|
|||
|
"on move: X\n",
|
|||
|
"O X X \n",
|
|||
|
"X O \n",
|
|||
|
"O \n",
|
|||
|
"\n",
|
|||
|
"on move: O\n",
|
|||
|
"O X X \n",
|
|||
|
"X X O \n",
|
|||
|
"O \n",
|
|||
|
"\n",
|
|||
|
"on move: X\n",
|
|||
|
"O X X \n",
|
|||
|
"X X O \n",
|
|||
|
"O O \n",
|
|||
|
"\n",
|
|||
|
"on move: O\n",
|
|||
|
"O X X \n",
|
|||
|
"X X O \n",
|
|||
|
"O X O \n",
|
|||
|
"\n",
|
|||
|
"Episode 100, Total Reward: 1\n",
|
|||
|
"Average Reward: 0.36\n",
|
|||
|
"\n",
|
|||
|
"on move: O\n",
|
|||
|
"X \n",
|
|||
|
" \n",
|
|||
|
" \n",
|
|||
|
"\n",
|
|||
|
"on move: X\n",
|
|||
|
"X \n",
|
|||
|
" \n",
|
|||
|
"O \n",
|
|||
|
"\n",
|
|||
|
"on move: O\n",
|
|||
|
"X \n",
|
|||
|
"X \n",
|
|||
|
"O \n",
|
|||
|
"\n",
|
|||
|
"on move: X\n",
|
|||
|
"X \n",
|
|||
|
"X \n",
|
|||
|
"O O \n",
|
|||
|
"\n",
|
|||
|
"on move: O\n",
|
|||
|
"X X \n",
|
|||
|
"X \n",
|
|||
|
"O O \n",
|
|||
|
"\n",
|
|||
|
"on move: X\n",
|
|||
|
"X X \n",
|
|||
|
"X \n",
|
|||
|
"O O O \n",
|
|||
|
"\n",
|
|||
|
"Episode 150, Total Reward: -1\n",
|
|||
|
"Average Reward: 0.35333333333333333\n",
|
|||
|
"\n",
|
|||
|
"on move: O\n",
|
|||
|
" \n",
|
|||
|
" \n",
|
|||
|
" X \n",
|
|||
|
"\n",
|
|||
|
"on move: X\n",
|
|||
|
" \n",
|
|||
|
" O \n",
|
|||
|
" X \n",
|
|||
|
"\n",
|
|||
|
"on move: O\n",
|
|||
|
" X \n",
|
|||
|
" O \n",
|
|||
|
" X \n",
|
|||
|
"\n",
|
|||
|
"on move: X\n",
|
|||
|
"O X \n",
|
|||
|
" O \n",
|
|||
|
" X \n",
|
|||
|
"\n",
|
|||
|
"on move: O\n",
|
|||
|
"O X X \n",
|
|||
|
" O \n",
|
|||
|
" X \n",
|
|||
|
"\n",
|
|||
|
"on move: X\n",
|
|||
|
"O X X \n",
|
|||
|
"O O \n",
|
|||
|
" X \n",
|
|||
|
"\n",
|
|||
|
"on move: O\n",
|
|||
|
"O X X \n",
|
|||
|
"O O \n",
|
|||
|
"X X \n",
|
|||
|
"\n",
|
|||
|
"on move: X\n",
|
|||
|
"O X X \n",
|
|||
|
"O O \n",
|
|||
|
"X X O \n",
|
|||
|
"\n",
|
|||
|
"on move: O\n",
|
|||
|
"O X X \n",
|
|||
|
"O X O \n",
|
|||
|
"X X O \n",
|
|||
|
"\n",
|
|||
|
"Episode 200, Total Reward: 1\n",
|
|||
|
"Average Reward: 0.355\n",
|
|||
|
"\n",
|
|||
|
"on move: O\n",
|
|||
|
" X \n",
|
|||
|
" \n",
|
|||
|
" \n",
|
|||
|
"\n",
|
|||
|
"on move: X\n",
|
|||
|
" X \n",
|
|||
|
" \n",
|
|||
|
"O \n",
|
|||
|
"\n",
|
|||
|
"on move: O\n",
|
|||
|
" X \n",
|
|||
|
" \n",
|
|||
|
"O X \n",
|
|||
|
"\n",
|
|||
|
"on move: X\n",
|
|||
|
"O X \n",
|
|||
|
" \n",
|
|||
|
"O X \n",
|
|||
|
"\n",
|
|||
|
"on move: O\n",
|
|||
|
"O X \n",
|
|||
|
" \n",
|
|||
|
"O X X \n",
|
|||
|
"\n",
|
|||
|
"on move: X\n",
|
|||
|
"O X O \n",
|
|||
|
" \n",
|
|||
|
"O X X \n",
|
|||
|
"\n",
|
|||
|
"on move: O\n",
|
|||
|
"O X O \n",
|
|||
|
"X \n",
|
|||
|
"O X X \n",
|
|||
|
"\n",
|
|||
|
"on move: X\n",
|
|||
|
"O X O \n",
|
|||
|
"X O \n",
|
|||
|
"O X X \n",
|
|||
|
"\n",
|
|||
|
"Episode 250, Total Reward: -1\n",
|
|||
|
"Average Reward: 0.384\n",
|
|||
|
"\n",
|
|||
|
"on move: O\n",
|
|||
|
"X \n",
|
|||
|
" \n",
|
|||
|
" \n",
|
|||
|
"\n",
|
|||
|
"on move: X\n",
|
|||
|
"X \n",
|
|||
|
" O \n",
|
|||
|
" \n",
|
|||
|
"\n",
|
|||
|
"on move: O\n",
|
|||
|
"X \n",
|
|||
|
" O \n",
|
|||
|
" X \n",
|
|||
|
"\n",
|
|||
|
"on move: X\n",
|
|||
|
"X \n",
|
|||
|
" O O \n",
|
|||
|
" X \n",
|
|||
|
"\n",
|
|||
|
"on move: O\n",
|
|||
|
"X \n",
|
|||
|
"X O O \n",
|
|||
|
" X \n",
|
|||
|
"\n",
|
|||
|
"on move: X\n",
|
|||
|
"X O \n",
|
|||
|
"X O O \n",
|
|||
|
" X \n",
|
|||
|
"\n",
|
|||
|
"on move: O\n",
|
|||
|
"X O \n",
|
|||
|
"X O O \n",
|
|||
|
" X X \n",
|
|||
|
"\n",
|
|||
|
"on move: X\n",
|
|||
|
"X O O \n",
|
|||
|
"X O O \n",
|
|||
|
" X X \n",
|
|||
|
"\n",
|
|||
|
"on move: O\n",
|
|||
|
"X O O \n",
|
|||
|
"X O O \n",
|
|||
|
"X X X \n",
|
|||
|
"\n",
|
|||
|
"Episode 300, Total Reward: 1\n",
|
|||
|
"Average Reward: 0.36333333333333334\n",
|
|||
|
"\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"import random\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"def random_move(moves):\n",
|
|||
|
" m = random.choice(moves)\n",
|
|||
|
" return m\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"env = TicTacToeEnv()\n",
|
|||
|
"\n",
|
|||
|
"alpha = 0.01\n",
|
|||
|
"beta = 0.01\n",
|
|||
|
"\n",
|
|||
|
"num_episodes = 300\n",
|
|||
|
"\n",
|
|||
|
"collected_rewards = []\n",
|
|||
|
"oom = 1\n",
|
|||
|
"\n",
|
|||
|
"for i in range(num_episodes):\n",
|
|||
|
" state = env.reset()\n",
|
|||
|
" \n",
|
|||
|
" total_reward = 0\n",
|
|||
|
" \n",
|
|||
|
" done = False\n",
|
|||
|
" om = oom;\n",
|
|||
|
"\n",
|
|||
|
" for j in range(9):\n",
|
|||
|
" moves = env.move_generator()\n",
|
|||
|
" if not moves:\n",
|
|||
|
" break\n",
|
|||
|
" \n",
|
|||
|
" if len(moves) == 1:\n",
|
|||
|
" # only a single possible move\n",
|
|||
|
" move = moves[0]\n",
|
|||
|
" else:\n",
|
|||
|
" move = random_move(moves)\n",
|
|||
|
" \n",
|
|||
|
" next_state, reward, done, info = env.step(move)\n",
|
|||
|
" total_reward += reward\n",
|
|||
|
" state = next_state\n",
|
|||
|
" \n",
|
|||
|
" if (i + 1) % 50 == 0: \n",
|
|||
|
" env.render()\n",
|
|||
|
" \n",
|
|||
|
" if done:\n",
|
|||
|
" break\n",
|
|||
|
" \n",
|
|||
|
" om = -om\n",
|
|||
|
"\n",
|
|||
|
" collected_rewards.append(total_reward)\n",
|
|||
|
" \n",
|
|||
|
" if (i + 1) % 50 == 0: \n",
|
|||
|
" print(f\"Episode {i+1}, Total Reward: {total_reward}\")\n",
|
|||
|
" average_reward = sum(collected_rewards) / len(collected_rewards)\n",
|
|||
|
" print(f\"Average Reward: {average_reward}\\n\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"## Обновлённая реализация проекта.\n",
|
|||
|
"\n",
|
|||
|
"### Описание проекта:\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"### Основные изменения:\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"---"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"### Реализация среды:\n",
|
|||
|
"\n",
|
|||
|
"**Среда (environment)** – среда, в которой объект выполняет действия для решения задачи."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 133,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"import numpy as np\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"class TicTacToeEnv(gym.Env):\n",
|
|||
|
" metadata: dict[str, list[str]] = {\"render_modes\": [\"ansi\"]}\n",
|
|||
|
" \n",
|
|||
|
" def __init__(self) -> None:\n",
|
|||
|
" self.action_space = spaces.Discrete(9) # 9 клеток\n",
|
|||
|
" self.observation_space = spaces.Box(low=-1, high=1, shape=(9,), dtype=int)\n",
|
|||
|
" self.symbols: dict[int, str] = {1: \"X\", -1: \"O\", 0: \" \"}\n",
|
|||
|
"\n",
|
|||
|
" self.reset()\n",
|
|||
|
"\n",
|
|||
|
" def reset(self, seed=None):\n",
|
|||
|
" super().reset(seed=seed)\n",
|
|||
|
" self.board = np.zeros(9, dtype=int) # Пустое поле\n",
|
|||
|
" self.current_player = 1 # Ход первого игрока\n",
|
|||
|
" return self.board\n",
|
|||
|
"\n",
|
|||
|
" def step(self, action):\n",
|
|||
|
" if self.board[action] != 0:\n",
|
|||
|
" # Нелегальный ход (клетка уже занята)\n",
|
|||
|
" reward = -self.current_player # Штраф за нелегальный ход: награду получает тот, кто НЕ совершил ошибку\n",
|
|||
|
" self.current_player *= -1 # Смена хода к следующему игроку\n",
|
|||
|
" return self.board, reward, False, False, {}\n",
|
|||
|
"\n",
|
|||
|
" # Совершение хода\n",
|
|||
|
" self.board[action] = self.current_player\n",
|
|||
|
"\n",
|
|||
|
" # Проверка на победу\n",
|
|||
|
" if self.check_winner(self.current_player):\n",
|
|||
|
" reward = self.current_player\n",
|
|||
|
" terminated = True\n",
|
|||
|
" elif np.all(self.board != 0):\n",
|
|||
|
" # Ничья\n",
|
|||
|
" reward = 0\n",
|
|||
|
" terminated = True\n",
|
|||
|
" else:\n",
|
|||
|
" # Продолжение игры\n",
|
|||
|
" reward = 0\n",
|
|||
|
" terminated = False\n",
|
|||
|
" self.current_player *= -1 # Смена хода\n",
|
|||
|
"\n",
|
|||
|
" return self.board, reward, terminated, False, {}\n",
|
|||
|
"\n",
|
|||
|
" def check_winner(self, player):\n",
|
|||
|
" winning_positions: list[tuple[int, int, int]] = [\n",
|
|||
|
" (0, 1, 2), (3, 4, 5), (6, 7, 8), # Горизонтали\n",
|
|||
|
" (0, 3, 6), (1, 4, 7), (2, 5, 8), # Вертикали\n",
|
|||
|
" (0, 4, 8), (2, 4, 6), # Диагонали\n",
|
|||
|
" ]\n",
|
|||
|
" for positions in winning_positions:\n",
|
|||
|
" if all(self.board[pos] == player for pos in positions):\n",
|
|||
|
" return True\n",
|
|||
|
" return False\n",
|
|||
|
"\n",
|
|||
|
" def render(self):\n",
|
|||
|
" board = self.board.reshape(3, 3)\n",
|
|||
|
" print(\"\\n\".join(\" | \".join(self.symbols[cell] for cell in row) for row in board))\n",
|
|||
|
" print()"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"### Реализация агента:\n",
|
|||
|
"\n",
|
|||
|
"**Агент (agent)** – объект обучения, который выполняет действия в некоторой среде для получения вознаграждения, принимая решения на основе своих целей и информации, которую он получает."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 134,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"class TicTacToeAgent:\n",
|
|||
|
" def __init__(self, learning_rate=0.1, discount_factor=0.9, epsilon=0.1):\n",
|
|||
|
" self.q_table = {} # Q-таблица\n",
|
|||
|
" self.learning_rate = learning_rate\n",
|
|||
|
" self.discount_factor = discount_factor\n",
|
|||
|
" self.epsilon = epsilon\n",
|
|||
|
"\n",
|
|||
|
" def get_state_key(self, state):\n",
|
|||
|
" return tuple(state)\n",
|
|||
|
"\n",
|
|||
|
" def select_action(self, state, possible_actions):\n",
|
|||
|
" state_key = self.get_state_key(state)\n",
|
|||
|
" if np.random.rand() < self.epsilon or state_key not in self.q_table:\n",
|
|||
|
" # Исследование: случайный ход\n",
|
|||
|
" return np.random.choice(possible_actions)\n",
|
|||
|
" # Эксплуатация: выбор действия с максимальным Q-значением\n",
|
|||
|
" return max(possible_actions, key=lambda a: self.q_table.get((state_key, a), 0))\n",
|
|||
|
"\n",
|
|||
|
" def update(self, state, action, reward, next_state, possible_actions, terminated):\n",
|
|||
|
" state_key = self.get_state_key(state)\n",
|
|||
|
" next_state_key = self.get_state_key(next_state)\n",
|
|||
|
"\n",
|
|||
|
" if (state_key, action) not in self.q_table:\n",
|
|||
|
" self.q_table[(state_key, action)] = 0\n",
|
|||
|
"\n",
|
|||
|
" if terminated:\n",
|
|||
|
" future_reward = 0\n",
|
|||
|
" else:\n",
|
|||
|
" future_reward = max(self.q_table.get((next_state_key, a), 0) for a in possible_actions)\n",
|
|||
|
"\n",
|
|||
|
" # Формула Q-learning\n",
|
|||
|
" td_target = reward + self.discount_factor * future_reward\n",
|
|||
|
" td_error = td_target - self.q_table[(state_key, action)]\n",
|
|||
|
" self.q_table[(state_key, action)] += self.learning_rate * td_error"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"### Реализация основного цикла обучения:"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 135,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Average Reward (last 50 episodes): 0.26\n",
|
|||
|
"Average Reward (last 50 episodes): 0.26\n",
|
|||
|
"Average Reward (last 50 episodes): 0.42\n",
|
|||
|
"Average Reward (last 50 episodes): 0.38\n",
|
|||
|
"Average Reward (last 50 episodes): 0.38\n",
|
|||
|
"Average Reward (last 50 episodes): 0.26\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAABKUAAAJOCAYAAABm7rQwAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAC1n0lEQVR4nOzdd3gU1dvG8XvTCSGE0EIJvSpFinQIVYo0pfcOFlS6oiBFRBRpooIFARGk+KNYKKIEUUAUAbFRDU2atIT0TTLvH/MmEpJAApvsbvL9XNde7Jk5M/Ps5pDy7DnPWAzDMAQAAAAAAABkIRd7BwAAAAAAAICch6QUAAAAAAAAshxJKQAAAAAAAGQ5klIAAAAAAADIciSlAAAAAAAAkOVISgEAAAAAACDLkZQCAAAAAABAliMpBQAAAAAAgCxHUgoAAAAAAABZjqQUAAA5wMCBA1WqVCmbnnPZsmWyWCw6deqUTc/raGbPnq0yZcrI1dVVDz30kL3DQTYwdepUWSyWLL3mqVOnZLFYtGzZsiy9LgAAd0JSCgCAdDp58qRGjBihMmXKyMvLS76+vmrYsKEWLFigqKgoe4eXaWbOnKmNGzfaO4wkicmwxIeXl5cqVKigkSNH6tKlSza91tdff60JEyaoYcOGWrp0qWbOnGnT8+dUX375pdq0aaP8+fMnff3GjRunq1ev2ju0FG4fb7c/fvzxR3uHCACA03KzdwAAADiDr776St26dZOnp6f69++vKlWqKDY2Vj/88IPGjx+vP/74Q++//769w8wUM2fOVNeuXdW5c+dk2/v166eePXvK09PTLnFNnz5dpUuXVnR0tH744QctWrRImzdv1u+//y5vb2+bXGPHjh1ycXHRkiVL5OHhYZNz5nTjxo3TnDlzVL16dT3//PPy9/fXgQMH9Pbbb2v16tX69ttvVbFiRXuHmULieLtduXLlMnyuSZMm6YUXXrBFWAAAODWSUgAA3EVISIh69uypkiVLaseOHSpSpEjSvqefflonTpzQV199ZccI7cPV1VWurq52u37btm1Vu3ZtSdLQoUOVP39+zZ07V5s2bVKvXr3u69yRkZHy9vbW5cuXlStXLpslpAzDUHR0tHLlymWT8zmbTz/9VHPmzFGPHj20cuXKZONn4MCBatasmbp166YDBw7IzS3rfk2NiIhQ7ty579jn1vF2v9zc3LL09QEA4KhYvgcAwF288cYbCg8P15IlS5IlpBKVK1dOzz33nKQ7122xWCyaOnVqUjuxrsyxY8fUt29f5c2bVwULFtTkyZNlGIbOnj2rTp06ydfXVwEBAZozZ06y86VV02nnzp2yWCzauXPnHV/Xm2++qQYNGih//vzKlSuXatWqpc8++yxFzBEREVq+fHnScqWBAwemev327durTJkyqV6rfv36Kf6g/+STT1SrVi3lypVL/v7+6tmzp86ePXvHmO+kefPmkswkYkau0bRpU1WpUkW//PKLmjRpIm9vb7344ouyWCxaunSpIiIikl574tc1Li5Or7zyisqWLStPT0+VKlVKL774omJiYpKdu1SpUmrfvr22bdum2rVrK1euXHrvvfeSvkZr167VtGnTVKxYMeXJk0ddu3ZVaGioYmJiNGrUKBUqVEg+Pj4aNGhQinMvXbpUzZs3V6FCheTp6akHHnhAixYtSvG+JMbwww8/qE6dOvLy8lKZMmX08ccfp+h748YNjR49WqVKlZKnp6eKFy+u/v3768qVK0l9YmJiNGXKFJUrV06enp4KDAzUhAkTUsSXmmnTpilfvnx6//33UyQ069Spo+eff16//fZb0jgcOXKkfHx8FBkZmeJcvXr1UkBAgOLj45O2bdmyRY0bN1bu3LmVJ08ePfroo/rjjz+SHTdw4ED5+Pjo5MmTateunfLkyaM+ffrcNfa7Sfy//+abb2revHkqWbKkcuXKpaCgIP3+++/J+qZWU2r79u1q1KiR/Pz85OPjo4oVK+rFF19M1ufy5csaMmSIChcuLC8vL1WvXl3Lly9PEcuNGzc0cOBA5c2bV35+fhowYIBu3LiRatxHjhxR165d5e/vLy8vL9WuXVuff/55sj5Wq1XTpk1T+fLl5eXlpfz586tRo0bavn37PbxTAAD8h6QUAAB38cUXX6hMmTJq0KBBppy/R48eSkhI0KxZs1S3bl3NmDFD8+fPV6tWrVSsWDG9/vrrKleunMaNG6ddu3bZ7LoLFixQjRo1NH36dM2cOVNubm7q1q1bsllfK1askKenpxo3bqwVK1ZoxYoVGjFiRJqvIyQkRD///HOy7adPn9aPP/6onj17Jm179dVX1b9/f5UvX15z587VqFGj9O2336pJkyZp/vF8NydPnpQk5c+fP8PXuHr1qtq2bauHHnpI8+fPV7NmzbRixQo1btxYnp6eSa+9SZMmksyZWS+//LJq1qypefPmKSgoSK+99lqy15jo6NGj6tWrl1q1aqUFCxYkK5b+2muvadu2bXrhhRc0ePBgrV+/Xk888YQGDx6sY8eOaerUqXr88ce1bNkyvf7668nOu2jRIpUsWVIvvvii5syZo8DAQD311FN65513UsRw4sQJde3aVa1atdKcOXOUL18+DRw4MFnCJjw8XI0bN9bChQv1yCOPaMGCBXriiSd05MgRnTt3TpKUkJCgjh076s0331SHDh20cOFCde7cWfPmzVOPHj3u+PU5fvy4jh49mpRoTU3//v0lmTWnJHNMRUREpJiJGBkZqS+++EJdu3ZNSm6tWLFCjz76qHx8fPT6669r8uTJ+vPPP9WoUaMUidu4uDi1bt1ahQoV0ptvvqkuXbrcMXZJCg0N1ZUrV5I9UquB9fHHH+utt97S008/rYkTJ+r3339X8+bN71jv7I8//lD79u0VExOj6dOna86cOerYsaN2796d1CcqKkpNmzbVihUr1KdPH82ePVt58+bVwIEDtWDBgqR+hmGoU6dOWrFihfr27asZM2bo3LlzGjBgQKrXrVevnv766y+98MILmjNnjnLnzq3OnTtrw4YNSf2mTp2qadOmqVmzZnr77bf10ksvqUSJEjpw4MBd3zcAAO7IAAAAaQoNDTUkGZ06dUpX/5CQEEOSsXTp0hT7JBlTpkxJak+ZMsWQZAwfPjxpW1xcnFG8eHHDYrEYs2bNStp+/fp1I1euXMaAAQOSti1dutSQZISEhCS7TnBwsCHJCA4OTto2YMAAo2TJksn6RUZGJmvHxsYaVapUMZo3b55se+7cuZNdN63rh4aGGp6ensbYsWOT9XvjjTcMi8VinD592jAMwzh16pTh6upqvPrqq8n6/fbbb4abm1uK7Wld95tvvjH+/fdf4+zZs8bq1auN/PnzG7ly5TLOnTuXoWsEBQUZkozFixenuNaAAQOM3LlzJ9t26NAhQ5IxdOjQZNvHjRtnSDJ27NiRtK1kyZKGJGPr1q3J+iZ+japUqWLExsYmbe/Vq5dhsViMtm3bJutfv379u379DMMwWrdubZQpUybZtsQYdu3albTt8uXLKb5WL7/8siHJWL9+fYrzJiQkGIZhGCtWrDBcXFyM77//Ptn+xYsXG5KM3bt3pzg20caNGw1Jxrx589LsYxiG4evra9SsWTPpusWKFTO6dOmSrM/atWuTvaabN28afn5+xrBhw5L1u3jxopE3b95k2wcMGGBIMl544YU7xpEocbyl9vD09Ezql/h/P3EMJtq3b58hyRg9enTStsT/+4nmzZtnSDL+/fffNOOYP3++Icn45JNPkrbFxsYa9evXN3x8fIywsDDDMP57n994442kfnFxcUbjxo1TfG9q0aKFUbVqVSM6OjppW0JCgtGgQQOjfPnySduqV69uPProo+l6vwAAyAhmSgEAcAdhYWGSpDx58mTaNYYOHZr03NXVVbVr15ZhGBoyZEjSdj8/P1WsWFF///23za57a12j69evKzQ0VI0bN77n2Q++vr5q27at1q5dK8MwkravWbNG9erVU4kSJSRJ69evV0JCgrp3755s1klAQIDKly+v4ODgdF2vZcuWKliwoAIDA9WzZ0/5+Phow4YNKla
|
|||
|
"text/plain": [
|
|||
|
"<Figure size 1200x600 with 1 Axes>"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"import matplotlib.pyplot as plt\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"# Параметры обучения\n",
|
|||
|
"learning_rate = 0.1\n",
|
|||
|
"discount_factor = 0.9\n",
|
|||
|
"epsilon = 0.1\n",
|
|||
|
"episodes = 300\n",
|
|||
|
"\n",
|
|||
|
"# Инициализация среды и агента\n",
|
|||
|
"env = TicTacToeEnv()\n",
|
|||
|
"agent = TicTacToeAgent(learning_rate, discount_factor, epsilon)\n",
|
|||
|
"\n",
|
|||
|
"# Инициализация метрик\n",
|
|||
|
"statistics = {\n",
|
|||
|
" \"Episode\": [],\n",
|
|||
|
" \"Total Reward\": [],\n",
|
|||
|
" \"Wins_X\": [],\n",
|
|||
|
" \"Wins_O\": [],\n",
|
|||
|
" \"Draws\": [],\n",
|
|||
|
"}\n",
|
|||
|
"\n",
|
|||
|
"for episode in range(episodes):\n",
|
|||
|
" state = env.reset()\n",
|
|||
|
" done = False\n",
|
|||
|
" total_reward = 0\n",
|
|||
|
" wins_X = 0\n",
|
|||
|
" wins_O = 0\n",
|
|||
|
" draws = 0\n",
|
|||
|
"\n",
|
|||
|
" while not done:\n",
|
|||
|
" possible_actions = [i for i in range(9) if state[i] == 0]\n",
|
|||
|
" action = agent.select_action(state, possible_actions)\n",
|
|||
|
" next_state, reward, terminated, truncated, _ = env.step(action)\n",
|
|||
|
"\n",
|
|||
|
" # Обновление Q-таблицы\n",
|
|||
|
" agent.update(state, action, reward, next_state, possible_actions, terminated)\n",
|
|||
|
"\n",
|
|||
|
" # Статистика\n",
|
|||
|
" total_reward += reward\n",
|
|||
|
" if reward == 1:\n",
|
|||
|
" wins_X += 1\n",
|
|||
|
" if reward == -1:\n",
|
|||
|
" wins_O += 1\n",
|
|||
|
" elif reward == 0 and terminated: # Ничья\n",
|
|||
|
" draws += 1\n",
|
|||
|
"\n",
|
|||
|
" state = next_state\n",
|
|||
|
" done = terminated\n",
|
|||
|
"\n",
|
|||
|
" # Сохраняем данные по эпизоду\n",
|
|||
|
" statistics[\"Episode\"].append(episode + 1)\n",
|
|||
|
" statistics[\"Total Reward\"].append(total_reward)\n",
|
|||
|
" statistics[\"Wins_X\"].append(wins_X)\n",
|
|||
|
" statistics[\"Wins_O\"].append(wins_O)\n",
|
|||
|
" statistics[\"Draws\"].append(draws)\n",
|
|||
|
"\n",
|
|||
|
" # Отображение прогресса каждые 50 эпизодов\n",
|
|||
|
" if (episode + 1) % 50 == 0:\n",
|
|||
|
" average_reward = sum(statistics[\"Total Reward\"][-50:]) / 50\n",
|
|||
|
" print(f\"Average Reward (last 50 episodes): {average_reward}\")\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"# Построение кумулятивного графика\n",
|
|||
|
"plt.figure(figsize=(12, 6))\n",
|
|||
|
"\n",
|
|||
|
"episodes_range = statistics[\"Episode\"]\n",
|
|||
|
"wins_X = np.cumsum(statistics[\"Wins_X\"])\n",
|
|||
|
"wins_O = np.cumsum(statistics[\"Wins_O\"])\n",
|
|||
|
"draws = np.cumsum(statistics[\"Draws\"])\n",
|
|||
|
"\n",
|
|||
|
"plt.plot(episodes_range, wins_X, label=\"Wins X\", color=\"red\")\n",
|
|||
|
"plt.plot(episodes_range, wins_O, label=\"Wins O\", color=\"blue\")\n",
|
|||
|
"plt.plot(episodes_range, draws, label=\"Draws\", color=\"green\")\n",
|
|||
|
"\n",
|
|||
|
"plt.xlabel(\"Episodes\")\n",
|
|||
|
"plt.ylabel(\"Cumulative Count\")\n",
|
|||
|
"plt.title(\"Cumulative Performance Over Episodes\")\n",
|
|||
|
"plt.legend()\n",
|
|||
|
"plt.grid()\n",
|
|||
|
"\n",
|
|||
|
"plt.tight_layout()\n",
|
|||
|
"plt.show()"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"metadata": {
|
|||
|
"kernelspec": {
|
|||
|
"display_name": "aimenv",
|
|||
|
"language": "python",
|
|||
|
"name": "python3"
|
|||
|
},
|
|||
|
"language_info": {
|
|||
|
"codemirror_mode": {
|
|||
|
"name": "ipython",
|
|||
|
"version": 3
|
|||
|
},
|
|||
|
"file_extension": ".py",
|
|||
|
"mimetype": "text/x-python",
|
|||
|
"name": "python",
|
|||
|
"nbconvert_exporter": "python",
|
|||
|
"pygments_lexer": "ipython3",
|
|||
|
"version": "3.12.5"
|
|||
|
}
|
|||
|
},
|
|||
|
"nbformat": 4,
|
|||
|
"nbformat_minor": 2
|
|||
|
}
|