- IA simple basée sur un perceptron
This commit is contained in:
parent
33e475bc6c
commit
541d71e0a5
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,55 @@
|
|||
from tron.player import Player, Direction
|
||||
from tron.game import Tile
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
|
||||
import os
|
||||
|
||||
class Net(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.fc1 = nn.Linear(144, 4)
|
||||
|
||||
def forward(self, x):
|
||||
n_images = x.shape[0]
|
||||
x = x.view(n_images, -1)
|
||||
return self.fc1(x)
|
||||
|
||||
|
||||
class Ai(Player):
|
||||
"""
|
||||
This class implements an AI based on the perceptron defined in class Net
|
||||
"""
|
||||
def __init__(self):
|
||||
super(Ai, self).__init__()
|
||||
self.net = Net()
|
||||
# Load network weights if they have been initialized already
|
||||
exists = os.path.isfile(self.find_file('ai.bak'))
|
||||
if exists:
|
||||
self.net.load_state_dict(torch.load(self.find_file('ai.bak')))
|
||||
|
||||
def action(self, map, id):
|
||||
|
||||
game_map = map.state_for_player(id)
|
||||
|
||||
input = np.reshape(game_map, (1, 1, game_map.shape[0], game_map.shape[1]))
|
||||
input = torch.from_numpy(input).float()
|
||||
output = self.net(input)
|
||||
|
||||
_, predicted = torch.max(output.data, 1)
|
||||
predicted = predicted.numpy()
|
||||
next_action = predicted[0] + 1
|
||||
|
||||
if next_action == 1:
|
||||
next_direction = Direction.UP
|
||||
if next_action == 2:
|
||||
next_direction = Direction.RIGHT
|
||||
if next_action == 3:
|
||||
next_direction = Direction.DOWN
|
||||
if next_action == 4:
|
||||
next_direction = Direction.LEFT
|
||||
|
||||
return next_direction
|
Loading…
Reference in New Issue