Run example instance

- Save results to csv
- Generate csv attributes
This commit is contained in:
George Lacey 2017-10-12 16:45:30 +01:00
parent 597f289e81
commit 0ff13b8e2f
4 changed files with 39 additions and 14 deletions

View File

@ -2,5 +2,15 @@ import csv
class File(object): class File(object):
def __init__(self, filename): def __init__(self, filename, header):
pass self.file = open(filename, 'w')
self.writer = csv.DictWriter(self.file, fieldnames=header)
def write_header(self):
self.writer.writeheader()
def write_row(self, row):
self.writer.writerow(row)
def close(self):
self.file.close()

View File

@ -12,15 +12,14 @@ class Lifecycle(object):
def start(self): def start(self):
for epoch in range(0, self.params["iter"]): for epoch in range(0, self.params["iter"]):
elite = round(self.params["elite"] * self.params["population_size"]) elite = self.params["elite"]
crossover = round((self.params["crossover"] crossover = self.params["crossover"]
* self.params["population_size"]) / 2)
self.population.advance_generation(elite, crossover) self.population.advance_generation(elite, crossover)
self.best_fit.append(self.population.best_fitness()) self.best_fit.append(self.population.best_fitness())
self.average_fit.append(self.population.avg_fitness()) self.average_fit.append(self.population.avg_fitness())
def best_member(self): def best_fitness(self):
return self.population.best_member() return self.population.best_fitness()
def generate_graph(self, show=False, location=None): def generate_graph(self, show=False, location=None):
plt.plot(self.best_fit) plt.plot(self.best_fit)
@ -30,3 +29,8 @@ class Lifecycle(object):
plt.show() plt.show()
if location is not None: if location is not None:
plt.savefig(location) plt.savefig(location)
def get_csv(self):
id = {'id': self.id}
best_fitness = {'best_fit': self.best_fitness()}
return {**id, **best_fitness, **self.params}

View File

@ -1,16 +1,27 @@
import os import os
from lifecycle import Lifecycle from lifecycle import Lifecycle
from file import File
param_example = {'population_size': 10,
'elite': round(0.1 * 10),
'crossover': round(0.6 * 10),
'iter': 1000}
def cls(): def cls():
os.system('cls' if os.name == 'nt' else 'clear') os.system('cls' if os.name == 'nt' else 'clear')
ga = Lifecycle(1, {'population_size': 10, output_file = File("output.csv", ['id', 'best_fit'] + list(param_example))
'elite': 0.1,
'crossover': 0.6, ga = Lifecycle(1, param_example)
'iter': 1000})
ga.start() ga.start()
output_file.write_header()
output_file.write_row(ga.get_csv())
output_file.close()
ga.generate_graph(show=True) ga.generate_graph(show=True)

View File

@ -68,7 +68,7 @@ class Population(object):
def elite(self, amount): def elite(self, amount):
sorted_members = sorted(self.members, key=lambda sorted_members = sorted(self.members, key=lambda
x: x.fitness(), reverse=True)[:amount] x: x.fitness(), reverse=True)[:amount]
return sorted_members return sorted_members
@ -92,7 +92,7 @@ class Population(object):
self.members = new_generation self.members = new_generation
self.mutate(20) self.mutate(10)
def remove_member(self, member): def remove_member(self, member):
self.members.remove(member) self.members.remove(member)