xxxxxxxxxx
//<>// //<>// //<>// //<>//
// An animated drawing of a Neural Network
// Based on Daniel Shiffman's example in "The Nature of Code", http://natureofcode.com
// Modified by Blake Porter
// www.blakeporterneuro.com
// Creates a network of neurons using a specified number of layers of neurons, inputs, and axons
// Connection synapse weights are created randomly and adjust based on activity levels, showing long term potentiation for active neural pairs and long term depression for inactive synapses
// Version 3 - synapses/axons are made by the user
// to do
// you dont really need those buttons, make more intuitive
// dont need layers, just input and other
// Note - this code works well online in a browser with processing.js. It does not always work nicely natively in Processing.
int yoffset = 0; // cause browsers are off by (-)30 for some reason. Change to 0 for Processing IDE.
// also change mouseClicked to mousePressed
int bezz = 5;
float APSpeed = 9;
// for input button
int rectX, rectY;
int rectSize = 140;
color rectColor;
color rectHighlight;
boolean rectOver = false;
int inputFont = 28;
boolean clickedGen = false;
// Layer display
int layerFont = 50;
// for layer buttons
int rect2X, rect2YU, rect2YD;
int rect2Size = 75;
color rect2Color;
color rect2UpHighlight;
color rect2DownHighlight;
boolean rectUpOver = false;
boolean rectDownOver = false;
int layer2Font = rect2Size;
// for neuron type
int selType = 1;
int nTypeRectX_E, nTypeRectY_E, nTypeRectX_I, nTypeRectY_I;
boolean rectOverE = false;
boolean rectOverI = false;
boolean rectPresedE = true;
boolean rectPresedI = false;
int EIFont = 50;
int rectEISizeW = 150;
int rectEISizeH = 80;
color rectEI_base = color(0, 0, 0);
color rectE_highlight = color(0, 255, 0);
color rectI_highlight = color(255, 0, 0);
boolean holdPair = true;
boolean onNeuron = false;
int neuronID = 0;
int axons = 1; // the number of axons a neuron has
int currLayer = 0;
float randWMax = 0.001;
String input;
int baseDelay = 30;
int selectedNeuron;
int prevSelNeuron;
boolean logSelNeuron = false;
boolean overNothing = true;
//boolean skipN = false;
boolean wasGen = false;
boolean isGen = false;
// initialize
int startRand = 0;
int endRand = 0;
int lastMsec = 0;
int currMsec = 0;
int numLayers = 0;
Network network;
void setup() {
size(1200, 600);
rect2X = 10;
rect2YU = 10;
rect2YD = rect2YU+rect2Size+15;
rect2Color = color(0);
rect2UpHighlight = color(128);
rect2DownHighlight = color(128);
rectColor = color(0);
rectHighlight = color(5, 178, 255);
rectX = width - rectEISizeW - 15 - rectSize;
rectY = rect2YU;
nTypeRectX_E = width - rectEISizeW - 10;
nTypeRectX_I = width - rectEISizeW - 10;
nTypeRectY_E = rect2YU;
nTypeRectY_I = rect2YU+rect2Size+15;
// Create the Network object
network = new Network(width/2, height/2);
}
void draw() {
background(255);
pushStyle();
fill(rect2Color);
rect(rect2X, rect2YU, rect2Size, rect2Size, bezz);
rect(rect2X, rect2YD, rect2Size, rect2Size, bezz);
fill(255);
String arrowUp = "↑";
String arrowDown = "↓";
textSize(layer2Font);
text(arrowUp, rect2X+(rect2Size/4), rect2Size-12);
text(arrowDown, rect2X+(rect2Size/4), rect2YD+rect2Size-20);
popStyle();
pushStyle();
stroke(0);
rect(rectX, rectY, rectSize, rect2Size*2+20, bezz);
String APprompt = "Generate Input";
fill(rectColor);
textSize(inputFont);
textAlign(CENTER);
stroke(0);
fill(255);
text(APprompt, rectX+4, rectY+40, rectSize, rectSize);
popStyle();
// Current layer info
pushStyle();
textSize(layerFont*1.5);
String layerText;
if (currLayer == 0) {
layerText = "Input";
} else {
layerText = "Hidden";
}
fill(0, 0, 0);
text(layerText, rect2Size+rect2X+10, layerFont*2);
popStyle();
// Neuron type
pushStyle();
textSize(EIFont);
if (rectPresedE) {
fill(rectE_highlight);
rect(nTypeRectX_E, nTypeRectY_E, rectEISizeW, rectEISizeH, bezz);
fill(0);
text("Excite", nTypeRectX_E+2, rect2Size-10);
} else {
fill(rectEI_base);
rect(nTypeRectX_E, nTypeRectY_E, rectEISizeW, rectEISizeH, bezz);
fill(255);
text("Excite", nTypeRectX_E+2, rect2Size-10);
}
if (rectPresedI) {
fill(rectI_highlight);
rect(nTypeRectX_I, nTypeRectY_I, rectEISizeW, rectEISizeH, bezz);
fill(255);
text("Inhib", nTypeRectX_I+15, nTypeRectY_I+rect2Size-15);
} else {
fill(rectEI_base);
rect(nTypeRectX_I, nTypeRectY_I, rectEISizeW, rectEISizeH, bezz);
fill(255);
text("Inhib", nTypeRectX_I+15, nTypeRectY_I+rect2Size-15);
}
popStyle();
if (keyPressed) {
int keyVal = key - 48;
if (keyVal >= 0 && keyVal <= 9) {
currLayer = keyVal;
}
if (currLayer < 0) {
currLayer = 0;
}
}
// You are doing a weird thing where update to check where mouse is is only done when the mouse is clicked
// so even though this is not within the mouseClicked event, rectOver only updated within mouseClicked
// this was to solve the flashing problem
if (isGen) {
currMsec = millis();
fill(rectHighlight);
if (currMsec > lastMsec + baseDelay) {
for (int i = 0; i < network.neurons.size(); i++) {
Neuron currN = network.neurons.get(i);
if (currN.layer == 0) {
currN.feedForward(50);
lastMsec = millis();
}
}
}
} else {
fill(rectColor);
}
// Update and display the Network
network.update();
network.display();
} // end draw
void mousePressed () { // change to mouseClicked for browsers, mousePressed for Processing IDE
//float x = mouseX + 0 - width/2;
//float y = mouseY -yoffset - height/2;
float x = mouseX + 0;
float y = mouseY -yoffset;
update(mouseX, mouseY-yoffset);
// println("__________");
// println("rectUpOver");
// println(rectUpOver);
// println("rectDownOver");
// println(rectDownOver);
// println("rectOver");
// println(rectOver);
// println("logSelNeuron");
// println(logSelNeuron);
// println("overNothing");
// println(overNothing);
// println("skipN");
// println(skipN);
//println("isGen");
//println(isGen);
if (!overNothing) {
if (rectUpOver) {
currLayer = currLayer+1;
currLayer = constrain(currLayer, 0, 1);
}
if (rectDownOver) {
currLayer = currLayer -1;
currLayer = constrain(currLayer, 0, 1);
}
if (rectOverE) {
rectPresedE = true;
rectPresedI = false;
selType = 1;
}
if (rectOverI) {
rectPresedI = true;
rectPresedE = false;
selType = 0;
}
//if (rectOverN) {
// recPressedN = true;
// recPressedS = false;
//}
//if (rectOverS) {
// recPressedS = true;
// recPressedN = false;
//}
} else if (!onNeuron) {
Neuron n = new Neuron(x, y, neuronID, currLayer, selType);
neuronID = neuronID + 1;
network.addNeuron(n);
prevSelNeuron = 0;
selectedNeuron = 0;
holdPair = true;
} else if (onNeuron && selectedNeuron != prevSelNeuron) {
Neuron nPre = network.neurons.get(selectedNeuron);
Neuron nPost = network.neurons.get(prevSelNeuron);
int newConnection = nPost.neuronNum;
nPre.addConnection(newConnection, random(0, randWMax));
nPost.addPostSyn(selectedNeuron);
selectedNeuron = 0;
prevSelNeuron = 0;
holdPair = true;
}
}
void update(int x, int y) {
overNothing = true;
onNeuron = false;
rectOver = false;
overNeuron();
if ( overRect(rectX, rectY, rectSize, rect2Size*2+20) && !clickedGen ) {
overNothing = false;
isGen = true;
clickedGen = true;
rectOver = true;
} else if ( overRect(rectX, rectY, rectSize, rect2Size*2+20) && clickedGen ) {
overNothing = false;
isGen = false;
wasGen = true;
clickedGen = false;
rectOver = true;
}
if (overRect(rect2X, rect2YU, rect2Size, rect2Size)) {
rectUpOver = true;
overNothing = false;
} else {
rectUpOver = false;
}
if (overRect(rect2X, rect2YD, rect2Size, rect2Size)) {
rectDownOver = true;
overNothing = false;
} else {
rectDownOver = false;
}
if (overRect(nTypeRectX_E, nTypeRectY_E, rectEISizeW, rectEISizeH)) {
rectOverE = true;
overNothing = false;
} else {
rectOverE = false;
}
if (overRect(nTypeRectX_I, nTypeRectY_I, rectEISizeW, rectEISizeH)) {
rectOverI = true;
overNothing = false;
} else {
rectOverI = false;
}
}
boolean overRect(int x, int y, int width, int height) {
if (mouseX >= x && mouseX <= x+width &&
mouseY-yoffset >= y && mouseY-yoffset <= y+height) {
return true;
} else {
return false;
}
}
void overNeuron() {
for (int i = 0; i <= network.neurons.size()-1; i++) {
Neuron curr = network.neurons.get(i);
if (mouseX >= curr.location.x-curr.r_base && mouseX <= curr.location.x+curr.r_base
&& mouseY-yoffset >= curr.location.y-curr.r_base && mouseY-yoffset <= curr.location.y+curr.r_base) {
onNeuron = true;
if (holdPair) {
prevSelNeuron = curr.neuronNum;
selectedNeuron = curr.neuronNum;
holdPair = false;
} else {
prevSelNeuron = selectedNeuron;
selectedNeuron = curr.neuronNum;
}
}
}
}
class ActionPotentials {
PVector location;
PVector receiver;
PVector sender;
int recNum;
float weight;
float lerpCounter;
float m;
float b;
float distance;
float r;
float d;
ActionPotentials(PVector loc, PVector rec, int recNumber, float w, float LC) {
location = new PVector(loc.x, loc.y);
receiver = new PVector(rec.x, rec.y);
sender = new PVector(loc.x, loc.y);
distance = sqrt((pow((rec.x-sender.x), 2))+(pow((rec.y-sender.y), 2)));
m = (rec.y-sender.y)/(rec.x-sender.x);
b = rec.y-(m*rec.x);
r = sqrt(1+pow(m, 2));
recNum = recNumber;
weight = w;
lerpCounter = LC;
}
void send() {
// https://math.stackexchange.com/questions/656500/given-a-point-slope-and-a-distance-along-that-slope-easily-find-a-second-p
if (sender.x <= receiver.x) {
location.x = sender.x+(d/r);
location.y = sender.y+((d*m)/r);
d = APSpeed + d;
} else {
location.x = sender.x-(d/r);
location.y = sender.y-((d*m)/r);
d = APSpeed + d;
}
}
}
int baseCurrent = 150;
float normInput = 200;
int leak = 6;
float LTP = 1.1;
float LTD = 0.25;
int maxW = 5;
int updateTime = 500;
int prevUpdate = 0;
int currUpdate = 0;
class Network {
// The Network has a list of neurons
ArrayList<Neuron> neurons;
// The Network now keeps a duplicate list of all Connection objects.
// This makes it easier to draw everything in this class
PVector location;
Network(float x, float y) {
location = new PVector(x, y);
neurons = new ArrayList<Neuron>();
}
// We can add a Neuron
void addNeuron(Neuron n) {
neurons.add(n);
}
// Sending an input to the first layer of neurons
void baseInput(int inputs) {
for (int i = 0; i < inputs; i++) {
Neuron n1 = neurons.get(i);
n1.feedForward(baseCurrent);
}
}
// Update connections
void update() {
if (network.neurons.size() > 1) {
currUpdate = millis();
if (currUpdate > prevUpdate + updateTime) {
for (Neuron n : neurons) {
n.checkFire();
n.MP = n.MP - leak;
n.MP = constrain(n.MP, n.minMP, n.spkT);
ArrayList currAPs = n.APs;
int countAPs = n.APs.size();
if (countAPs > 0) {
for (int i = countAPs-1; i >= 0; i--) {
ActionPotentials currAP = n.APs.get(i);
if (abs(currAP.location.x-currAP.receiver.x) < n.APr && abs(currAP.location.y-currAP.receiver.y) < n.APr) {
Neuron recN = network.neurons.get(currAP.recNum);
float currW = currAP.weight;
currW = constrain(currW, 1, maxW);
if (n.type == 1) {
recN.feedForward(normInput*currW);
} else if (
n.type == 0) {
recN.feedForward(normInput*currW*-1);
}
n.APs.remove(i);
} else {
currAP.send();
//currAP.location.x = lerp(currAP.sender.x, currAP.receiver.x, (currAP.lerpCounter/lerpVal));
//currAP.location.y = lerp(currAP.sender.y, currAP.receiver.y, (currAP.lerpCounter/lerpVal));
//currAP.lerpCounter ++;
//currAP.lerpCounter = constrain(currAP.lerpCounter, 0, lerpVal);
//currAP.location.x = lerp(currAP.location.x, currAP.receiver.x, 0.1);
//currAP.location.y = lerp(currAP.location.y, currAP.receiver.y, 0.1);
}
}
}
for (int i = 0; i < n.connections.length; i++) {
Neuron preN = network.neurons.get(n.connections[i]);
if (preN.layer == 0) {
if (n.justFired && preN.justFired && (millis() - preN.lastAP) < 500 && (millis() - n.lastAP) < 500) { // n.spkCount > 0 && preN.spkCount > 0 &&
n.weights[i] = n.weights[i] + LTP;
//n.weights.add(i, LTP);
float currW = n.weights[i];
currW = constrain(currW, 0, maxW);
n.weights[i] = currW;
} else {
n.weights[i] = n.weights[i] - LTD;
//n.weights.sub(i, LTD);
float currW = n.weights[i];
currW = constrain(currW, 0, maxW);
n.weights[i] = currW;
} // spike count
} else {
if (n.justFired && preN.justFired && (millis() - preN.lastAP) < 500 && (millis() - n.lastAP) < 500) { //n.spkCount >= preN.spkCount && n.spkCount > 0 && preN.spkCount > 0 &&
n.weights[i] = n.weights[i] + LTP;
float currW = n.weights[i];
currW = constrain(currW, 0, maxW);
n.weights[i] = currW;
} else {
n.weights[i] = n.weights[i] - LTD;
//n.weights.sub(i, LTD);
float currW = n.weights[i];
currW = constrain(currW, 0, maxW);
n.weights[i] = currW;
} // spike count
}
prevUpdate = millis();
} // for all connections
n.spkCount = 0;
} // all neurons
} else {
for (Neuron n : neurons) {
n.MP = n.MP - leak*0.2;
n.MP = constrain(n.MP, n.minMP, n.spkT);
ArrayList currAPs = n.APs;
int countAPs = n.APs.size();
if (countAPs > 0) {
for (int i = countAPs-1; i >= 0; i--) {
ActionPotentials currAP = n.APs.get(i);
if (abs(currAP.location.x-currAP.receiver.x) < n.APr && abs(currAP.location.y-currAP.receiver.y) < n.APr) {
Neuron recN = network.neurons.get(currAP.recNum);
float currW = currAP.weight;
currW = constrain(currW, 1, maxW);
if (n.type == 1) {
recN.feedForward(normInput*currW);
} else if (
n.type == 0) {
recN.feedForward(normInput*currW*-1);
}
n.APs.remove(i);
} else {
currAP.send();
//currAP.location.x = lerp(currAP.sender.x, currAP.receiver.x, (currAP.lerpCounter/lerpVal));
//currAP.location.y = lerp(currAP.sender.y, currAP.receiver.y, (currAP.lerpCounter/lerpVal));
//currAP.lerpCounter ++;
//currAP.lerpCounter = constrain(currAP.lerpCounter, 0, lerpVal);
// currAP.location.x = lerp(currAP.location.x, currAP.receiver.x, 0.1);
// currAP.location.y = lerp(currAP.location.y, currAP.receiver.y, 0.1);
}
}
}
}
} // time
}// size
}// update
// Draw everything
void display() {
//pushMatrix();
// translate(location.x, location.y);
for (Neuron n : neurons) {
n.displayAx();
}
for (Neuron n : neurons) {
n.displayN();
}
for (Neuron n : neurons) {
n.displayAP();
}
// popMatrix();
}
}
// An animated drawing of a Neural Network
// Based on Daniel Shiffman's example in "The Nature of Code", http://natureofcode.com
// Modified by Blake Porter
// www.blakeporterneuro.com
// Creates a network of neurons using a specified number of layers of neurons, inputs, and axons
// Connection synapse weights are created randomly and adjust based on activity levels, showing long term potentiation for active neural pairs and long term depression for inactive synapses
class Neuron {
int neuronNum;
PVector location; // Neuron has a start location, x and y
PVector centerLoc; // Neuron has a central location, x and y, in order to draw it's axon from
int prevSpkCount = 0; // how many spikes did it have
int spkCount = 0; // how many times it has fired since we last checked?
int layer;
boolean justFired = false;
int type;
int refract = 150; // refractory period in ms
int lastAP;
int RMP = 0;
int minMP = -500;
int spkT = 1000;
float MP = RMP; // the intracellular voltage value, start at 0
// Neuron has a list of connections
int[] connections = {};
float[] weights = {};
int[] postSyn = {};
ArrayList<ActionPotentials> APs = new ArrayList<ActionPotentials>();
// The Neuron's size can be animated
float r_base = 20; // how big it is at the start
float r_pop = 30; // how big it grows to when it fires
float r = r_base; // for reference
float APr = r_base*0.5;
// Center of triangle of the neuron
float centerX;
float centerY;
Neuron(float x, float y, int neuronID, int currLayer, int selType) {
location = new PVector(x, y); // get start position
centerLoc = new PVector(location.x, location.y); //calculate center
//centerLoc = new PVector(((location.x + (location.x+r) + (location.x+(r/2)))/3),(location.y + location.y + (location.y-r))/3); //calculate center
layer = currLayer;
neuronNum = neuronID;
type = selType;
}
// Add a Connection
void addConnection(int c, float w) {
connections = append(connections, c);
weights = append(weights, w);
//connections.append(c);
//weights.append(w);
}
void addPostSyn(int x) {
postSyn = append(postSyn, x);
//postSyn.append(x);
}
// Receive an input
void feedForward(float input) {
// Accumulate it
MP += input;
MP = constrain(MP, minMP, spkT);
// did it reach the action potential threshold (of 1)?
if (MP >= spkT && (millis() - lastAP) > refract) {
fire();
justFired = true;
MP = RMP; // Reset the MP to 0 if it fires
spkCount++; // add a spike to the spike count
float apW = 0.0;
lastAP = millis();
for (int i = 0; i <= postSyn.length-1; i++) {
int currConnect = postSyn[i];
Neuron recN = network.neurons.get(currConnect);
PVector rec = recN.location;
for (int j = 0; j < recN.connections.length; j++) {
int isThisConnected = recN.connections[j];
if (isThisConnected == neuronNum) {
apW = recN.weights[j];
}
}
APs.add(new ActionPotentials(location, rec, recN.neuronNum, apW, 1.0));
}
} else if (MP >= spkT && (millis() - lastAP) < refract) {
MP = spkT-250;
}
}
void checkFire(){
if ((millis() - lastAP) > 2000 && justFired) {
justFired = false;
}
}
// The Neuron fires
void fire() {
r = r_pop; // It suddenly is bigger
// We send the output through all connections
}
void displayN() {
// neurons
pushStyle();
stroke(0);
strokeWeight(1);
if (type == 1) {
color fromC = color(255, 255, 255);
color startC = color(rectE_highlight); //color(74, 209, 111);
color endC = color(0, 128, 0);
float layerF = layer;
color toC = lerpColor(startC, endC, layerF/9);
float memP = map(MP, minMP, spkT, 0, 1);
color neuronC = lerpColor(fromC, toC, memP);
fill(neuronC);
}
if (type == 0) {
color fromC = color(255, 255, 255);
color startC = color(rectI_highlight); //color(74, 209, 111);
color endC = color(128, 0, 0);
float layerF = layer;
color toC = lerpColor(startC, endC, layerF/9);
float memP = map(MP, minMP, spkT, 0, 1);
color neuronC = lerpColor(fromC, toC, memP);
fill(neuronC);
}
ellipse(location.x, location.y, r, r);
r = lerp(r, r_base, 0.1);
popStyle();
} // display N
void displayAx() {
// Axons
pushStyle();
stroke(0, 0, 0);
for (int i = 0; i < connections.length; i++) {
int currPost = connections[i];
Neuron post = network.neurons.get(currPost);
float currW = weights[i];
strokeWeight(currW);
line(centerLoc.x, centerLoc.y, post.centerLoc.x, post.centerLoc.y);
}
popStyle();
} // display axon
void displayAP() {
// AP
pushStyle();
for (int i = 0; i < APs.size(); i++) {
ActionPotentials currAP = APs.get(i);
//color fromAPC = color(242, 223, 0);
//color toAPC = color(11, 107, 191);
//float layerAPF = layer;
//color apC = lerpColor(fromAPC, toAPC, layerAPF/9);
stroke(100);
strokeWeight(1);
color apC = color(255, 255, 0);
fill(apC);
ellipse(currAP.location.x, currAP.location.y, APr, APr); // draw APs as X% the neuron size
}
popStyle();
} // display ap
}