001         package com.croftsoft.apps.backpropxor;
002    
003         import java.awt.Choice;
004         import java.awt.Color;
005         import java.awt.Font;
006         import java.awt.Graphics;
007         import java.awt.Rectangle;
008         import java.awt.event.*;
009         import java.text.DecimalFormat;
010         import javax.swing.*;
011    
012         import com.croftsoft.core.gui.ShutdownWindowListener;
013         import com.croftsoft.core.gui.FrameLib;
014         import com.croftsoft.core.gui.plot.PlotLib;
015         import com.croftsoft.core.lang.lifecycle.Lifecycle;
016         import com.croftsoft.core.math.*;
017    
018         /*********************************************************************
019         * A Backpropagation neural network learning algorithm demonstration.
020         *
021         * @version
022         *   2002-02-28
023         * @since
024         *   1996
025         * @author
026         *   <a href="http://www.alumni.caltech.edu/~croft/">David W. Croft</a>
027         *********************************************************************/
028    
029         public class  BackpropXor
030           extends JPanel
031           implements ActionListener, ItemListener, Lifecycle, Runnable,
032             BackpropXorConstants
033         //////////////////////////////////////////////////////////////////////
034         //////////////////////////////////////////////////////////////////////
035         {
036    
037         private Matrix  l0_activations      = new Matrix ( 3, 1 );
038         private Matrix  l1_activations      = new Matrix ( 2, 1 );
039         private Matrix  l2_inputs           = new Matrix ( 3, 1 );
040         private Matrix  l2_activations      = new Matrix ( 1, 1 );
041         private Matrix  l1_weights          = new Matrix ( 3, 2 );
042         private Matrix  l2_weights          = new Matrix ( 3, 1 );
043         private Matrix  l1_weighted_sums    = new Matrix ( 2, 1 );
044         private Matrix  l2_weighted_sums    = new Matrix ( 1, 1 );
045         private double  output_desired      = 0.0;
046         private double  output_error        = 0.0;
047         private Matrix  l2_local_gradient   = new Matrix ( 1, 1 );
048         private double  learning_rate       = 0.1;
049         private double  momentum_constant   = 0.9;
050         private Matrix  l2_weights_delta    = new Matrix ( 3, 1 );
051         private Matrix  l2_weights_momentum = new Matrix ( 3, 1 );
052         private Matrix  sum_weighted_deltas = new Matrix ( 2, 1 );
053         private Matrix  l1_local_gradients  = new Matrix ( 2, 1 );
054         private Matrix  l1_weights_delta    = new Matrix ( 3, 2 );
055         private Matrix  l1_weights_momentum = new Matrix ( 3, 2 );
056    
057         //
058    
059         private long        total_samples = 0;
060    
061         private int         iteration = 0;
062         private double [ ]  squared_errors = new double [ ITERATIONS_PER_EPOCH ];
063         private int         epoch = 0;
064         private int         epochs_max = 1000;
065         private double [ ]  epoch_rms_error = new double [ epochs_max ];
066    
067         private double [ ] [ ]  samples = new double [ ITERATIONS_PER_EPOCH ] [ 3 ];
068    
069         //
070    
071         private Rectangle  r;
072         private Rectangle  rs;
073    
074         //
075    
076         private Thread   runner;
077    
078         private boolean  pleaseStop = false;
079    
080         private boolean  isPaused;
081    
082         //
083    
084         private JComboBox   functionComboBox;
085    
086         private JTextField  learning_rate_TextField;
087    
088         private JLabel      learning_rate_Label;
089    
090         private JTextField  momentum_TextField;
091    
092         private JLabel      momentum_Label;
093    
094         private JButton     randomize_Button;
095    
096         private JButton     reset_Button;
097    
098         private JButton     pause_Button;
099    
100         //
101    
102         private int     function_selected;
103    
104         private String  function_String;
105    
106         private int  y2;
107    
108         private int  y4;
109    
110         private DecimalFormat  decimalFormat
111           = new DecimalFormat ( "0.000000" );
112    
113         //////////////////////////////////////////////////////////////////////
114         //////////////////////////////////////////////////////////////////////
115    
116         public static void  main ( String [ ]  args )
117         //////////////////////////////////////////////////////////////////////
118         {
119           JFrame  jFrame = new JFrame ( TITLE );
120    
121    /*
122           try
123           {
124             jFrame.setIconImage ( ClassLib.getResourceAsImage (
125               BackpropXor.class, FRAME_ICON_FILENAME ) );
126           }
127           catch ( Exception  ex )
128           {
129             ex.printStackTrace ( );
130           }
131    */
132    
133           BackpropXor  backpropXor = new BackpropXor ( );
134    
135           jFrame.setContentPane ( backpropXor );
136    
137           FrameLib.launchJFrameAsDesktopApp (
138             jFrame,
139             new Lifecycle [ ] { backpropXor },
140             FRAME_SIZE,
141             SHUTDOWN_CONFIRMATION_PROMPT );
142         }
143    
144         //////////////////////////////////////////////////////////////////////
145         // interface Lifecycle methods
146         //////////////////////////////////////////////////////////////////////
147    
148         public void  init ( )
149         //////////////////////////////////////////////////////////////////////
150         {
151           setFont ( new Font ( "Times New Roman", Font.PLAIN, FONT_SIZE ) );
152    
153           function_selected = INITIAL_FUNCTION;
154    
155           function_String = FUNCTION_NAMES [ function_selected ];
156    
157           functionComboBox = new JComboBox ( FUNCTION_NAMES );
158    
159           functionComboBox.setSelectedIndex ( function_selected );
160    
161           learning_rate_TextField = new JTextField ( "" + learning_rate, 4 );
162    
163           learning_rate_Label = new JLabel ( "Learning Rate" );
164    
165           momentum_TextField = new JTextField ( "" + momentum_constant, 4 );
166    
167           momentum_Label = new JLabel ( "Momentum Factor" );
168    
169           add ( functionComboBox );
170           add ( momentum_Label );
171           add ( momentum_TextField );
172           add ( learning_rate_Label );
173           add ( learning_rate_TextField );
174    
175           JPanel  buttonPanel = new JPanel ( );
176    
177           randomize_Button = new JButton ( "Randomize Weights" );
178    
179           reset_Button = new JButton ( "Reset Strip" );
180    
181           pause_Button = new JButton ( "  Pause  " );
182    
183           buttonPanel.add ( randomize_Button );
184           
185           buttonPanel.add ( pause_Button );
186           
187           buttonPanel.add ( reset_Button );
188    
189           add ( buttonPanel );
190    
191           functionComboBox       .addItemListener   ( this );
192    
193           momentum_TextField     .addActionListener ( this );
194    
195           learning_rate_TextField.addActionListener ( this );
196    
197           randomize_Button       .addActionListener ( this );
198    
199           pause_Button           .addActionListener ( this );
200    
201           reset_Button           .addActionListener ( this );
202    
203           r = new Rectangle (
204             10, 10 + Y, SIZE.width - 100, SIZE.height - 350 );
205    
206           y2 = 30 + Y + r.height;
207    
208           y4 = y2 + YTAB * 8;
209    
210           rs = new Rectangle (
211             10 + r.width + 10, 10 + Y,
212             SIZE.height - 350, SIZE.height - 350 );
213    
214           randomize_weights ( );
215         }
216    
217         public void  start ( )
218         //////////////////////////////////////////////////////////////////////
219         {
220           if ( ( runner == null ) && !isPaused )
221           {
222             pleaseStop = false;
223    
224             runner = new Thread ( this );
225    
226             runner.setPriority ( runner.getPriority ( ) - 1 );
227    
228             runner.setDaemon ( true );
229    
230             runner.start ( );
231           }
232         }
233    
234         public void  stop ( )
235         //////////////////////////////////////////////////////////////////////
236         {
237           pleaseStop = true;
238         }
239    
240         public void  destroy ( )
241         //////////////////////////////////////////////////////////////////////
242         {
243           stop ( );
244         }
245    
246         //////////////////////////////////////////////////////////////////////
247         //////////////////////////////////////////////////////////////////////
248    
249         public void  itemStateChanged ( ItemEvent  itemEvent )
250         //////////////////////////////////////////////////////////////////////
251         {
252           Object  source = itemEvent.getSource ( );
253    
254           if ( source == functionComboBox )
255           {
256             function_selected = functionComboBox.getSelectedIndex ( );
257    
258             function_String = ( String ) functionComboBox.getSelectedItem ( );
259           }
260         }
261    
262         public void  actionPerformed ( ActionEvent  actionEvent )
263         //////////////////////////////////////////////////////////////////////
264         {
265           Object  source = actionEvent.getSource ( );
266    
267           if ( source == momentum_TextField )
268           {
269             momentum_constant = Double.valueOf (
270               momentum_TextField.getText ( ) ).doubleValue ( );
271           }
272           else if ( source == learning_rate_TextField )
273           {
274             learning_rate = Double.valueOf (
275               learning_rate_TextField.getText ( ) ).doubleValue ( );
276           }
277           else if ( source == randomize_Button )
278           {
279             randomize_weights ( );
280           }
281           else if ( source == reset_Button )
282           {
283             total_samples = 0;
284    
285             iteration = 0;
286    
287             epoch = 0;
288           }
289           else if ( source == pause_Button )
290           {
291             if ( isPaused )
292             {
293               isPaused = false;
294    
295               pause_Button.setText ( " Pause " );
296    
297               start ( );           
298             }
299             else
300             {
301               isPaused = true;
302    
303               pause_Button.setText ( "Continue" );
304    
305               stop ( );
306             }
307           }
308         }
309    
310         public void  paintComponent ( Graphics g )
311         //////////////////////////////////////////////////////////////////////
312         {
313           super.paintComponent ( g );
314    
315           plot_epochs ( r, g, epoch_rms_error );
316    
317           plotGraph  ( g, samples );
318    
319           g.drawString ( "Inputs", 10 + 0 * XTAB, 10 + 0 * YTAB + y2 );
320           g.drawString ( decimalFormat.format ( l0_activations.data    [ 0 ] [ 0 ] ), 10 + 0 * XTAB, 10 + 1 * YTAB + y2 );
321           g.drawString ( decimalFormat.format ( l0_activations.data    [ 1 ] [ 0 ] ), 10 + 0 * XTAB, 10 + 2 * YTAB + y2 );
322           g.drawString ( decimalFormat.format ( l0_activations.data    [ 2 ] [ 0 ] ), 10 + 0 * XTAB, 10 + 3 * YTAB + y2 );
323    
324           g.drawString ( "Weights", 10 + 1 * XTAB, 10 + 0 * YTAB + y2 );
325           g.drawString ( decimalFormat.format ( l1_weights.data        [ 0 ] [ 0 ] ), 10 + 1 * XTAB, 10 + 1 * YTAB + y2 );
326           g.drawString ( decimalFormat.format ( l1_weights.data        [ 1 ] [ 0 ] ), 10 + 1 * XTAB, 10 + 2 * YTAB + y2 );
327           g.drawString ( decimalFormat.format ( l1_weights.data        [ 2 ] [ 0 ] ), 10 + 1 * XTAB, 10 + 3 * YTAB + y2 );
328           g.drawString ( decimalFormat.format ( l1_weights.data        [ 0 ] [ 1 ] ), 10 + 1 * XTAB, 10 + 4 * YTAB + y2 );
329           g.drawString ( decimalFormat.format ( l1_weights.data        [ 1 ] [ 1 ] ), 10 + 1 * XTAB, 10 + 5 * YTAB + y2 );
330           g.drawString ( decimalFormat.format ( l1_weights.data        [ 2 ] [ 1 ] ), 10 + 1 * XTAB, 10 + 6 * YTAB + y2 );
331    
332           g.drawString ( "Weighted Sums", 10 + 2 * XTAB, 10 + 0 * YTAB + y2 );
333           g.drawString ( decimalFormat.format ( l1_weighted_sums.data  [ 0 ] [ 0 ] ), 10 + 2 * XTAB, 10 + 1 * YTAB + y2 );
334           g.drawString ( decimalFormat.format ( l1_weighted_sums.data  [ 1 ] [ 0 ] ), 10 + 2 * XTAB, 10 + 2 * YTAB + y2 );
335                                                                                         
336           g.drawString ( "Hidden", 10 + 3 * XTAB, 10 + 0 * YTAB + y2 );
337           g.drawString ( decimalFormat.format ( l2_inputs.data         [ 0 ] [ 0 ] ), 10 + 3 * XTAB, 10 + 1 * YTAB + y2 );
338           g.drawString ( decimalFormat.format ( l2_inputs.data         [ 1 ] [ 0 ] ), 10 + 3 * XTAB, 10 + 2 * YTAB + y2 );
339           g.drawString ( decimalFormat.format ( l2_inputs.data         [ 2 ] [ 0 ] ), 10 + 3 * XTAB, 10 + 3 * YTAB + y2 );
340    
341           g.drawString ( "Weights", 10 + 4 * XTAB, 10 + 0 * YTAB + y2 );
342           g.drawString ( decimalFormat.format ( l2_weights.data        [ 0 ] [ 0 ] ), 10 + 4 * XTAB, 10 + 1 * YTAB + y2 );
343           g.drawString ( decimalFormat.format ( l2_weights.data        [ 1 ] [ 0 ] ), 10 + 4 * XTAB, 10 + 2 * YTAB + y2 );
344           g.drawString ( decimalFormat.format ( l2_weights.data        [ 2 ] [ 0 ] ), 10 + 4 * XTAB, 10 + 3 * YTAB + y2 );
345    
346           g.drawString ( "Weighted Sum", 10 + 5 * XTAB, 10 + 0 * YTAB + y2 );
347           g.drawString ( decimalFormat.format ( l2_weighted_sums.data  [ 0 ] [ 0 ] ), 10 + 5 * XTAB, 10 + 1 * YTAB + y2 );
348    
349           g.drawString ( "Output", 10 + 6 * XTAB, 10 + 0 * YTAB + y2 );
350           g.drawString ( decimalFormat.format ( l2_activations.data    [ 0 ] [ 0 ] ), 10 + 6 * XTAB, 10 + 1 * YTAB + y2 );
351    
352           g.drawString ( "Output Desired", 10 + 6 * XTAB, 10 + 2 * YTAB + y2 );
353           g.drawString ( decimalFormat.format ( output_desired ), 10 + 6 * XTAB, 10 + 3 * YTAB + y2 );
354    
355           g.drawString ( "Output Error", 10 + 6 * XTAB, 10 + 4 * YTAB + y2 );
356           g.drawString ( decimalFormat.format ( output_error ), 10 + 6 * XTAB, 10 + 5 * YTAB + y2 );
357    
358           g.drawString ( "Iterations", 10 + 7 * XTAB,  10 + 0 * YTAB + y2 );
359           g.drawString ( "" + total_samples, 10 + 7 * XTAB,  10 + 1 * YTAB + y2 );
360    
361           g.drawString ( "Function", 10 + 7 * XTAB,  10 + 2 * YTAB + y2 );
362           g.drawString ( function_String, 10 + 7 * XTAB,  10 + 3 * YTAB + y2 );
363    
364           g.drawString ( "L2 Gradient", 10 + 0 * XTAB, 0 * YTAB + y4 );
365           g.drawString ( decimalFormat.format ( l2_local_gradient.data [ 0 ] [ 0 ] ), 10 + 0 * XTAB, 1 * YTAB + y4 );
366    
367           g.drawString ( "W2 Deltas", 10 + 1 * XTAB, 0 * YTAB + y4 );
368           g.drawString ( decimalFormat.format ( l2_weights_delta.data  [ 0 ] [ 0 ] ), 10 + 1 * XTAB, 1 * YTAB + y4 );
369           g.drawString ( decimalFormat.format ( l2_weights_delta.data  [ 1 ] [ 0 ] ), 10 + 1 * XTAB, 2 * YTAB + y4 );
370           g.drawString ( decimalFormat.format ( l2_weights_delta.data  [ 2 ] [ 0 ] ), 10 + 1 * XTAB, 3 * YTAB + y4 );
371    
372           g.drawString ( "W2 Momentum", 10 + 2 * XTAB, 0 * YTAB + y4 );
373           g.drawString ( decimalFormat.format ( l2_weights_momentum.data [ 0 ] [ 0 ] ), 10 + 2 * XTAB, 1 * YTAB + y4 );
374           g.drawString ( decimalFormat.format ( l2_weights_momentum.data [ 1 ] [ 0 ] ), 10 + 2 * XTAB, 2 * YTAB + y4 );
375           g.drawString ( decimalFormat.format ( l2_weights_momentum.data [ 2 ] [ 0 ] ), 10 + 2 * XTAB, 3 * YTAB + y4 );
376    
377           g.drawString ( "Sum W Deltas", 10 + 3 * XTAB, 0 * YTAB + y4 );
378           g.drawString ( decimalFormat.format ( sum_weighted_deltas.data [ 0 ] [ 0 ] ), 10 + 3 * XTAB, 1 * YTAB + y4 );
379           g.drawString ( decimalFormat.format ( sum_weighted_deltas.data [ 1 ] [ 0 ] ), 10 + 3 * XTAB, 2 * YTAB + y4 );
380    
381           g.drawString ( "L1 Gradients", 10 + 4 * XTAB, 0 * YTAB + y4 );
382           g.drawString ( decimalFormat.format ( l1_local_gradients.data [ 0 ] [ 0 ] ), 10 + 4 * XTAB, 1 * YTAB + y4 );
383           g.drawString ( decimalFormat.format ( l1_local_gradients.data [ 1 ] [ 0 ] ), 10 + 4 * XTAB, 2 * YTAB + y4 );
384    
385           g.drawString ( "W1 Deltas", 10 + 5 * XTAB, 0 * YTAB + y4 );
386           g.drawString ( decimalFormat.format ( l1_weights_delta.data  [ 0 ] [ 0 ] ), 10 + 5 * XTAB, 1 * YTAB + y4 );
387           g.drawString ( decimalFormat.format ( l1_weights_delta.data  [ 1 ] [ 0 ] ), 10 + 5 * XTAB, 2 * YTAB + y4 );
388           g.drawString ( decimalFormat.format ( l1_weights_delta.data  [ 2 ] [ 0 ] ), 10 + 5 * XTAB, 3 * YTAB + y4 );
389           g.drawString ( decimalFormat.format ( l1_weights_delta.data  [ 0 ] [ 1 ] ), 10 + 5 * XTAB, 4 * YTAB + y4 );
390           g.drawString ( decimalFormat.format ( l1_weights_delta.data  [ 1 ] [ 1 ] ), 10 + 5 * XTAB, 5 * YTAB + y4 );
391           g.drawString ( decimalFormat.format ( l1_weights_delta.data  [ 2 ] [ 1 ] ), 10 + 5 * XTAB, 6 * YTAB + y4 );
392    
393           g.drawString ( "W1 Momentum", 10 + 6 * XTAB, 0 * YTAB + y4 );
394           g.drawString ( decimalFormat.format ( l1_weights_momentum.data [ 0 ] [ 0 ] ), 10 + 6 * XTAB, 1 * YTAB + y4 );
395           g.drawString ( decimalFormat.format ( l1_weights_momentum.data [ 1 ] [ 0 ] ), 10 + 6 * XTAB, 2 * YTAB + y4 );
396           g.drawString ( decimalFormat.format ( l1_weights_momentum.data [ 2 ] [ 0 ] ), 10 + 6 * XTAB, 3 * YTAB + y4 );
397           g.drawString ( decimalFormat.format ( l1_weights_momentum.data [ 0 ] [ 1 ] ), 10 + 6 * XTAB, 4 * YTAB + y4 );
398           g.drawString ( decimalFormat.format ( l1_weights_momentum.data [ 1 ] [ 1 ] ), 10 + 6 * XTAB, 5 * YTAB + y4 );
399           g.drawString ( decimalFormat.format ( l1_weights_momentum.data [ 2 ] [ 1 ] ), 10 + 6 * XTAB, 6 * YTAB + y4 );
400         }
401    
402         public void  run ( )
403         //////////////////////////////////////////////////////////////////////
404         {
405           try
406           {
407    
408           long  lastRepaintTime = 0;
409    
410           while ( !pleaseStop )
411           {
412             if ( iteration == 0 )
413             {
414               long  currentTime = System.currentTimeMillis ( );
415    
416               if ( currentTime >= lastRepaintTime + REPAINT_PERIOD )
417               {
418                 lastRepaintTime = currentTime;
419    
420                 repaint ( );
421               }
422             }
423    
424             l2_weights = l2_weights.add ( l2_weights_delta );
425    
426    //         l2_weights = l2_weights.clip ( -1.0, 1.0 );
427    
428             l1_weights = l1_weights.add ( l1_weights_delta );
429    
430    //         l1_weights = l1_weights.clip ( -1.0, 1.0 );
431    
432             l0_activations = l0_activations.randomizeUniform ( 0.0, 1.0 );
433    
434             l0_activations.data [ 0 ] [ 0 ] = 1.0;
435    
436             samples [ iteration ] [ 0 ] = l0_activations.data [ 1 ] [ 0 ];
437    
438             samples [ iteration ] [ 1 ] = l0_activations.data [ 2 ] [ 0 ];
439    
440             l1_weighted_sums = Matrix.multiply (
441               l1_weights.transpose ( ), l0_activations );
442    
443             l1_activations = l1_weighted_sums.sigmoid ( );
444    
445             l2_inputs.data [ 0 ] [ 0 ] = 1.0;
446    
447             l2_inputs.data [ 1 ] [ 0 ] = l1_activations.data [ 0 ] [ 0 ];
448    
449             l2_inputs.data [ 2 ] [ 0 ] = l1_activations.data [ 1 ] [ 0 ];
450    
451             l2_weighted_sums = Matrix.multiply (
452               l2_weights.transpose ( ), l2_inputs );
453    
454             l2_activations = l2_weighted_sums.sigmoid ( );
455    
456             samples [ iteration ] [ 2 ] = l2_activations.data [ 0 ] [ 0 ];
457    
458             output_desired = target_function ( l0_activations );
459    
460             output_error = output_desired - l2_activations.data [ 0 ] [ 0 ];
461    
462             l2_local_gradient = l2_weighted_sums.sigmoidDerivative ( );
463    
464             l2_local_gradient = l2_local_gradient.multiply ( output_error );
465    
466    // I'm not sure about the transpose below here.
467    
468             l2_weights_delta
469               = Matrix.multiply ( l2_inputs, l2_local_gradient.transpose ( ) );
470    
471             l2_weights_delta = l2_weights_delta.multiply ( learning_rate );
472    
473             l2_weights_delta = l2_weights_delta.add ( l2_weights_momentum );
474    
475             l2_weights_momentum
476               = l2_weights_delta.multiply ( momentum_constant );
477    
478    // needs transpose or sum below?
479    
480             sum_weighted_deltas = l2_weights.submatrix ( 1, 2, 0, 0 );
481    
482             sum_weighted_deltas = Matrix.multiply (
483               sum_weighted_deltas, l2_local_gradient );
484    
485             l1_local_gradients = new Matrix (
486               l1_activations.rows, l1_activations.cols, 1.0 );
487    
488             l1_local_gradients = l1_local_gradients.subtract ( l1_activations );
489    
490             l1_local_gradients = Matrix.multiplyPairwise (
491               l1_activations, l1_local_gradients );
492    
493             l1_local_gradients = Matrix.multiplyPairwise (
494               l1_local_gradients, sum_weighted_deltas );
495    
496             l1_weights_delta = Matrix.multiply (
497               l0_activations, l1_local_gradients.transpose ( ) );
498    
499             l1_weights_delta = l1_weights_delta.multiply ( learning_rate );
500    
501             l1_weights_delta = l1_weights_delta.add ( l1_weights_momentum );
502    
503             l1_weights_momentum
504               = l1_weights_delta.multiply ( momentum_constant );
505    
506             squared_errors [ iteration ] = output_error * output_error;
507    
508             iteration++;
509    
510             total_samples++;
511    
512             if ( iteration % ITERATIONS_PER_EPOCH == 0 )
513             {
514               epoch++;
515    
516               if ( epoch >= epochs_max ) epoch = 1;
517    
518               iteration = 0;
519    
520               epoch_rms_error [ epoch - 1 ] = 0.0;
521    
522               for ( int index_iteration = 0;
523                         index_iteration < ITERATIONS_PER_EPOCH;
524                         index_iteration++ )
525               {
526                 epoch_rms_error [ epoch - 1 ]
527                   += squared_errors [ index_iteration ];
528               }
529    
530               epoch_rms_error [ epoch - 1 ] /= ( double ) ITERATIONS_PER_EPOCH;
531    
532               epoch_rms_error [ epoch - 1 ]
533                 = Math.sqrt ( epoch_rms_error [ epoch - 1 ] );
534             }
535    
536             Thread.sleep ( 0 );
537           }
538    
539           repaint ( );
540    
541           }
542           catch ( Exception  ex )
543           {
544             ex.printStackTrace ( );
545           }
546           finally
547           {
548             runner = null;
549           }
550         }
551    
552         //////////////////////////////////////////////////////////////////////
553         // private methods
554         //////////////////////////////////////////////////////////////////////
555    
556         private void  plot_epochs (
557           Rectangle  r, Graphics  g, double [ ]  epochs )
558         //////////////////////////////////////////////////////////////////////
559         {
560           g.setColor ( java.awt.Color.black );
561           g.fillRect ( r.x, r.y, r.width, r.height );
562           g.setColor ( java.awt.Color.white );
563           g.drawRect ( r.x, r.y, r.width, r.height / 2 );
564    //       g.clipRect ( r.x, r.y, r.width, r.height );
565           for ( int index_epoch = 1;
566                     index_epoch <= epoch;
567                     index_epoch++ )
568           {
569             PlotLib.xy ( java.awt.Color.red,
570               ( double ) index_epoch, epochs [ index_epoch - 1 ],
571               r, g, 1.0, ( double ) epoch, 0.0, 1.0, OVAL_SIZE, true );
572           }
573    //       g.clipRect ( 0, 0, SIZE.width, SIZE.height );
574           g.setColor ( java.awt.Color.white );
575           g.drawRect ( r.x, r.y, r.width, r.height );
576           g.setColor ( this.getForeground ( ) );
577         }
578    
579         private void  plotGraph ( Graphics  g, double [ ] [ ]  samples )
580         //////////////////////////////////////////////////////////////////////
581         {
582           g.setColor ( Color.black );
583    
584           g.fillRect ( rs.x, rs.y, rs.width, rs.height );
585    
586           for ( int index_iteration = 0;
587             index_iteration < ITERATIONS_PER_EPOCH;
588             index_iteration++ )
589           {
590             PlotLib.xy (
591               samples [ index_iteration ] [ 2 ] >= 0.5
592                 ? Color.green : Color.red,
593               samples [ index_iteration ] [ 0 ],
594               samples [ index_iteration ] [ 1 ],
595               rs, g, 0.0, 1.0, 0.0, 1.0, OVAL_SIZE );
596           }
597    
598           g.setColor ( Color.white );
599    
600           g.drawRect ( rs.x, rs.y, rs.width, rs.height );
601    
602           g.setColor ( this.getForeground ( ) );
603         }
604    
605         private void  randomize_weights ( )
606         //////////////////////////////////////////////////////////////////////
607         {
608           l1_weights = l1_weights.randomizeUniform ( -1.0, 1.0 );
609    
610           l2_weights = l2_weights.randomizeUniform ( -1.0, 1.0 );
611         }
612    
613         private double  target_function ( Matrix  inputs )
614         //////////////////////////////////////////////////////////////////////
615         {
616           long  a = Math.round ( inputs.data [ 1 ] [ 0 ] );
617    
618           long  b = Math.round ( inputs.data [ 2 ] [ 0 ] );
619    
620           long  bit_num = 2 * b + a;
621    
622           long  mask = 1 << bit_num;
623    
624           long  masked = function_selected & mask;
625    
626           return ( masked == mask ) ? 1.0 : 0.0;
627         }
628    
629         //////////////////////////////////////////////////////////////////////
630         //////////////////////////////////////////////////////////////////////
631         }