今天就跟大家聊聊有关dl4j如何使用遗传神经网络完成手写数字识别,可能很多人都不太了解,为了让大家更加了解,小编给大家总结了以下内容,希望大家根据这篇文章可以有所收获。
实现步骤
1.随机初始化若干个智能体(神经网络),并让智能体识别训练数据,并对识别结果进行排序
2.随机在排序结果中选择一个作为母本,并在比母本识别率更高的智能体中随机选择一个作为父本
3.随机选择母本或父本同位的神经网络超参组成新的智能体
4.按照母本的排序对智能体进行超参调整,排序越靠后调整幅度越大(1%~10%)之间
5.让新的智能体识别训练集并放入排行榜,并移除排行榜最后一位
6.重复2~5过程,让识别率越来越高
这个过程就类似于自然界的优胜劣汰,将神经网络超参看作dna,超参的调整看作dna的突变;当然还可以把拥有不同隐藏层的神经网络看作不同的物种,让竞争过程更加多样化.当然我们这里只讨论一种神经网络的情况
优势: 可以解决很多没有头绪的问题 劣势: 训练效率极低
gitee地址:
https://gitee.com/ichiva/gnn.git
实现步骤 1.进化接口
public interface Evolution { /** * 遗传 * @param mDna * @param fDna * @return */ INDArray inheritance(INDArray mDna,INDArray fDna); /** * 突变 * @param dna * @param v * @param r 突变范围 * @return */ INDArray mutation(INDArray dna,double v, double r); /** * 置换 * @param dna * @param v * @return */ INDArray substitution(INDArray dna,double v); /** * 外源 * @param dna * @param v * @return */ INDArray other(INDArray dna,double v); /** * DNA 是否同源 * @param mDna * @param fDna * @return */ boolean iSogeny(INDArray mDna, INDArray fDna); }
一个比较通用的实现
public class MnistEvolution implements Evolution { private static final MnistEvolution instance = new MnistEvolution(); public static MnistEvolution getInstance() { return instance; } @Override public INDArray inheritance(INDArray mDna, INDArray fDna) { if(mDna == fDna) return mDna; long[] mShape = mDna.shape(); if(!iSogeny(mDna,fDna)){ throw new RuntimeException("非同源dna"); } INDArray nDna = Nd4j.create(mShape); NdIndexIterator it = new NdIndexIterator(mShape); while (it.hasNext()){ long[] next = it.next(); double val; if(Math.random() > 0.5){ val = fDna.getDouble(next); }else { val = mDna.getDouble(next); } nDna.putScalar(next,val); } return nDna; } @Override public INDArray mutation(INDArray dna, double v, double r) { long[] shape = dna.shape(); INDArray nDna = Nd4j.create(shape); NdIndexIterator it = new NdIndexIterator(shape); while (it.hasNext()) { long[] next = it.next(); if(Math.random() < v){ dna.putScalar(next,dna.getDouble(next) + ((Math.random() - 0.5) * r * 2)); }else { nDna.putScalar(next,dna.getDouble(next)); } } return nDna; } @Override public INDArray substitution(INDArray dna, double v) { long[] shape = dna.shape(); INDArray nDna = Nd4j.create(shape); NdIndexIterator it = new NdIndexIterator(shape); while (it.hasNext()) { long[] next = it.next(); if(Math.random() > v){ long[] tag = new long[shape.length]; for (int i = 0; i < shape.length; i++) { tag[i] = (long) (Math.random() * shape[i]); } nDna.putScalar(next,dna.getDouble(tag)); }else { nDna.putScalar(next,dna.getDouble(next)); } } return nDna; } @Override public INDArray other(INDArray dna, double v) { long[] shape = dna.shape(); INDArray nDna = Nd4j.create(shape); NdIndexIterator it = new NdIndexIterator(shape); while (it.hasNext()) { long[] next = it.next(); if(Math.random() > v){ nDna.putScalar(next,Math.random()); }else { nDna.putScalar(next,dna.getDouble(next)); } } return nDna; } @Override public boolean iSogeny(INDArray mDna, INDArray fDna) { long[] mShape = mDna.shape(); long[] fShape = fDna.shape(); if (mShape.length == fShape.length) { for (int i = 0; i < mShape.length; i++) { if (mShape[i] != fShape[i]) { return false; } } return true; } return false; } }
定义智能体配置接口
public interface AgentConfig { /** * 输入量 * @return */ int getInput(); /** * 输出量 * @return */ int getOutput(); /** * 神经网络配置 * @return */ MultiLayerConfiguration getMultiLayerConfiguration(); }
按手写数字识别进行配置实现
public class MnistConfig implements AgentConfig { @Override public int getInput() { return 28 * 28; } @Override public int getOutput() { return 10; } @Override public MultiLayerConfiguration getMultiLayerConfiguration() { return new NeuralNetConfiguration.Builder() .seed((long) (Math.random() * Long.MAX_VALUE)) .updater(new Nesterovs(0.006, 0.9)) .l2(1e-4) .list() .layer(0, new DenseLayer.Builder() .nIn(getInput()) .nOut(1000) .activation(Activation.RELU) .weightInit(WeightInit.XAVIER) .build()) .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) //create hidden layer .nIn(1000) .nOut(getOutput()) .activation(Activation.SOFTMAX) .weightInit(WeightInit.XAVIER) .build()) .pretrain(false).backprop(true) .build(); } }
智能体基类
@Getter public class Agent { private final AgentConfig config; private final INDArray dna; private final MultiLayerNetwork multiLayerNetwork; /** * 采用默认方法初始化参数 * @param config */ public Agent(AgentConfig config){ this(config,null); } /** * * @param config * @param dna */ public Agent(AgentConfig config, INDArray dna){ if(dna == null){ this.config = config; MultiLayerConfiguration conf = config.getMultiLayerConfiguration(); this.multiLayerNetwork = new MultiLayerNetwork(conf); multiLayerNetwork.init(); this.dna = multiLayerNetwork.params(); }else { this.config = config; MultiLayerConfiguration conf = config.getMultiLayerConfiguration(); this.multiLayerNetwork = new MultiLayerNetwork(conf); multiLayerNetwork.init(dna,true); this.dna = dna; } } }
手写数字智能体实现类
@Getter @Setter public class MnistAgent extends Agent { private static final AtomicInteger index = new AtomicInteger(0); private String name; /** * 环境适应分数 */ private double score; /** * 验证分数 */ private double validScore; public MnistAgent(AgentConfig config) { this(config,null); } public MnistAgent(AgentConfig config, INDArray dna) { super(config, dna); name = "agent-" + index.incrementAndGet(); } public static MnistConfig mnistConfig = new MnistConfig(); public static MnistAgent newInstance(){ return new MnistAgent(mnistConfig); } public static MnistAgent create(INDArray dna){ return new MnistAgent(mnistConfig,dna); } }
手写数字识别环境构建
@Slf4j public class MnistEnv { /** * 环境数据 */ private static final ThreadLocal<MnistDataSetIterator> tLocal = ThreadLocal.withInitial(() -> { try { return new MnistDataSetIterator(128, true, 0); } catch (IOException e) { throw new RuntimeException("mnist 文件读取失败"); } }); private static final ThreadLocal<MnistDataSetIterator> testLocal = ThreadLocal.withInitial(() -> { try { return new MnistDataSetIterator(128, false, 0); } catch (IOException e) { throw new RuntimeException("mnist 文件读取失败"); } }); private static final MnistEvolution evolution = MnistEvolution.getInstance(); /** * 环境承载上限 * * 超过上限AI会进行激烈竞争 */ private final int max; private Double maxScore,minScore; /** * 环境中的生命体 * * 新生代与历史代共同排序,选出最适应环境的个体 */ //2个变量,一个队列保存KEY的顺序,一个MAP保存KEY对应的具体对象的数据 线程安全map private final TreeMap<Double,MnistAgent> lives = new TreeMap<>(); /** * 初始化环境 * * 1.向环境中初始化ai * 2.将初始化ai进行环境适应性测试,并排序 * @param max */ public MnistEnv(int max){ this.max = max; for (int i = 0; i < max; i++) { MnistAgent agent = MnistAgent.newInstance(); test(agent); synchronized (lives) { lives.put(agent.getScore(),agent); } log.info("初始化智能体 name = {} , score = {}",i,agent.getScore()); } synchronized (lives) { minScore = lives.firstKey(); maxScore = lives.lastKey(); } } /** * 环境适应性评估 * @param ai */ public void test(MnistAgent ai){ MultiLayerNetwork network = ai.getMultiLayerNetwork(); MnistDataSetIterator dataIterator = tLocal.get(); Evaluation eval = new Evaluation(ai.getConfig().getOutput()); try { while (dataIterator.hasNext()) { DataSet data = dataIterator.next(); INDArray output = network.output(data.getFeatures(), false); eval.eval(data.getLabels(),output); } }finally { dataIterator.reset(); } ai.setScore(eval.accuracy()); } /** * 迁移评估 * * @param ai */ public void validation(MnistAgent ai){ MultiLayerNetwork network = ai.getMultiLayerNetwork(); MnistDataSetIterator dataIterator = testLocal.get(); Evaluation eval = new Evaluation(ai.getConfig().getOutput()); try { while (dataIterator.hasNext()) { DataSet data = dataIterator.next(); INDArray output = network.output(data.getFeatures(), false); eval.eval(data.getLabels(),output); } }finally { dataIterator.reset(); } ai.setValidScore(eval.accuracy()); } /** * 进化 * * 每轮随机创建ai并放入环境中进行优胜劣汰 * @param n 进化次数 */ public void evolution(int n){ BlockThreadPool blockThreadPool=new BlockThreadPool(2); for (int i = 0; i < n; i++) { blockThreadPool.execute(() -> contend(newLive())); } // for (int i = 0; i < n; i++) { // contend(newLive()); // } } /** * 竞争 * @param ai */ public void contend(MnistAgent ai){ test(ai); quality(ai); double score = ai.getScore(); if(score <= minScore){ UI.put("无法生存",String.format("name = %s, score = %s", ai.getName(),ai.getScore())); return; } Map.Entry<Double, MnistAgent> lastEntry; synchronized (lives) { lives.put(score,ai); if (lives.size() > max) { MnistAgent lastAI = lives.remove(lives.firstKey()); UI.put("淘 汰 ",String.format("name = %s, score = %s", lastAI.getName(),lastAI.getScore())); } lastEntry = lives.lastEntry(); minScore = lives.firstKey(); } Double lastScore = lastEntry.getKey(); if(lastScore > maxScore){ maxScore = lastScore; MnistAgent agent = lastEntry.getValue(); validation(agent); UI.put("max验证",String.format("score = %s,validScore = %s",lastScore,agent.getValidScore())); try { Warehouse.write(agent); } catch (IOException ex) { log.error("保存对象失败",ex); } } } ArrayList<Double> scoreList = new ArrayList<>(100); ArrayList<Integer> avgList = new ArrayList<>(); private void quality(MnistAgent ai) { synchronized (scoreList) { scoreList.add(ai.getScore()); if (scoreList.size() >= 100) { double avg = scoreList.stream().mapToDouble(e -> e) .average().getAsDouble(); avgList.add((int) (avg * 1000)); StringBuffer buffer = new StringBuffer(); avgList.forEach(e -> buffer.append(e).append('\t')); UI.put("平均得分",String.format("aix100 avg = %s",buffer.toString())); scoreList.clear(); } } } /** * 随机生成新智能体 * * 完全随机产生母本 * 随机从比目标相同或更高评分中选择父本 * * 基因进化在1%~10%之间进行,评分越高基于越稳定 */ public MnistAgent newLive(){ double r = Math.random(); //基因突变率 double v = r / 11 + 0.01; //母本 MnistAgent mAgent = getMother(r); //父本 MnistAgent fAgent = getFather(r); int i = (int) (Math.random() * 3); INDArray newDNA = evolution.inheritance(mAgent.getDna(), fAgent.getDna()); switch (i){ case 0: newDNA = evolution.other(newDNA,v); break; case 1: newDNA = evolution.mutation(newDNA,v,0.1); break; case 2: newDNA = evolution.substitution(newDNA,v); break; } return MnistAgent.create(newDNA); } /** * 父本只选择比母本评分高的样本 * @param r * @return */ private MnistAgent getFather(double r) { r += (Math.random() * (1-r)); return getMother(r); } private MnistAgent getMother(double r) { int index = (int) (r * max); return getMnistAgent(index); } private MnistAgent getMnistAgent(int index) { synchronized (lives) { Iterator<Map.Entry<Double, MnistAgent>> it = lives.entrySet().iterator(); for (int i = 0; i < index; i++) { it.next(); } return it.next().getValue(); } } }
主函数
@Slf4j public class Program { public static void main(String[] args) { UI.put("开始时间",new Date().toLocaleString()); MnistEnv env = new MnistEnv(128); env.evolution(Integer.MAX_VALUE); } }
运行截图
看完上述内容,你们对dl4j如何使用遗传神经网络完成手写数字识别有进一步的了解吗?如果还想了解更多知识或者相关内容,请关注亿速云行业资讯频道,感谢大家的支持。
免责声明:本站发布的内容(图片、视频和文字)以原创、转载和分享为主,文章观点不代表本网站立场,如果涉及侵权请联系站长邮箱:is@yisu.com进行举报,并提供相关证据,一经查实,将立刻删除涉嫌侵权内容。