56 lines
1.4 KiB
Python
56 lines
1.4 KiB
Python
from tron.player import Player, Direction
|
|
from tron.game import Tile
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import numpy as np
|
|
|
|
import os
|
|
|
|
class Net(nn.Module):
|
|
|
|
def __init__(self):
|
|
super(Net, self).__init__()
|
|
self.fc1 = nn.Linear(144, 4)
|
|
|
|
def forward(self, x):
|
|
n_images = x.shape[0]
|
|
x = x.view(n_images, -1)
|
|
return self.fc1(x)
|
|
|
|
|
|
class Ai(Player):
|
|
"""
|
|
This class implements an AI based on the perceptron defined in class Net
|
|
"""
|
|
def __init__(self):
|
|
super(Ai, self).__init__()
|
|
self.net = Net()
|
|
# Load network weights if they have been initialized already
|
|
exists = os.path.isfile(self.find_file('ai.bak'))
|
|
if exists:
|
|
self.net.load_state_dict(torch.load(self.find_file('ai.bak')))
|
|
|
|
def action(self, map, id):
|
|
|
|
game_map = map.state_for_player(id)
|
|
|
|
input = np.reshape(game_map, (1, 1, game_map.shape[0], game_map.shape[1]))
|
|
input = torch.from_numpy(input).float()
|
|
output = self.net(input)
|
|
|
|
_, predicted = torch.max(output.data, 1)
|
|
predicted = predicted.numpy()
|
|
next_action = predicted[0] + 1
|
|
|
|
if next_action == 1:
|
|
next_direction = Direction.UP
|
|
if next_action == 2:
|
|
next_direction = Direction.RIGHT
|
|
if next_action == 3:
|
|
next_direction = Direction.DOWN
|
|
if next_action == 4:
|
|
next_direction = Direction.LEFT
|
|
|
|
return next_direction
|