Tidy code

- Don't mutate the elite
- Remove unused methods
- Remove roulette selection
This commit is contained in:
George Lacey 2017-10-16 16:19:21 +01:00
parent 300678156c
commit 17cf9a67db
3 changed files with 38 additions and 36 deletions

View File

@ -16,7 +16,9 @@ class Lifecycle(object):
self.iterations += 1 self.iterations += 1
elite = self.params["elite"] elite = self.params["elite"]
crossover = self.params["crossover"] crossover = self.params["crossover"]
self.population.advance_generation(elite, crossover_rate=crossover, n_arena=4) mutation = self.params["mutation"]
self.population.advance_generation(elite, crossover_rate=crossover,
n_arena=4, mutation=mutation)
self.best_fit.append(self.population.best_fitness()) self.best_fit.append(self.population.best_fitness())
self.average_fit.append(self.population.avg_fitness()) self.average_fit.append(self.population.avg_fitness())
if epoch > 50: if epoch > 50:
@ -24,7 +26,6 @@ class Lifecycle(object):
if max(recent_best) - min(recent_best) < 0.02: if max(recent_best) - min(recent_best) < 0.02:
break break
def best_fitness(self): def best_fitness(self):
return self.population.best_fitness() return self.population.best_fitness()
@ -32,11 +33,19 @@ class Lifecycle(object):
plt.plot(self.best_fit) plt.plot(self.best_fit)
plt.xlabel("Epoch") plt.xlabel("Epoch")
plt.ylabel("Best fitness") plt.ylabel("Best fitness")
plt.title(self.get_params())
if show: if show:
plt.show() plt.show()
if location is not None: if location is not None:
plt.savefig(location) plt.savefig(location)
def get_params(self):
return_string = ("mutation: %d " % self.params['mutation'])
return_string += ("crossover: %d " % self.params['crossover'])
return_string += ("elite: %d " % self.params['elite'])
return_string += ("iter: %d" % self.iterations)
return return_string
def get_csv(self): def get_csv(self):
id = {'id': self.id} id = {'id': self.id}
best_fitness = {'best_fit': self.best_fitness()} best_fitness = {'best_fit': self.best_fitness()}

View File

@ -1,4 +1,3 @@
import os
from multiprocessing import Process, Queue from multiprocessing import Process, Queue
from lifecycle import Lifecycle from lifecycle import Lifecycle
from file import File from file import File
@ -8,14 +7,16 @@ import time
param_example = {'population_size': 10, param_example = {'population_size': 10,
'elite': round(0.1 * 10), 'elite': round(0.1 * 10),
'crossover': round(0.6 * 10), 'crossover': round(0.6 * 10),
'mutation': 0,
'epochs': 1000, 'epochs': 1000,
'iter': 0} 'iter': 0}
def gen_param(pop, elite, crossover, epochs): def gen_param(pop, elite, crossover, epochs, mutation=0):
return {'population_size': pop, return {'population_size': pop,
'elite': round(elite * pop), 'elite': round((elite / 100) * pop),
'crossover': round(crossover * pop), 'crossover': round((crossover / 100) * pop),
'mutation': mutation,
'epochs': epochs} 'epochs': epochs}
@ -25,25 +26,23 @@ outputs = Queue()
def run_instance(instance, output_list): def run_instance(instance, output_list):
instance.start() instance.start()
output_list.put(instance.get_csv()) output_list.put(instance.get_csv())
instance.generate_graph(location="images/%d.png" % instance.id, show=False)
print(instance.id) print(instance.id)
def cls():
os.system('cls' if os.name == 'nt' else 'clear')
output_file = File("output.csv", ['id', 'best_fit'] + list(param_example)) output_file = File("output.csv", ['id', 'best_fit'] + list(param_example))
output_file.write_header() output_file.write_header()
instances = list() instances = list()
instance_id = 0
id = 0 for i in range(20, 90, 10):
for j in range(0, 3):
for i in [x * 1 for x in range(0, 100)]: instances.append(
instances.append(Process(name="Thread-%d" % i, target=run_instance, Process(name="Thread-%d" % i, target=run_instance,
args=[Lifecycle(id, gen_param(50, 0, 0.5, 500)), outputs])) args=[Lifecycle(instance_id,
id += 1 gen_param(50, j, i, 500, mutation=5)), outputs]))
instance_id += 1
start_time = time.time() start_time = time.time()

View File

@ -72,38 +72,32 @@ class Population(object):
return sorted_members return sorted_members
def advance_generation(self, n_elite, crossover_rate=0.5, n_arena=4): def advance_generation(self, n_elite, crossover_rate=0.5, n_arena=4, mutation=0):
new_generation = list() new_generation = list()
# elitism while len(new_generation) < len(self.members) - n_elite:
for member in self.elite(n_elite):
new_generation.append(member)
# parent
# for ind in range(0, n_crossover):
# self.roulette_crossover(new_generation)
while len(new_generation) < len(self.members):
if n_arena > 0: if n_arena > 0:
x, y = self.tournament_selection(n_arena, crossover_rate) x, y = self.tournament_selection(n_arena, crossover_rate)
else:
x, y = self.roulette_crossover()
new_generation.append(x) new_generation.append(x)
new_generation.append(y) new_generation.append(y)
self.members = new_generation # Remove excess members
while len(new_generation) > len(self.members) - n_elite:
new_generation.pop()
self.mutate(10) for member in new_generation:
if rand.random() < mutation/100:
member.mutate()
# elitism
new_generation += self.elite(n_elite)
self.members = new_generation
def remove_member(self, member): def remove_member(self, member):
self.members.remove(member) self.members.remove(member)
def roulette_crossover(self):
parent_one = self.roulette()
parent_two = self.roulette()
return parent_one.crossover(parent_two)
def tournament_selection(self, arena_size, rate): def tournament_selection(self, arena_size, rate):
parents = list() parents = list()
for i in range(arena_size): for i in range(arena_size):