From 0ff13b8e2ff07ba7dbda37b6d5b0db85fe104f7d Mon Sep 17 00:00:00 2001 From: George Lacey Date: Thu, 12 Oct 2017 16:45:30 +0100 Subject: [PATCH] Run example instance - Save results to csv - Generate csv attributes --- src/file.py | 14 ++++++++++++-- src/lifecycle.py | 16 ++++++++++------ src/main.py | 19 +++++++++++++++---- src/population.py | 4 ++-- 4 files changed, 39 insertions(+), 14 deletions(-) diff --git a/src/file.py b/src/file.py index af445cf..9e61e98 100644 --- a/src/file.py +++ b/src/file.py @@ -2,5 +2,15 @@ import csv class File(object): - def __init__(self, filename): - pass + def __init__(self, filename, header): + self.file = open(filename, 'w') + self.writer = csv.DictWriter(self.file, fieldnames=header) + + def write_header(self): + self.writer.writeheader() + + def write_row(self, row): + self.writer.writerow(row) + + def close(self): + self.file.close() diff --git a/src/lifecycle.py b/src/lifecycle.py index 9b20149..0f71a25 100644 --- a/src/lifecycle.py +++ b/src/lifecycle.py @@ -12,15 +12,14 @@ class Lifecycle(object): def start(self): for epoch in range(0, self.params["iter"]): - elite = round(self.params["elite"] * self.params["population_size"]) - crossover = round((self.params["crossover"] - * self.params["population_size"]) / 2) + elite = self.params["elite"] + crossover = self.params["crossover"] self.population.advance_generation(elite, crossover) self.best_fit.append(self.population.best_fitness()) self.average_fit.append(self.population.avg_fitness()) - def best_member(self): - return self.population.best_member() + def best_fitness(self): + return self.population.best_fitness() def generate_graph(self, show=False, location=None): plt.plot(self.best_fit) @@ -29,4 +28,9 @@ class Lifecycle(object): if show: plt.show() if location is not None: - plt.savefig(location) \ No newline at end of file + plt.savefig(location) + + def get_csv(self): + id = {'id': self.id} + best_fitness = {'best_fit': self.best_fitness()} + return {**id, **best_fitness, **self.params} diff --git a/src/main.py b/src/main.py index ffb3f09..7cf3bd2 100644 --- a/src/main.py +++ b/src/main.py @@ -1,16 +1,27 @@ import os from lifecycle import Lifecycle +from file import File + +param_example = {'population_size': 10, + 'elite': round(0.1 * 10), + 'crossover': round(0.6 * 10), + 'iter': 1000} def cls(): os.system('cls' if os.name == 'nt' else 'clear') -ga = Lifecycle(1, {'population_size': 10, - 'elite': 0.1, - 'crossover': 0.6, - 'iter': 1000}) +output_file = File("output.csv", ['id', 'best_fit'] + list(param_example)) + +ga = Lifecycle(1, param_example) ga.start() +output_file.write_header() + +output_file.write_row(ga.get_csv()) + +output_file.close() + ga.generate_graph(show=True) diff --git a/src/population.py b/src/population.py index 53021d8..3d3bc22 100644 --- a/src/population.py +++ b/src/population.py @@ -68,7 +68,7 @@ class Population(object): def elite(self, amount): sorted_members = sorted(self.members, key=lambda - x: x.fitness(), reverse=True)[:amount] + x: x.fitness(), reverse=True)[:amount] return sorted_members @@ -92,7 +92,7 @@ class Population(object): self.members = new_generation - self.mutate(20) + self.mutate(10) def remove_member(self, member): self.members.remove(member)