How can I pause / serialize a genetic algorithm in Encog?

How to pause the genetic algorithm in Encog 3.4 (the version currently being developed on Github)?

I am using Encog version for Java.

I am trying to modify the Lunar example that comes with Encog. I want to pause / serialize the genetic algorithm, and then continue / deserialize at a later stage.

When I call train.pause(); , it just returns null - this is pretty obvious from the code, since the method always returns null .

I would suggest that this would be fairly straightforward, as there might be a scenario in which I want to train a neural network, use it for some predictions, and then continue learning the genetic algorithm as I get more data before resuming more predictions - without the need to restart training from the very beginning.

Please note that I am not trying to serialize or save the neural network, but rather the entire genetic algorithm.

+5
source share
1 answer

Not all Encog trainers support simple pause / resume. If they do not support it, they return zero, like this one. The genetic algorithm trainer is much more complex than a simple propagation trainer that supports pause / resume. To preserve the state of the genetic algorithm, you must save the entire population as well as the scoring function (which may or may not be serializable). I modified the Lunar Lander example to show you how you can save / reload your population of neural networks to do this.

You can see that he trains 50 iterations, then makes circular trips (loads / saves) the genetic algorithm, then trains another 50.

 package org.encog.examples.neural.lunar; import java.io.File; import java.io.IOException; import org.encog.Encog; import org.encog.engine.network.activation.ActivationTANH; import org.encog.ml.MLMethod; import org.encog.ml.MLResettable; import org.encog.ml.MethodFactory; import org.encog.ml.ea.population.Population; import org.encog.ml.genetic.MLMethodGeneticAlgorithm; import org.encog.ml.genetic.MLMethodGenomeFactory; import org.encog.neural.networks.BasicNetwork; import org.encog.neural.pattern.FeedForwardPattern; import org.encog.util.obj.SerializeObject; public class LunarLander { public static BasicNetwork createNetwork() { FeedForwardPattern pattern = new FeedForwardPattern(); pattern.setInputNeurons(3); pattern.addHiddenLayer(50); pattern.setOutputNeurons(1); pattern.setActivationFunction(new ActivationTANH()); BasicNetwork network = (BasicNetwork)pattern.generate(); network.reset(); return network; } public static void saveMLMethodGeneticAlgorithm(String file, MLMethodGeneticAlgorithm ga ) throws IOException { ga.getGenetic().getPopulation().setGenomeFactory(null); SerializeObject.save(new File(file),ga.getGenetic().getPopulation()); } public static MLMethodGeneticAlgorithm loadMLMethodGeneticAlgorithm(String filename) throws ClassNotFoundException, IOException { Population pop = (Population) SerializeObject.load(new File(filename)); pop.setGenomeFactory(new MLMethodGenomeFactory(new MethodFactory(){ @Override public MLMethod factor() { final BasicNetwork result = createNetwork(); ((MLResettable)result).reset(); return result; }},pop)); MLMethodGeneticAlgorithm result = new MLMethodGeneticAlgorithm(new MethodFactory(){ @Override public MLMethod factor() { return createNetwork(); }},new PilotScore(),1); result.getGenetic().setPopulation(pop); return result; } public static void main(String args[]) { BasicNetwork network = createNetwork(); MLMethodGeneticAlgorithm train; train = new MLMethodGeneticAlgorithm(new MethodFactory(){ @Override public MLMethod factor() { final BasicNetwork result = createNetwork(); ((MLResettable)result).reset(); return result; }},new PilotScore(),500); try { int epoch = 1; for(int i=0;i<50;i++) { train.iteration(); System.out .println("Epoch #" + epoch + " Score:" + train.getError()); epoch++; } train.finishTraining(); // Round trip the GA and then train again LunarLander.saveMLMethodGeneticAlgorithm("/Users/jeff/projects/trainer.bin",train); train = LunarLander.loadMLMethodGeneticAlgorithm("/Users/jeff/projects/trainer.bin"); // Train again for(int i=0;i<50;i++) { train.iteration(); System.out .println("Epoch #" + epoch + " Score:" + train.getError()); epoch++; } train.finishTraining(); } catch(IOException ex) { ex.printStackTrace(); } catch (ClassNotFoundException e) { // TODO Auto-generated catch block e.printStackTrace(); } int epoch = 1; for(int i=0;i<50;i++) { train.iteration(); System.out .println("Epoch #" + epoch + " Score:" + train.getError()); epoch++; } train.finishTraining(); System.out.println("\nHow the winning network landed:"); network = (BasicNetwork)train.getMethod(); NeuralPilot pilot = new NeuralPilot(network,true); System.out.println(pilot.scorePilot()); Encog.getInstance().shutdown(); } } 
+4
source

All Articles