SHOW:
|
|
- or go back to the newest paste.
1 | -- MarI/O by SethBling | |
2 | -- Feel free to use this code, but please do not redistribute it. | |
3 | -- Intended for use with the BizHawk emulator and Super Mario World or Super Mario Bros. ROM. | |
4 | -- For SMW, make sure you have a save state named "DP1.state" at the beginning of a level, | |
5 | -- and put a copy in both the Lua folder and the root directory of BizHawk. | |
6 | ||
7 | if gameinfo.getromname() == "Super Mario World (USA)" then | |
8 | Filename = "DP1.state" | |
9 | ButtonNames = { | |
10 | "A", | |
11 | "B", | |
12 | "X", | |
13 | "Y", | |
14 | "Up", | |
15 | "Down", | |
16 | "Left", | |
17 | "Right", | |
18 | } | |
19 | elseif gameinfo.getromname() == "Super Mario Bros." then | |
20 | Filename = "SMB1-1.state" | |
21 | ButtonNames = { | |
22 | "A", | |
23 | "B", | |
24 | "Up", | |
25 | "Down", | |
26 | "Left", | |
27 | "Right", | |
28 | } | |
29 | end | |
30 | ||
31 | BoxRadius = 6 | |
32 | InputSize = (BoxRadius*2+1)*(BoxRadius*2+1) | |
33 | ||
34 | Inputs = InputSize+1 | |
35 | Outputs = #ButtonNames | |
36 | ||
37 | Population = 300 | |
38 | DeltaDisjoint = 2.0 | |
39 | DeltaWeights = 0.4 | |
40 | DeltaThreshold = 1.0 | |
41 | ||
42 | StaleSpecies = 15 | |
43 | ||
44 | MutateConnectionsChance = 0.25 | |
45 | PerturbChance = 0.90 | |
46 | CrossoverChance = 0.75 | |
47 | LinkMutationChance = 2.0 | |
48 | NodeMutationChance = 0.50 | |
49 | BiasMutationChance = 0.40 | |
50 | StepSize = 0.1 | |
51 | DisableMutationChance = 0.4 | |
52 | EnableMutationChance = 0.2 | |
53 | ||
54 | TimeoutConstant = 20 | |
55 | ||
56 | MaxNodes = 1000000 | |
57 | ||
58 | function getPositions() | |
59 | if gameinfo.getromname() == "Super Mario World (USA)" then | |
60 | marioX = memory.read_s16_le(0x94) | |
61 | marioY = memory.read_s16_le(0x96) | |
62 | ||
63 | local layer1x = memory.read_s16_le(0x1A); | |
64 | local layer1y = memory.read_s16_le(0x1C); | |
65 | ||
66 | screenX = marioX-layer1x | |
67 | screenY = marioY-layer1y | |
68 | elseif gameinfo.getromname() == "Super Mario Bros." then | |
69 | marioX = memory.readbyte(0x6D) * 0x100 + memory.readbyte(0x86) | |
70 | marioY = memory.readbyte(0x03B8)+16 | |
71 | ||
72 | screenX = memory.readbyte(0x03AD) | |
73 | screenY = memory.readbyte(0x03B8) | |
74 | end | |
75 | end | |
76 | ||
77 | function getTile(dx, dy) | |
78 | if gameinfo.getromname() == "Super Mario World (USA)" then | |
79 | x = math.floor((marioX+dx+8)/16) | |
80 | y = math.floor((marioY+dy)/16) | |
81 | ||
82 | return memory.readbyte(0x1C800 + math.floor(x/0x10)*0x1B0 + y*0x10 + x%0x10) | |
83 | elseif gameinfo.getromname() == "Super Mario Bros." then | |
84 | local x = marioX + dx + 8 | |
85 | local y = marioY + dy - 16 | |
86 | local page = math.floor(x/256)%2 | |
87 | ||
88 | local subx = math.floor((x%256)/16) | |
89 | local suby = math.floor((y - 32)/16) | |
90 | local addr = 0x500 + page*13*16+suby*16+subx | |
91 | ||
92 | if suby >= 13 or suby < 0 then | |
93 | return 0 | |
94 | end | |
95 | ||
96 | if memory.readbyte(addr) ~= 0 then | |
97 | return 1 | |
98 | else | |
99 | return 0 | |
100 | end | |
101 | end | |
102 | end | |
103 | ||
104 | function getSprites() | |
105 | if gameinfo.getromname() == "Super Mario World (USA)" then | |
106 | local sprites = {} | |
107 | for slot=0,11 do | |
108 | local status = memory.readbyte(0x14C8+slot) | |
109 | if status ~= 0 then | |
110 | spritex = memory.readbyte(0xE4+slot) + memory.readbyte(0x14E0+slot)*256 | |
111 | spritey = memory.readbyte(0xD8+slot) + memory.readbyte(0x14D4+slot)*256 | |
112 | sprites[#sprites+1] = {["x"]=spritex, ["y"]=spritey} | |
113 | end | |
114 | end | |
115 | ||
116 | return sprites | |
117 | elseif gameinfo.getromname() == "Super Mario Bros." then | |
118 | local sprites = {} | |
119 | for slot=0,4 do | |
120 | local enemy = memory.readbyte(0xF+slot) | |
121 | if enemy ~= 0 then | |
122 | local ex = memory.readbyte(0x6E + slot)*0x100 + memory.readbyte(0x87+slot) | |
123 | local ey = memory.readbyte(0xCF + slot)+24 | |
124 | sprites[#sprites+1] = {["x"]=ex,["y"]=ey} | |
125 | end | |
126 | end | |
127 | ||
128 | return sprites | |
129 | end | |
130 | end | |
131 | ||
132 | function getExtendedSprites() | |
133 | if gameinfo.getromname() == "Super Mario World (USA)" then | |
134 | local extended = {} | |
135 | for slot=0,11 do | |
136 | local number = memory.readbyte(0x170B+slot) | |
137 | if number ~= 0 then | |
138 | spritex = memory.readbyte(0x171F+slot) + memory.readbyte(0x1733+slot)*256 | |
139 | spritey = memory.readbyte(0x1715+slot) + memory.readbyte(0x1729+slot)*256 | |
140 | extended[#extended+1] = {["x"]=spritex, ["y"]=spritey} | |
141 | end | |
142 | end | |
143 | ||
144 | return extended | |
145 | elseif gameinfo.getromname() == "Super Mario Bros." then | |
146 | return {} | |
147 | end | |
148 | end | |
149 | ||
150 | function getInputs() | |
151 | getPositions() | |
152 | ||
153 | sprites = getSprites() | |
154 | extended = getExtendedSprites() | |
155 | ||
156 | local inputs = {} | |
157 | ||
158 | for dy=-BoxRadius*16,BoxRadius*16,16 do | |
159 | for dx=-BoxRadius*16,BoxRadius*16,16 do | |
160 | inputs[#inputs+1] = 0 | |
161 | ||
162 | tile = getTile(dx, dy) | |
163 | if tile == 1 and marioY+dy < 0x1B0 then | |
164 | inputs[#inputs] = 1 | |
165 | end | |
166 | ||
167 | for i = 1,#sprites do | |
168 | distx = math.abs(sprites[i]["x"] - (marioX+dx)) | |
169 | disty = math.abs(sprites[i]["y"] - (marioY+dy)) | |
170 | if distx <= 8 and disty <= 8 then | |
171 | inputs[#inputs] = -1 | |
172 | end | |
173 | end | |
174 | ||
175 | for i = 1,#extended do | |
176 | distx = math.abs(extended[i]["x"] - (marioX+dx)) | |
177 | disty = math.abs(extended[i]["y"] - (marioY+dy)) | |
178 | if distx < 8 and disty < 8 then | |
179 | inputs[#inputs] = -1 | |
180 | end | |
181 | end | |
182 | end | |
183 | end | |
184 | ||
185 | --mariovx = memory.read_s8(0x7B) | |
186 | --mariovy = memory.read_s8(0x7D) | |
187 | ||
188 | return inputs | |
189 | end | |
190 | ||
191 | function sigmoid(x) | |
192 | return 2/(1+math.exp(-4.9*x))-1 | |
193 | end | |
194 | ||
195 | function newInnovation() | |
196 | pool.innovation = pool.innovation + 1 | |
197 | return pool.innovation | |
198 | end | |
199 | ||
200 | function newPool() | |
201 | local pool = {} | |
202 | pool.species = {} | |
203 | pool.generation = 0 | |
204 | pool.innovation = Outputs | |
205 | pool.currentSpecies = 1 | |
206 | pool.currentGenome = 1 | |
207 | pool.currentFrame = 0 | |
208 | pool.maxFitness = 0 | |
209 | ||
210 | return pool | |
211 | end | |
212 | ||
213 | function newSpecies() | |
214 | local species = {} | |
215 | species.topFitness = 0 | |
216 | species.staleness = 0 | |
217 | species.genomes = {} | |
218 | species.averageFitness = 0 | |
219 | ||
220 | return species | |
221 | end | |
222 | ||
223 | function newGenome() | |
224 | local genome = {} | |
225 | genome.genes = {} | |
226 | genome.fitness = 0 | |
227 | genome.adjustedFitness = 0 | |
228 | genome.network = {} | |
229 | genome.maxneuron = 0 | |
230 | genome.globalRank = 0 | |
231 | genome.mutationRates = {} | |
232 | genome.mutationRates["connections"] = MutateConnectionsChance | |
233 | genome.mutationRates["link"] = LinkMutationChance | |
234 | genome.mutationRates["bias"] = BiasMutationChance | |
235 | genome.mutationRates["node"] = NodeMutationChance | |
236 | genome.mutationRates["enable"] = EnableMutationChance | |
237 | genome.mutationRates["disable"] = DisableMutationChance | |
238 | genome.mutationRates["step"] = StepSize | |
239 | ||
240 | return genome | |
241 | end | |
242 | ||
243 | function copyGenome(genome) | |
244 | local genome2 = newGenome() | |
245 | for g=1,#genome.genes do | |
246 | table.insert(genome2.genes, copyGene(genome.genes[g])) | |
247 | end | |
248 | genome2.maxneuron = genome.maxneuron | |
249 | genome2.mutationRates["connections"] = genome.mutationRates["connections"] | |
250 | genome2.mutationRates["link"] = genome.mutationRates["link"] | |
251 | genome2.mutationRates["bias"] = genome.mutationRates["bias"] | |
252 | genome2.mutationRates["node"] = genome.mutationRates["node"] | |
253 | genome2.mutationRates["enable"] = genome.mutationRates["enable"] | |
254 | genome2.mutationRates["disable"] = genome.mutationRates["disable"] | |
255 | ||
256 | return genome2 | |
257 | end | |
258 | ||
259 | function basicGenome() | |
260 | local genome = newGenome() | |
261 | local innovation = 1 | |
262 | ||
263 | genome.maxneuron = Inputs | |
264 | mutate(genome) | |
265 | ||
266 | return genome | |
267 | end | |
268 | ||
269 | function newGene() | |
270 | local gene = {} | |
271 | gene.into = 0 | |
272 | gene.out = 0 | |
273 | gene.weight = 0.0 | |
274 | gene.enabled = true | |
275 | gene.innovation = 0 | |
276 | ||
277 | return gene | |
278 | end | |
279 | ||
280 | function copyGene(gene) | |
281 | local gene2 = newGene() | |
282 | gene2.into = gene.into | |
283 | gene2.out = gene.out | |
284 | gene2.weight = gene.weight | |
285 | gene2.enabled = gene.enabled | |
286 | gene2.innovation = gene.innovation | |
287 | ||
288 | return gene2 | |
289 | end | |
290 | ||
291 | function newNeuron() | |
292 | local neuron = {} | |
293 | neuron.incoming = {} | |
294 | neuron.value = 0.0 | |
295 | ||
296 | return neuron | |
297 | end | |
298 | ||
299 | function generateNetwork(genome) | |
300 | local network = {} | |
301 | network.neurons = {} | |
302 | ||
303 | for i=1,Inputs do | |
304 | network.neurons[i] = newNeuron() | |
305 | end | |
306 | ||
307 | for o=1,Outputs do | |
308 | network.neurons[MaxNodes+o] = newNeuron() | |
309 | end | |
310 | ||
311 | table.sort(genome.genes, function (a,b) | |
312 | return (a.out < b.out) | |
313 | end) | |
314 | for i=1,#genome.genes do | |
315 | local gene = genome.genes[i] | |
316 | if gene.enabled then | |
317 | if network.neurons[gene.out] == nil then | |
318 | network.neurons[gene.out] = newNeuron() | |
319 | end | |
320 | local neuron = network.neurons[gene.out] | |
321 | table.insert(neuron.incoming, gene) | |
322 | if network.neurons[gene.into] == nil then | |
323 | network.neurons[gene.into] = newNeuron() | |
324 | end | |
325 | end | |
326 | end | |
327 | ||
328 | genome.network = network | |
329 | end | |
330 | ||
331 | function evaluateNetwork(network, inputs) | |
332 | table.insert(inputs, 1) | |
333 | if #inputs ~= Inputs then | |
334 | console.writeline("Incorrect number of neural network inputs.") | |
335 | return {} | |
336 | end | |
337 | ||
338 | for i=1,Inputs do | |
339 | network.neurons[i].value = inputs[i] | |
340 | end | |
341 | ||
342 | for _,neuron in pairs(network.neurons) do | |
343 | local sum = 0 | |
344 | for j = 1,#neuron.incoming do | |
345 | local incoming = neuron.incoming[j] | |
346 | local other = network.neurons[incoming.into] | |
347 | sum = sum + incoming.weight * other.value | |
348 | end | |
349 | ||
350 | if #neuron.incoming > 0 then | |
351 | neuron.value = sigmoid(sum) | |
352 | end | |
353 | end | |
354 | ||
355 | local outputs = {} | |
356 | for o=1,Outputs do | |
357 | local button = "P1 " .. ButtonNames[o] | |
358 | if network.neurons[MaxNodes+o].value > 0 then | |
359 | outputs[button] = true | |
360 | else | |
361 | outputs[button] = false | |
362 | end | |
363 | end | |
364 | ||
365 | return outputs | |
366 | end | |
367 | ||
368 | function crossover(g1, g2) | |
369 | -- Make sure g1 is the higher fitness genome | |
370 | if g2.fitness > g1.fitness then | |
371 | tempg = g1 | |
372 | g1 = g2 | |
373 | g2 = tempg | |
374 | end | |
375 | ||
376 | local child = newGenome() | |
377 | ||
378 | local innovations2 = {} | |
379 | for i=1,#g2.genes do | |
380 | local gene = g2.genes[i] | |
381 | innovations2[gene.innovation] = gene | |
382 | end | |
383 | ||
384 | for i=1,#g1.genes do | |
385 | local gene1 = g1.genes[i] | |
386 | local gene2 = innovations2[gene1.innovation] | |
387 | if gene2 ~= nil and math.random(2) == 1 and gene2.enabled then | |
388 | table.insert(child.genes, copyGene(gene2)) | |
389 | else | |
390 | table.insert(child.genes, copyGene(gene1)) | |
391 | end | |
392 | end | |
393 | ||
394 | child.maxneuron = math.max(g1.maxneuron,g2.maxneuron) | |
395 | ||
396 | for mutation,rate in pairs(g1.mutationRates) do | |
397 | child.mutationRates[mutation] = rate | |
398 | end | |
399 | ||
400 | return child | |
401 | end | |
402 | ||
403 | function randomNeuron(genes, nonInput) | |
404 | local neurons = {} | |
405 | if not nonInput then | |
406 | for i=1,Inputs do | |
407 | neurons[i] = true | |
408 | end | |
409 | end | |
410 | for o=1,Outputs do | |
411 | neurons[MaxNodes+o] = true | |
412 | end | |
413 | for i=1,#genes do | |
414 | if (not nonInput) or genes[i].into > Inputs then | |
415 | neurons[genes[i].into] = true | |
416 | end | |
417 | if (not nonInput) or genes[i].out > Inputs then | |
418 | neurons[genes[i].out] = true | |
419 | end | |
420 | end | |
421 | ||
422 | local count = 0 | |
423 | for _,_ in pairs(neurons) do | |
424 | count = count + 1 | |
425 | end | |
426 | local n = math.random(1, count) | |
427 | ||
428 | for k,v in pairs(neurons) do | |
429 | n = n-1 | |
430 | if n == 0 then | |
431 | return k | |
432 | end | |
433 | end | |
434 | ||
435 | return 0 | |
436 | end | |
437 | ||
438 | function containsLink(genes, link) | |
439 | for i=1,#genes do | |
440 | local gene = genes[i] | |
441 | if gene.into == link.into and gene.out == link.out then | |
442 | return true | |
443 | end | |
444 | end | |
445 | end | |
446 | ||
447 | function pointMutate(genome) | |
448 | local step = genome.mutationRates["step"] | |
449 | ||
450 | for i=1,#genome.genes do | |
451 | local gene = genome.genes[i] | |
452 | if math.random() < PerturbChance then | |
453 | gene.weight = gene.weight + math.random() * step*2 - step | |
454 | else | |
455 | gene.weight = math.random()*4-2 | |
456 | end | |
457 | end | |
458 | end | |
459 | ||
460 | function linkMutate(genome, forceBias) | |
461 | local neuron1 = randomNeuron(genome.genes, false) | |
462 | local neuron2 = randomNeuron(genome.genes, true) | |
463 | ||
464 | local newLink = newGene() | |
465 | if neuron1 <= Inputs and neuron2 <= Inputs then | |
466 | --Both input nodes | |
467 | return | |
468 | end | |
469 | if neuron2 <= Inputs then | |
470 | -- Swap output and input | |
471 | local temp = neuron1 | |
472 | neuron1 = neuron2 | |
473 | neuron2 = temp | |
474 | end | |
475 | ||
476 | newLink.into = neuron1 | |
477 | newLink.out = neuron2 | |
478 | if forceBias then | |
479 | newLink.into = Inputs | |
480 | end | |
481 | ||
482 | if containsLink(genome.genes, newLink) then | |
483 | return | |
484 | end | |
485 | newLink.innovation = newInnovation() | |
486 | newLink.weight = math.random()*4-2 | |
487 | ||
488 | table.insert(genome.genes, newLink) | |
489 | end | |
490 | ||
491 | function nodeMutate(genome) | |
492 | if #genome.genes == 0 then | |
493 | return | |
494 | end | |
495 | ||
496 | genome.maxneuron = genome.maxneuron + 1 | |
497 | ||
498 | local gene = genome.genes[math.random(1,#genome.genes)] | |
499 | if not gene.enabled then | |
500 | return | |
501 | end | |
502 | gene.enabled = false | |
503 | ||
504 | local gene1 = copyGene(gene) | |
505 | gene1.out = genome.maxneuron | |
506 | gene1.weight = 1.0 | |
507 | gene1.innovation = newInnovation() | |
508 | gene1.enabled = true | |
509 | table.insert(genome.genes, gene1) | |
510 | ||
511 | local gene2 = copyGene(gene) | |
512 | gene2.into = genome.maxneuron | |
513 | gene2.innovation = newInnovation() | |
514 | gene2.enabled = true | |
515 | table.insert(genome.genes, gene2) | |
516 | end | |
517 | ||
518 | function enableDisableMutate(genome, enable) | |
519 | local candidates = {} | |
520 | for _,gene in pairs(genome.genes) do | |
521 | if gene.enabled == not enable then | |
522 | table.insert(candidates, gene) | |
523 | end | |
524 | end | |
525 | ||
526 | if #candidates == 0 then | |
527 | return | |
528 | end | |
529 | ||
530 | local gene = candidates[math.random(1,#candidates)] | |
531 | gene.enabled = not gene.enabled | |
532 | end | |
533 | ||
534 | function mutate(genome) | |
535 | for mutation,rate in pairs(genome.mutationRates) do | |
536 | if math.random(1,2) == 1 then | |
537 | genome.mutationRates[mutation] = 0.95*rate | |
538 | else | |
539 | genome.mutationRates[mutation] = 1.05263*rate | |
540 | end | |
541 | end | |
542 | ||
543 | if math.random() < genome.mutationRates["connections"] then | |
544 | pointMutate(genome) | |
545 | end | |
546 | ||
547 | local p = genome.mutationRates["link"] | |
548 | while p > 0 do | |
549 | if math.random() < p then | |
550 | linkMutate(genome, false) | |
551 | end | |
552 | p = p - 1 | |
553 | end | |
554 | ||
555 | p = genome.mutationRates["bias"] | |
556 | while p > 0 do | |
557 | if math.random() < p then | |
558 | linkMutate(genome, true) | |
559 | end | |
560 | p = p - 1 | |
561 | end | |
562 | ||
563 | p = genome.mutationRates["node"] | |
564 | while p > 0 do | |
565 | if math.random() < p then | |
566 | nodeMutate(genome) | |
567 | end | |
568 | p = p - 1 | |
569 | end | |
570 | ||
571 | p = genome.mutationRates["enable"] | |
572 | while p > 0 do | |
573 | if math.random() < p then | |
574 | enableDisableMutate(genome, true) | |
575 | end | |
576 | p = p - 1 | |
577 | end | |
578 | ||
579 | p = genome.mutationRates["disable"] | |
580 | while p > 0 do | |
581 | if math.random() < p then | |
582 | enableDisableMutate(genome, false) | |
583 | end | |
584 | p = p - 1 | |
585 | end | |
586 | end | |
587 | ||
588 | function disjoint(genes1, genes2) | |
589 | local i1 = {} | |
590 | for i = 1,#genes1 do | |
591 | local gene = genes1[i] | |
592 | i1[gene.innovation] = true | |
593 | end | |
594 | ||
595 | local i2 = {} | |
596 | for i = 1,#genes2 do | |
597 | local gene = genes2[i] | |
598 | i2[gene.innovation] = true | |
599 | end | |
600 | ||
601 | local disjointGenes = 0 | |
602 | for i = 1,#genes1 do | |
603 | local gene = genes1[i] | |
604 | if not i2[gene.innovation] then | |
605 | disjointGenes = disjointGenes+1 | |
606 | end | |
607 | end | |
608 | ||
609 | for i = 1,#genes2 do | |
610 | local gene = genes2[i] | |
611 | if not i1[gene.innovation] then | |
612 | disjointGenes = disjointGenes+1 | |
613 | end | |
614 | end | |
615 | ||
616 | local n = math.max(#genes1, #genes2) | |
617 | ||
618 | return disjointGenes / n | |
619 | end | |
620 | ||
621 | function weights(genes1, genes2) | |
622 | local i2 = {} | |
623 | for i = 1,#genes2 do | |
624 | local gene = genes2[i] | |
625 | i2[gene.innovation] = gene | |
626 | end | |
627 | ||
628 | local sum = 0 | |
629 | local coincident = 0 | |
630 | for i = 1,#genes1 do | |
631 | local gene = genes1[i] | |
632 | if i2[gene.innovation] ~= nil then | |
633 | local gene2 = i2[gene.innovation] | |
634 | sum = sum + math.abs(gene.weight - gene2.weight) | |
635 | coincident = coincident + 1 | |
636 | end | |
637 | end | |
638 | ||
639 | return sum / coincident | |
640 | end | |
641 | ||
642 | function sameSpecies(genome1, genome2) | |
643 | local dd = DeltaDisjoint*disjoint(genome1.genes, genome2.genes) | |
644 | local dw = DeltaWeights*weights(genome1.genes, genome2.genes) | |
645 | return dd + dw < DeltaThreshold | |
646 | end | |
647 | ||
648 | function rankGlobally() | |
649 | local global = {} | |
650 | for s = 1,#pool.species do | |
651 | local species = pool.species[s] | |
652 | for g = 1,#species.genomes do | |
653 | table.insert(global, species.genomes[g]) | |
654 | end | |
655 | end | |
656 | table.sort(global, function (a,b) | |
657 | return (a.fitness < b.fitness) | |
658 | end) | |
659 | ||
660 | for g=1,#global do | |
661 | global[g].globalRank = g | |
662 | end | |
663 | end | |
664 | ||
665 | function calculateAverageFitness(species) | |
666 | local total = 0 | |
667 | ||
668 | for g=1,#species.genomes do | |
669 | local genome = species.genomes[g] | |
670 | total = total + genome.globalRank | |
671 | end | |
672 | ||
673 | species.averageFitness = total / #species.genomes | |
674 | end | |
675 | ||
676 | function totalAverageFitness() | |
677 | local total = 0 | |
678 | for s = 1,#pool.species do | |
679 | local species = pool.species[s] | |
680 | total = total + species.averageFitness | |
681 | end | |
682 | ||
683 | return total | |
684 | end | |
685 | ||
686 | function cullSpecies(cutToOne) | |
687 | for s = 1,#pool.species do | |
688 | local species = pool.species[s] | |
689 | ||
690 | table.sort(species.genomes, function (a,b) | |
691 | return (a.fitness > b.fitness) | |
692 | end) | |
693 | ||
694 | local remaining = math.ceil(#species.genomes/2) | |
695 | if cutToOne then | |
696 | remaining = 1 | |
697 | end | |
698 | while #species.genomes > remaining do | |
699 | table.remove(species.genomes) | |
700 | end | |
701 | end | |
702 | end | |
703 | ||
704 | function breedChild(species) | |
705 | local child = {} | |
706 | if math.random() < CrossoverChance then | |
707 | g1 = species.genomes[math.random(1, #species.genomes)] | |
708 | g2 = species.genomes[math.random(1, #species.genomes)] | |
709 | child = crossover(g1, g2) | |
710 | else | |
711 | g = species.genomes[math.random(1, #species.genomes)] | |
712 | child = copyGenome(g) | |
713 | end | |
714 | ||
715 | mutate(child) | |
716 | ||
717 | return child | |
718 | end | |
719 | ||
720 | function removeStaleSpecies() | |
721 | local survived = {} | |
722 | ||
723 | for s = 1,#pool.species do | |
724 | local species = pool.species[s] | |
725 | ||
726 | table.sort(species.genomes, function (a,b) | |
727 | return (a.fitness > b.fitness) | |
728 | end) | |
729 | ||
730 | if species.genomes[1].fitness > species.topFitness then | |
731 | species.topFitness = species.genomes[1].fitness | |
732 | species.staleness = 0 | |
733 | else | |
734 | species.staleness = species.staleness + 1 | |
735 | end | |
736 | if species.staleness < StaleSpecies or species.topFitness >= pool.maxFitness then | |
737 | table.insert(survived, species) | |
738 | end | |
739 | end | |
740 | ||
741 | pool.species = survived | |
742 | end | |
743 | ||
744 | function removeWeakSpecies() | |
745 | local survived = {} | |
746 | ||
747 | local sum = totalAverageFitness() | |
748 | for s = 1,#pool.species do | |
749 | local species = pool.species[s] | |
750 | breed = math.floor(species.averageFitness / sum * Population) | |
751 | if breed >= 1 then | |
752 | table.insert(survived, species) | |
753 | end | |
754 | end | |
755 | ||
756 | pool.species = survived | |
757 | end | |
758 | ||
759 | ||
760 | function addToSpecies(child) | |
761 | local foundSpecies = false | |
762 | for s=1,#pool.species do | |
763 | local species = pool.species[s] | |
764 | if not foundSpecies and sameSpecies(child, species.genomes[1]) then | |
765 | table.insert(species.genomes, child) | |
766 | foundSpecies = true | |
767 | end | |
768 | end | |
769 | ||
770 | if not foundSpecies then | |
771 | local childSpecies = newSpecies() | |
772 | table.insert(childSpecies.genomes, child) | |
773 | table.insert(pool.species, childSpecies) | |
774 | end | |
775 | end | |
776 | ||
777 | function newGeneration() | |
778 | cullSpecies(false) -- Cull the bottom half of each species | |
779 | rankGlobally() | |
780 | removeStaleSpecies() | |
781 | rankGlobally() | |
782 | for s = 1,#pool.species do | |
783 | local species = pool.species[s] | |
784 | calculateAverageFitness(species) | |
785 | end | |
786 | removeWeakSpecies() | |
787 | local sum = totalAverageFitness() | |
788 | local children = {} | |
789 | for s = 1,#pool.species do | |
790 | local species = pool.species[s] | |
791 | breed = math.floor(species.averageFitness / sum * Population) - 1 | |
792 | for i=1,breed do | |
793 | table.insert(children, breedChild(species)) | |
794 | end | |
795 | end | |
796 | cullSpecies(true) -- Cull all but the top member of each species | |
797 | while #children + #pool.species < Population do | |
798 | local species = pool.species[math.random(1, #pool.species)] | |
799 | table.insert(children, breedChild(species)) | |
800 | end | |
801 | for c=1,#children do | |
802 | local child = children[c] | |
803 | addToSpecies(child) | |
804 | end | |
805 | ||
806 | pool.generation = pool.generation + 1 | |
807 | ||
808 | writeFile("backup." .. pool.generation .. "." .. forms.gettext(saveLoadFile)) | |
809 | end | |
810 | ||
811 | function initializePool() | |
812 | pool = newPool() | |
813 | ||
814 | for i=1,Population do | |
815 | basic = basicGenome() | |
816 | addToSpecies(basic) | |
817 | end | |
818 | ||
819 | initializeRun() | |
820 | end | |
821 | ||
822 | function clearJoypad() | |
823 | controller = {} | |
824 | for b = 1,#ButtonNames do | |
825 | controller["P1 " .. ButtonNames[b]] = false | |
826 | end | |
827 | joypad.set(controller) | |
828 | end | |
829 | ||
830 | function initializeRun() | |
831 | savestate.load(Filename); | |
832 | rightmost = 0 | |
833 | pool.currentFrame = 0 | |
834 | timeout = TimeoutConstant | |
835 | clearJoypad() | |
836 | ||
837 | local species = pool.species[pool.currentSpecies] | |
838 | local genome = species.genomes[pool.currentGenome] | |
839 | generateNetwork(genome) | |
840 | evaluateCurrent() | |
841 | end | |
842 | ||
843 | function evaluateCurrent() | |
844 | local species = pool.species[pool.currentSpecies] | |
845 | local genome = species.genomes[pool.currentGenome] | |
846 | ||
847 | inputs = getInputs() | |
848 | controller = evaluateNetwork(genome.network, inputs) | |
849 | ||
850 | if controller["P1 Left"] and controller["P1 Right"] then | |
851 | controller["P1 Left"] = false | |
852 | controller["P1 Right"] = false | |
853 | end | |
854 | if controller["P1 Up"] and controller["P1 Down"] then | |
855 | controller["P1 Up"] = false | |
856 | controller["P1 Down"] = false | |
857 | end | |
858 | ||
859 | joypad.set(controller) | |
860 | end | |
861 | ||
862 | if pool == nil then | |
863 | initializePool() | |
864 | end | |
865 | ||
866 | ||
867 | function nextGenome() | |
868 | pool.currentGenome = pool.currentGenome + 1 | |
869 | if pool.currentGenome > #pool.species[pool.currentSpecies].genomes then | |
870 | pool.currentGenome = 1 | |
871 | pool.currentSpecies = pool.currentSpecies+1 | |
872 | if pool.currentSpecies > #pool.species then | |
873 | newGeneration() | |
874 | pool.currentSpecies = 1 | |
875 | end | |
876 | end | |
877 | end | |
878 | ||
879 | function fitnessAlreadyMeasured() | |
880 | local species = pool.species[pool.currentSpecies] | |
881 | local genome = species.genomes[pool.currentGenome] | |
882 | ||
883 | return genome.fitness ~= 0 | |
884 | end | |
885 | ||
886 | function displayGenome(genome) | |
887 | local network = genome.network | |
888 | local cells = {} | |
889 | local i = 1 | |
890 | local cell = {} | |
891 | for dy=-BoxRadius,BoxRadius do | |
892 | for dx=-BoxRadius,BoxRadius do | |
893 | cell = {} | |
894 | cell.x = 50+5*dx | |
895 | cell.y = 70+5*dy | |
896 | cell.value = network.neurons[i].value | |
897 | cells[i] = cell | |
898 | i = i + 1 | |
899 | end | |
900 | end | |
901 | local biasCell = {} | |
902 | biasCell.x = 80 | |
903 | biasCell.y = 110 | |
904 | biasCell.value = network.neurons[Inputs].value | |
905 | cells[Inputs] = biasCell | |
906 | ||
907 | for o = 1,Outputs do | |
908 | cell = {} | |
909 | cell.x = 220 | |
910 | cell.y = 30 + 8 * o | |
911 | cell.value = network.neurons[MaxNodes + o].value | |
912 | cells[MaxNodes+o] = cell | |
913 | local color | |
914 | if cell.value > 0 then | |
915 | color = 0xFF0000FF | |
916 | else | |
917 | color = 0xFF000000 | |
918 | end | |
919 | gui.drawText(223, 24+8*o, ButtonNames[o], color, 9) | |
920 | end | |
921 | ||
922 | for n,neuron in pairs(network.neurons) do | |
923 | cell = {} | |
924 | if n > Inputs and n <= MaxNodes then | |
925 | cell.x = 140 | |
926 | cell.y = 40 | |
927 | cell.value = neuron.value | |
928 | cells[n] = cell | |
929 | end | |
930 | end | |
931 | ||
932 | for n=1,4 do | |
933 | for _,gene in pairs(genome.genes) do | |
934 | if gene.enabled then | |
935 | local c1 = cells[gene.into] | |
936 | local c2 = cells[gene.out] | |
937 | if gene.into > Inputs and gene.into <= MaxNodes then | |
938 | c1.x = 0.75*c1.x + 0.25*c2.x | |
939 | if c1.x >= c2.x then | |
940 | c1.x = c1.x - 40 | |
941 | end | |
942 | if c1.x < 90 then | |
943 | c1.x = 90 | |
944 | end | |
945 | ||
946 | if c1.x > 220 then | |
947 | c1.x = 220 | |
948 | end | |
949 | c1.y = 0.75*c1.y + 0.25*c2.y | |
950 | ||
951 | end | |
952 | if gene.out > Inputs and gene.out <= MaxNodes then | |
953 | c2.x = 0.25*c1.x + 0.75*c2.x | |
954 | if c1.x >= c2.x then | |
955 | c2.x = c2.x + 40 | |
956 | end | |
957 | if c2.x < 90 then | |
958 | c2.x = 90 | |
959 | end | |
960 | if c2.x > 220 then | |
961 | c2.x = 220 | |
962 | end | |
963 | c2.y = 0.25*c1.y + 0.75*c2.y | |
964 | end | |
965 | end | |
966 | end | |
967 | end | |
968 | ||
969 | gui.drawBox(50-BoxRadius*5-3,70-BoxRadius*5-3,50+BoxRadius*5+2,70+BoxRadius*5+2,0xFF000000, 0x80808080) | |
970 | for n,cell in pairs(cells) do | |
971 | if n > Inputs or cell.value ~= 0 then | |
972 | local color = math.floor((cell.value+1)/2*256) | |
973 | if color > 255 then color = 255 end | |
974 | if color < 0 then color = 0 end | |
975 | local opacity = 0xFF000000 | |
976 | if cell.value == 0 then | |
977 | opacity = 0x50000000 | |
978 | end | |
979 | color = opacity + color*0x10000 + color*0x100 + color | |
980 | gui.drawBox(cell.x-2,cell.y-2,cell.x+2,cell.y+2,opacity,color) | |
981 | end | |
982 | end | |
983 | for _,gene in pairs(genome.genes) do | |
984 | if gene.enabled then | |
985 | local c1 = cells[gene.into] | |
986 | local c2 = cells[gene.out] | |
987 | local opacity = 0xA0000000 | |
988 | if c1.value == 0 then | |
989 | opacity = 0x20000000 | |
990 | end | |
991 | ||
992 | local color = 0x80-math.floor(math.abs(sigmoid(gene.weight))*0x80) | |
993 | if gene.weight > 0 then | |
994 | color = opacity + 0x8000 + 0x10000*color | |
995 | else | |
996 | color = opacity + 0x800000 + 0x100*color | |
997 | end | |
998 | gui.drawLine(c1.x+1, c1.y, c2.x-3, c2.y, color) | |
999 | end | |
1000 | end | |
1001 | ||
1002 | gui.drawBox(49,71,51,78,0x00000000,0x80FF0000) | |
1003 | ||
1004 | if forms.ischecked(showMutationRates) then | |
1005 | local pos = 100 | |
1006 | for mutation,rate in pairs(genome.mutationRates) do | |
1007 | gui.drawText(100, pos, mutation .. ": " .. rate, 0xFF000000, 10) | |
1008 | pos = pos + 8 | |
1009 | end | |
1010 | end | |
1011 | end | |
1012 | ||
1013 | function writeFile(filename) | |
1014 | local file = io.open(filename, "w") | |
1015 | file:write(pool.generation .. "\n") | |
1016 | file:write(pool.maxFitness .. "\n") | |
1017 | file:write(#pool.species .. "\n") | |
1018 | for n,species in pairs(pool.species) do | |
1019 | file:write(species.topFitness .. "\n") | |
1020 | file:write(species.staleness .. "\n") | |
1021 | file:write(#species.genomes .. "\n") | |
1022 | for m,genome in pairs(species.genomes) do | |
1023 | file:write(genome.fitness .. "\n") | |
1024 | file:write(genome.maxneuron .. "\n") | |
1025 | for mutation,rate in pairs(genome.mutationRates) do | |
1026 | file:write(mutation .. "\n") | |
1027 | file:write(rate .. "\n") | |
1028 | end | |
1029 | file:write("done\n") | |
1030 | ||
1031 | file:write(#genome.genes .. "\n") | |
1032 | for l,gene in pairs(genome.genes) do | |
1033 | file:write(gene.into .. " ") | |
1034 | file:write(gene.out .. " ") | |
1035 | file:write(gene.weight .. " ") | |
1036 | file:write(gene.innovation .. " ") | |
1037 | if(gene.enabled) then | |
1038 | file:write("1\n") | |
1039 | else | |
1040 | file:write("0\n") | |
1041 | end | |
1042 | end | |
1043 | end | |
1044 | end | |
1045 | file:close() | |
1046 | end | |
1047 | ||
1048 | function savePool() | |
1049 | local filename = forms.gettext(saveLoadFile) | |
1050 | writeFile(filename) | |
1051 | end | |
1052 | ||
1053 | function loadFile(filename) | |
1054 | local file = io.open(filename, "r") | |
1055 | pool = newPool() | |
1056 | pool.generation = file:read("*number") | |
1057 | pool.maxFitness = file:read("*number") | |
1058 | forms.settext(maxFitnessLabel, "Max Fitness: " .. math.floor(pool.maxFitness)) | |
1059 | local numSpecies = file:read("*number") | |
1060 | for s=1,numSpecies do | |
1061 | local species = newSpecies() | |
1062 | table.insert(pool.species, species) | |
1063 | species.topFitness = file:read("*number") | |
1064 | species.staleness = file:read("*number") | |
1065 | local numGenomes = file:read("*number") | |
1066 | for g=1,numGenomes do | |
1067 | local genome = newGenome() | |
1068 | table.insert(species.genomes, genome) | |
1069 | genome.fitness = file:read("*number") | |
1070 | genome.maxneuron = file:read("*number") | |
1071 | local line = file:read("*line") | |
1072 | while line ~= "done" do | |
1073 | genome.mutationRates[line] = file:read("*number") | |
1074 | line = file:read("*line") | |
1075 | end | |
1076 | local numGenes = file:read("*number") | |
1077 | for n=1,numGenes do | |
1078 | local gene = newGene() | |
1079 | table.insert(genome.genes, gene) | |
1080 | local enabled | |
1081 | gene.into, gene.out, gene.weight, gene.innovation, enabled = file:read("*number", "*number", "*number", "*number", "*number") | |
1082 | if enabled == 0 then | |
1083 | gene.enabled = false | |
1084 | else | |
1085 | gene.enabled = true | |
1086 | end | |
1087 | ||
1088 | end | |
1089 | end | |
1090 | end | |
1091 | file:close() | |
1092 | ||
1093 | while fitnessAlreadyMeasured() do | |
1094 | nextGenome() | |
1095 | end | |
1096 | initializeRun() | |
1097 | pool.currentFrame = pool.currentFrame + 1 | |
1098 | end | |
1099 | ||
1100 | function loadPool() | |
1101 | local filename = forms.gettext(saveLoadFile) | |
1102 | loadFile(filename) | |
1103 | end | |
1104 | ||
1105 | function playTop() | |
1106 | local maxfitness = 0 | |
1107 | local maxs, maxg | |
1108 | for s,species in pairs(pool.species) do | |
1109 | for g,genome in pairs(species.genomes) do | |
1110 | if genome.fitness > maxfitness then | |
1111 | maxfitness = genome.fitness | |
1112 | maxs = s | |
1113 | maxg = g | |
1114 | end | |
1115 | end | |
1116 | end | |
1117 | ||
1118 | pool.currentSpecies = maxs | |
1119 | pool.currentGenome = maxg | |
1120 | pool.maxFitness = maxfitness | |
1121 | forms.settext(maxFitnessLabel, "Max Fitness: " .. math.floor(pool.maxFitness)) | |
1122 | initializeRun() | |
1123 | pool.currentFrame = pool.currentFrame + 1 | |
1124 | return | |
1125 | end | |
1126 | ||
1127 | function onExit() | |
1128 | forms.destroy(form) | |
1129 | end | |
1130 | ||
1131 | writeFile("temp.pool") | |
1132 | ||
1133 | event.onexit(onExit) | |
1134 | ||
1135 | form = forms.newform(200, 260, "Fitness") | |
1136 | maxFitnessLabel = forms.label(form, "Max Fitness: " .. math.floor(pool.maxFitness), 5, 8) | |
1137 | showNetwork = forms.checkbox(form, "Show Map", 5, 30) | |
1138 | showMutationRates = forms.checkbox(form, "Show M-Rates", 5, 52) | |
1139 | restartButton = forms.button(form, "Restart", initializePool, 5, 77) | |
1140 | saveButton = forms.button(form, "Save", savePool, 5, 102) | |
1141 | loadButton = forms.button(form, "Load", loadPool, 80, 102) | |
1142 | saveLoadFile = forms.textbox(form, Filename .. ".pool", 170, 25, nil, 5, 148) | |
1143 | saveLoadLabel = forms.label(form, "Save/Load:", 5, 129) | |
1144 | playTopButton = forms.button(form, "Play Top", playTop, 5, 170) | |
1145 | hideBanner = forms.checkbox(form, "Hide Banner", 5, 190) | |
1146 | ||
1147 | ||
1148 | while true do | |
1149 | local backgroundColor = 0xD0FFFFFF | |
1150 | if not forms.ischecked(hideBanner) then | |
1151 | gui.drawBox(0, 0, 300, 26, backgroundColor, backgroundColor) | |
1152 | end | |
1153 | ||
1154 | local species = pool.species[pool.currentSpecies] | |
1155 | local genome = species.genomes[pool.currentGenome] | |
1156 | ||
1157 | if forms.ischecked(showNetwork) then | |
1158 | displayGenome(genome) | |
1159 | end | |
1160 | ||
1161 | if pool.currentFrame%5 == 0 then | |
1162 | evaluateCurrent() | |
1163 | end | |
1164 | ||
1165 | joypad.set(controller) | |
1166 | ||
1167 | getPositions() | |
1168 | if marioX > rightmost then | |
1169 | rightmost = marioX | |
1170 | timeout = TimeoutConstant | |
1171 | end | |
1172 | ||
1173 | timeout = timeout - 1 | |
1174 | ||
1175 | ||
1176 | local timeoutBonus = pool.currentFrame / 4 | |
1177 | if timeout + timeoutBonus <= 0 then | |
1178 | local fitness = rightmost - pool.currentFrame / 2 | |
1179 | if gameinfo.getromname() == "Super Mario World (USA)" and rightmost > 4816 then | |
1180 | fitness = fitness + 1000 | |
1181 | end | |
1182 | if gameinfo.getromname() == "Super Mario Bros." and rightmost > 3186 then | |
1183 | fitness = fitness + 1000 | |
1184 | end | |
1185 | if fitness == 0 then | |
1186 | fitness = -1 | |
1187 | end | |
1188 | genome.fitness = fitness | |
1189 | ||
1190 | if fitness > pool.maxFitness then | |
1191 | pool.maxFitness = fitness | |
1192 | forms.settext(maxFitnessLabel, "Max Fitness: " .. math.floor(pool.maxFitness)) | |
1193 | writeFile("backup." .. pool.generation .. "." .. forms.gettext(saveLoadFile)) | |
1194 | end | |
1195 | ||
1196 | console.writeline("Gen " .. pool.generation .. " species " .. pool.currentSpecies .. " genome " .. pool.currentGenome .. " fitness: " .. fitness) | |
1197 | pool.currentSpecies = 1 | |
1198 | pool.currentGenome = 1 | |
1199 | while fitnessAlreadyMeasured() do | |
1200 | nextGenome() | |
1201 | end | |
1202 | initializeRun() | |
1203 | end | |
1204 | ||
1205 | local measured = 0 | |
1206 | local total = 0 | |
1207 | for _,species in pairs(pool.species) do | |
1208 | for _,genome in pairs(species.genomes) do | |
1209 | total = total + 1 | |
1210 | if genome.fitness ~= 0 then | |
1211 | measured = measured + 1 | |
1212 | end | |
1213 | end | |
1214 | end | |
1215 | if not forms.ischecked(hideBanner) then | |
1216 | gui.drawText(0, 0, "Gen " .. pool.generation .. " species " .. pool.currentSpecies .. " genome " .. pool.currentGenome .. " (" .. math.floor(measured/total*100) .. "%)", 0xFF000000, 11) | |
1217 | gui.drawText(0, 12, "Fitness: " .. math.floor(rightmost - (pool.currentFrame) / 2 - (timeout + timeoutBonus)*2/3), 0xFF000000, 11) | |
1218 | gui.drawText(100, 12, "Max Fitness: " .. math.floor(pool.maxFitness), 0xFF000000, 11) | |
1219 | end | |
1220 | ||
1221 | pool.currentFrame = pool.currentFrame + 1 | |
1222 | ||
1223 | emu.frameadvance(); | |
1224 | end |