001 package com.croftsoft.apps.backpropxor;
002
003 import java.awt.Choice;
004 import java.awt.Color;
005 import java.awt.Font;
006 import java.awt.Graphics;
007 import java.awt.Rectangle;
008 import java.awt.event.*;
009 import java.text.DecimalFormat;
010 import javax.swing.*;
011
012 import com.croftsoft.core.gui.ShutdownWindowListener;
013 import com.croftsoft.core.gui.FrameLib;
014 import com.croftsoft.core.gui.plot.PlotLib;
015 import com.croftsoft.core.lang.lifecycle.Lifecycle;
016 import com.croftsoft.core.math.*;
017
018 /*********************************************************************
019 * A Backpropagation neural network learning algorithm demonstration.
020 *
021 * @version
022 * 2002-02-28
023 * @since
024 * 1996
025 * @author
026 * <a href="http://www.alumni.caltech.edu/~croft/">David W. Croft</a>
027 *********************************************************************/
028
029 public class BackpropXor
030 extends JPanel
031 implements ActionListener, ItemListener, Lifecycle, Runnable,
032 BackpropXorConstants
033 //////////////////////////////////////////////////////////////////////
034 //////////////////////////////////////////////////////////////////////
035 {
036
037 private Matrix l0_activations = new Matrix ( 3, 1 );
038 private Matrix l1_activations = new Matrix ( 2, 1 );
039 private Matrix l2_inputs = new Matrix ( 3, 1 );
040 private Matrix l2_activations = new Matrix ( 1, 1 );
041 private Matrix l1_weights = new Matrix ( 3, 2 );
042 private Matrix l2_weights = new Matrix ( 3, 1 );
043 private Matrix l1_weighted_sums = new Matrix ( 2, 1 );
044 private Matrix l2_weighted_sums = new Matrix ( 1, 1 );
045 private double output_desired = 0.0;
046 private double output_error = 0.0;
047 private Matrix l2_local_gradient = new Matrix ( 1, 1 );
048 private double learning_rate = 0.1;
049 private double momentum_constant = 0.9;
050 private Matrix l2_weights_delta = new Matrix ( 3, 1 );
051 private Matrix l2_weights_momentum = new Matrix ( 3, 1 );
052 private Matrix sum_weighted_deltas = new Matrix ( 2, 1 );
053 private Matrix l1_local_gradients = new Matrix ( 2, 1 );
054 private Matrix l1_weights_delta = new Matrix ( 3, 2 );
055 private Matrix l1_weights_momentum = new Matrix ( 3, 2 );
056
057 //
058
059 private long total_samples = 0;
060
061 private int iteration = 0;
062 private double [ ] squared_errors = new double [ ITERATIONS_PER_EPOCH ];
063 private int epoch = 0;
064 private int epochs_max = 1000;
065 private double [ ] epoch_rms_error = new double [ epochs_max ];
066
067 private double [ ] [ ] samples = new double [ ITERATIONS_PER_EPOCH ] [ 3 ];
068
069 //
070
071 private Rectangle r;
072 private Rectangle rs;
073
074 //
075
076 private Thread runner;
077
078 private boolean pleaseStop = false;
079
080 private boolean isPaused;
081
082 //
083
084 private JComboBox functionComboBox;
085
086 private JTextField learning_rate_TextField;
087
088 private JLabel learning_rate_Label;
089
090 private JTextField momentum_TextField;
091
092 private JLabel momentum_Label;
093
094 private JButton randomize_Button;
095
096 private JButton reset_Button;
097
098 private JButton pause_Button;
099
100 //
101
102 private int function_selected;
103
104 private String function_String;
105
106 private int y2;
107
108 private int y4;
109
110 private DecimalFormat decimalFormat
111 = new DecimalFormat ( "0.000000" );
112
113 //////////////////////////////////////////////////////////////////////
114 //////////////////////////////////////////////////////////////////////
115
116 public static void main ( String [ ] args )
117 //////////////////////////////////////////////////////////////////////
118 {
119 JFrame jFrame = new JFrame ( TITLE );
120
121 /*
122 try
123 {
124 jFrame.setIconImage ( ClassLib.getResourceAsImage (
125 BackpropXor.class, FRAME_ICON_FILENAME ) );
126 }
127 catch ( Exception ex )
128 {
129 ex.printStackTrace ( );
130 }
131 */
132
133 BackpropXor backpropXor = new BackpropXor ( );
134
135 jFrame.setContentPane ( backpropXor );
136
137 FrameLib.launchJFrameAsDesktopApp (
138 jFrame,
139 new Lifecycle [ ] { backpropXor },
140 FRAME_SIZE,
141 SHUTDOWN_CONFIRMATION_PROMPT );
142 }
143
144 //////////////////////////////////////////////////////////////////////
145 // interface Lifecycle methods
146 //////////////////////////////////////////////////////////////////////
147
148 public void init ( )
149 //////////////////////////////////////////////////////////////////////
150 {
151 setFont ( new Font ( "Times New Roman", Font.PLAIN, FONT_SIZE ) );
152
153 function_selected = INITIAL_FUNCTION;
154
155 function_String = FUNCTION_NAMES [ function_selected ];
156
157 functionComboBox = new JComboBox ( FUNCTION_NAMES );
158
159 functionComboBox.setSelectedIndex ( function_selected );
160
161 learning_rate_TextField = new JTextField ( "" + learning_rate, 4 );
162
163 learning_rate_Label = new JLabel ( "Learning Rate" );
164
165 momentum_TextField = new JTextField ( "" + momentum_constant, 4 );
166
167 momentum_Label = new JLabel ( "Momentum Factor" );
168
169 add ( functionComboBox );
170 add ( momentum_Label );
171 add ( momentum_TextField );
172 add ( learning_rate_Label );
173 add ( learning_rate_TextField );
174
175 JPanel buttonPanel = new JPanel ( );
176
177 randomize_Button = new JButton ( "Randomize Weights" );
178
179 reset_Button = new JButton ( "Reset Strip" );
180
181 pause_Button = new JButton ( " Pause " );
182
183 buttonPanel.add ( randomize_Button );
184
185 buttonPanel.add ( pause_Button );
186
187 buttonPanel.add ( reset_Button );
188
189 add ( buttonPanel );
190
191 functionComboBox .addItemListener ( this );
192
193 momentum_TextField .addActionListener ( this );
194
195 learning_rate_TextField.addActionListener ( this );
196
197 randomize_Button .addActionListener ( this );
198
199 pause_Button .addActionListener ( this );
200
201 reset_Button .addActionListener ( this );
202
203 r = new Rectangle (
204 10, 10 + Y, SIZE.width - 100, SIZE.height - 350 );
205
206 y2 = 30 + Y + r.height;
207
208 y4 = y2 + YTAB * 8;
209
210 rs = new Rectangle (
211 10 + r.width + 10, 10 + Y,
212 SIZE.height - 350, SIZE.height - 350 );
213
214 randomize_weights ( );
215 }
216
217 public void start ( )
218 //////////////////////////////////////////////////////////////////////
219 {
220 if ( ( runner == null ) && !isPaused )
221 {
222 pleaseStop = false;
223
224 runner = new Thread ( this );
225
226 runner.setPriority ( runner.getPriority ( ) - 1 );
227
228 runner.setDaemon ( true );
229
230 runner.start ( );
231 }
232 }
233
234 public void stop ( )
235 //////////////////////////////////////////////////////////////////////
236 {
237 pleaseStop = true;
238 }
239
240 public void destroy ( )
241 //////////////////////////////////////////////////////////////////////
242 {
243 stop ( );
244 }
245
246 //////////////////////////////////////////////////////////////////////
247 //////////////////////////////////////////////////////////////////////
248
249 public void itemStateChanged ( ItemEvent itemEvent )
250 //////////////////////////////////////////////////////////////////////
251 {
252 Object source = itemEvent.getSource ( );
253
254 if ( source == functionComboBox )
255 {
256 function_selected = functionComboBox.getSelectedIndex ( );
257
258 function_String = ( String ) functionComboBox.getSelectedItem ( );
259 }
260 }
261
262 public void actionPerformed ( ActionEvent actionEvent )
263 //////////////////////////////////////////////////////////////////////
264 {
265 Object source = actionEvent.getSource ( );
266
267 if ( source == momentum_TextField )
268 {
269 momentum_constant = Double.valueOf (
270 momentum_TextField.getText ( ) ).doubleValue ( );
271 }
272 else if ( source == learning_rate_TextField )
273 {
274 learning_rate = Double.valueOf (
275 learning_rate_TextField.getText ( ) ).doubleValue ( );
276 }
277 else if ( source == randomize_Button )
278 {
279 randomize_weights ( );
280 }
281 else if ( source == reset_Button )
282 {
283 total_samples = 0;
284
285 iteration = 0;
286
287 epoch = 0;
288 }
289 else if ( source == pause_Button )
290 {
291 if ( isPaused )
292 {
293 isPaused = false;
294
295 pause_Button.setText ( " Pause " );
296
297 start ( );
298 }
299 else
300 {
301 isPaused = true;
302
303 pause_Button.setText ( "Continue" );
304
305 stop ( );
306 }
307 }
308 }
309
310 public void paintComponent ( Graphics g )
311 //////////////////////////////////////////////////////////////////////
312 {
313 super.paintComponent ( g );
314
315 plot_epochs ( r, g, epoch_rms_error );
316
317 plotGraph ( g, samples );
318
319 g.drawString ( "Inputs", 10 + 0 * XTAB, 10 + 0 * YTAB + y2 );
320 g.drawString ( decimalFormat.format ( l0_activations.data [ 0 ] [ 0 ] ), 10 + 0 * XTAB, 10 + 1 * YTAB + y2 );
321 g.drawString ( decimalFormat.format ( l0_activations.data [ 1 ] [ 0 ] ), 10 + 0 * XTAB, 10 + 2 * YTAB + y2 );
322 g.drawString ( decimalFormat.format ( l0_activations.data [ 2 ] [ 0 ] ), 10 + 0 * XTAB, 10 + 3 * YTAB + y2 );
323
324 g.drawString ( "Weights", 10 + 1 * XTAB, 10 + 0 * YTAB + y2 );
325 g.drawString ( decimalFormat.format ( l1_weights.data [ 0 ] [ 0 ] ), 10 + 1 * XTAB, 10 + 1 * YTAB + y2 );
326 g.drawString ( decimalFormat.format ( l1_weights.data [ 1 ] [ 0 ] ), 10 + 1 * XTAB, 10 + 2 * YTAB + y2 );
327 g.drawString ( decimalFormat.format ( l1_weights.data [ 2 ] [ 0 ] ), 10 + 1 * XTAB, 10 + 3 * YTAB + y2 );
328 g.drawString ( decimalFormat.format ( l1_weights.data [ 0 ] [ 1 ] ), 10 + 1 * XTAB, 10 + 4 * YTAB + y2 );
329 g.drawString ( decimalFormat.format ( l1_weights.data [ 1 ] [ 1 ] ), 10 + 1 * XTAB, 10 + 5 * YTAB + y2 );
330 g.drawString ( decimalFormat.format ( l1_weights.data [ 2 ] [ 1 ] ), 10 + 1 * XTAB, 10 + 6 * YTAB + y2 );
331
332 g.drawString ( "Weighted Sums", 10 + 2 * XTAB, 10 + 0 * YTAB + y2 );
333 g.drawString ( decimalFormat.format ( l1_weighted_sums.data [ 0 ] [ 0 ] ), 10 + 2 * XTAB, 10 + 1 * YTAB + y2 );
334 g.drawString ( decimalFormat.format ( l1_weighted_sums.data [ 1 ] [ 0 ] ), 10 + 2 * XTAB, 10 + 2 * YTAB + y2 );
335
336 g.drawString ( "Hidden", 10 + 3 * XTAB, 10 + 0 * YTAB + y2 );
337 g.drawString ( decimalFormat.format ( l2_inputs.data [ 0 ] [ 0 ] ), 10 + 3 * XTAB, 10 + 1 * YTAB + y2 );
338 g.drawString ( decimalFormat.format ( l2_inputs.data [ 1 ] [ 0 ] ), 10 + 3 * XTAB, 10 + 2 * YTAB + y2 );
339 g.drawString ( decimalFormat.format ( l2_inputs.data [ 2 ] [ 0 ] ), 10 + 3 * XTAB, 10 + 3 * YTAB + y2 );
340
341 g.drawString ( "Weights", 10 + 4 * XTAB, 10 + 0 * YTAB + y2 );
342 g.drawString ( decimalFormat.format ( l2_weights.data [ 0 ] [ 0 ] ), 10 + 4 * XTAB, 10 + 1 * YTAB + y2 );
343 g.drawString ( decimalFormat.format ( l2_weights.data [ 1 ] [ 0 ] ), 10 + 4 * XTAB, 10 + 2 * YTAB + y2 );
344 g.drawString ( decimalFormat.format ( l2_weights.data [ 2 ] [ 0 ] ), 10 + 4 * XTAB, 10 + 3 * YTAB + y2 );
345
346 g.drawString ( "Weighted Sum", 10 + 5 * XTAB, 10 + 0 * YTAB + y2 );
347 g.drawString ( decimalFormat.format ( l2_weighted_sums.data [ 0 ] [ 0 ] ), 10 + 5 * XTAB, 10 + 1 * YTAB + y2 );
348
349 g.drawString ( "Output", 10 + 6 * XTAB, 10 + 0 * YTAB + y2 );
350 g.drawString ( decimalFormat.format ( l2_activations.data [ 0 ] [ 0 ] ), 10 + 6 * XTAB, 10 + 1 * YTAB + y2 );
351
352 g.drawString ( "Output Desired", 10 + 6 * XTAB, 10 + 2 * YTAB + y2 );
353 g.drawString ( decimalFormat.format ( output_desired ), 10 + 6 * XTAB, 10 + 3 * YTAB + y2 );
354
355 g.drawString ( "Output Error", 10 + 6 * XTAB, 10 + 4 * YTAB + y2 );
356 g.drawString ( decimalFormat.format ( output_error ), 10 + 6 * XTAB, 10 + 5 * YTAB + y2 );
357
358 g.drawString ( "Iterations", 10 + 7 * XTAB, 10 + 0 * YTAB + y2 );
359 g.drawString ( "" + total_samples, 10 + 7 * XTAB, 10 + 1 * YTAB + y2 );
360
361 g.drawString ( "Function", 10 + 7 * XTAB, 10 + 2 * YTAB + y2 );
362 g.drawString ( function_String, 10 + 7 * XTAB, 10 + 3 * YTAB + y2 );
363
364 g.drawString ( "L2 Gradient", 10 + 0 * XTAB, 0 * YTAB + y4 );
365 g.drawString ( decimalFormat.format ( l2_local_gradient.data [ 0 ] [ 0 ] ), 10 + 0 * XTAB, 1 * YTAB + y4 );
366
367 g.drawString ( "W2 Deltas", 10 + 1 * XTAB, 0 * YTAB + y4 );
368 g.drawString ( decimalFormat.format ( l2_weights_delta.data [ 0 ] [ 0 ] ), 10 + 1 * XTAB, 1 * YTAB + y4 );
369 g.drawString ( decimalFormat.format ( l2_weights_delta.data [ 1 ] [ 0 ] ), 10 + 1 * XTAB, 2 * YTAB + y4 );
370 g.drawString ( decimalFormat.format ( l2_weights_delta.data [ 2 ] [ 0 ] ), 10 + 1 * XTAB, 3 * YTAB + y4 );
371
372 g.drawString ( "W2 Momentum", 10 + 2 * XTAB, 0 * YTAB + y4 );
373 g.drawString ( decimalFormat.format ( l2_weights_momentum.data [ 0 ] [ 0 ] ), 10 + 2 * XTAB, 1 * YTAB + y4 );
374 g.drawString ( decimalFormat.format ( l2_weights_momentum.data [ 1 ] [ 0 ] ), 10 + 2 * XTAB, 2 * YTAB + y4 );
375 g.drawString ( decimalFormat.format ( l2_weights_momentum.data [ 2 ] [ 0 ] ), 10 + 2 * XTAB, 3 * YTAB + y4 );
376
377 g.drawString ( "Sum W Deltas", 10 + 3 * XTAB, 0 * YTAB + y4 );
378 g.drawString ( decimalFormat.format ( sum_weighted_deltas.data [ 0 ] [ 0 ] ), 10 + 3 * XTAB, 1 * YTAB + y4 );
379 g.drawString ( decimalFormat.format ( sum_weighted_deltas.data [ 1 ] [ 0 ] ), 10 + 3 * XTAB, 2 * YTAB + y4 );
380
381 g.drawString ( "L1 Gradients", 10 + 4 * XTAB, 0 * YTAB + y4 );
382 g.drawString ( decimalFormat.format ( l1_local_gradients.data [ 0 ] [ 0 ] ), 10 + 4 * XTAB, 1 * YTAB + y4 );
383 g.drawString ( decimalFormat.format ( l1_local_gradients.data [ 1 ] [ 0 ] ), 10 + 4 * XTAB, 2 * YTAB + y4 );
384
385 g.drawString ( "W1 Deltas", 10 + 5 * XTAB, 0 * YTAB + y4 );
386 g.drawString ( decimalFormat.format ( l1_weights_delta.data [ 0 ] [ 0 ] ), 10 + 5 * XTAB, 1 * YTAB + y4 );
387 g.drawString ( decimalFormat.format ( l1_weights_delta.data [ 1 ] [ 0 ] ), 10 + 5 * XTAB, 2 * YTAB + y4 );
388 g.drawString ( decimalFormat.format ( l1_weights_delta.data [ 2 ] [ 0 ] ), 10 + 5 * XTAB, 3 * YTAB + y4 );
389 g.drawString ( decimalFormat.format ( l1_weights_delta.data [ 0 ] [ 1 ] ), 10 + 5 * XTAB, 4 * YTAB + y4 );
390 g.drawString ( decimalFormat.format ( l1_weights_delta.data [ 1 ] [ 1 ] ), 10 + 5 * XTAB, 5 * YTAB + y4 );
391 g.drawString ( decimalFormat.format ( l1_weights_delta.data [ 2 ] [ 1 ] ), 10 + 5 * XTAB, 6 * YTAB + y4 );
392
393 g.drawString ( "W1 Momentum", 10 + 6 * XTAB, 0 * YTAB + y4 );
394 g.drawString ( decimalFormat.format ( l1_weights_momentum.data [ 0 ] [ 0 ] ), 10 + 6 * XTAB, 1 * YTAB + y4 );
395 g.drawString ( decimalFormat.format ( l1_weights_momentum.data [ 1 ] [ 0 ] ), 10 + 6 * XTAB, 2 * YTAB + y4 );
396 g.drawString ( decimalFormat.format ( l1_weights_momentum.data [ 2 ] [ 0 ] ), 10 + 6 * XTAB, 3 * YTAB + y4 );
397 g.drawString ( decimalFormat.format ( l1_weights_momentum.data [ 0 ] [ 1 ] ), 10 + 6 * XTAB, 4 * YTAB + y4 );
398 g.drawString ( decimalFormat.format ( l1_weights_momentum.data [ 1 ] [ 1 ] ), 10 + 6 * XTAB, 5 * YTAB + y4 );
399 g.drawString ( decimalFormat.format ( l1_weights_momentum.data [ 2 ] [ 1 ] ), 10 + 6 * XTAB, 6 * YTAB + y4 );
400 }
401
402 public void run ( )
403 //////////////////////////////////////////////////////////////////////
404 {
405 try
406 {
407
408 long lastRepaintTime = 0;
409
410 while ( !pleaseStop )
411 {
412 if ( iteration == 0 )
413 {
414 long currentTime = System.currentTimeMillis ( );
415
416 if ( currentTime >= lastRepaintTime + REPAINT_PERIOD )
417 {
418 lastRepaintTime = currentTime;
419
420 repaint ( );
421 }
422 }
423
424 l2_weights = l2_weights.add ( l2_weights_delta );
425
426 // l2_weights = l2_weights.clip ( -1.0, 1.0 );
427
428 l1_weights = l1_weights.add ( l1_weights_delta );
429
430 // l1_weights = l1_weights.clip ( -1.0, 1.0 );
431
432 l0_activations = l0_activations.randomizeUniform ( 0.0, 1.0 );
433
434 l0_activations.data [ 0 ] [ 0 ] = 1.0;
435
436 samples [ iteration ] [ 0 ] = l0_activations.data [ 1 ] [ 0 ];
437
438 samples [ iteration ] [ 1 ] = l0_activations.data [ 2 ] [ 0 ];
439
440 l1_weighted_sums = Matrix.multiply (
441 l1_weights.transpose ( ), l0_activations );
442
443 l1_activations = l1_weighted_sums.sigmoid ( );
444
445 l2_inputs.data [ 0 ] [ 0 ] = 1.0;
446
447 l2_inputs.data [ 1 ] [ 0 ] = l1_activations.data [ 0 ] [ 0 ];
448
449 l2_inputs.data [ 2 ] [ 0 ] = l1_activations.data [ 1 ] [ 0 ];
450
451 l2_weighted_sums = Matrix.multiply (
452 l2_weights.transpose ( ), l2_inputs );
453
454 l2_activations = l2_weighted_sums.sigmoid ( );
455
456 samples [ iteration ] [ 2 ] = l2_activations.data [ 0 ] [ 0 ];
457
458 output_desired = target_function ( l0_activations );
459
460 output_error = output_desired - l2_activations.data [ 0 ] [ 0 ];
461
462 l2_local_gradient = l2_weighted_sums.sigmoidDerivative ( );
463
464 l2_local_gradient = l2_local_gradient.multiply ( output_error );
465
466 // I'm not sure about the transpose below here.
467
468 l2_weights_delta
469 = Matrix.multiply ( l2_inputs, l2_local_gradient.transpose ( ) );
470
471 l2_weights_delta = l2_weights_delta.multiply ( learning_rate );
472
473 l2_weights_delta = l2_weights_delta.add ( l2_weights_momentum );
474
475 l2_weights_momentum
476 = l2_weights_delta.multiply ( momentum_constant );
477
478 // needs transpose or sum below?
479
480 sum_weighted_deltas = l2_weights.submatrix ( 1, 2, 0, 0 );
481
482 sum_weighted_deltas = Matrix.multiply (
483 sum_weighted_deltas, l2_local_gradient );
484
485 l1_local_gradients = new Matrix (
486 l1_activations.rows, l1_activations.cols, 1.0 );
487
488 l1_local_gradients = l1_local_gradients.subtract ( l1_activations );
489
490 l1_local_gradients = Matrix.multiplyPairwise (
491 l1_activations, l1_local_gradients );
492
493 l1_local_gradients = Matrix.multiplyPairwise (
494 l1_local_gradients, sum_weighted_deltas );
495
496 l1_weights_delta = Matrix.multiply (
497 l0_activations, l1_local_gradients.transpose ( ) );
498
499 l1_weights_delta = l1_weights_delta.multiply ( learning_rate );
500
501 l1_weights_delta = l1_weights_delta.add ( l1_weights_momentum );
502
503 l1_weights_momentum
504 = l1_weights_delta.multiply ( momentum_constant );
505
506 squared_errors [ iteration ] = output_error * output_error;
507
508 iteration++;
509
510 total_samples++;
511
512 if ( iteration % ITERATIONS_PER_EPOCH == 0 )
513 {
514 epoch++;
515
516 if ( epoch >= epochs_max ) epoch = 1;
517
518 iteration = 0;
519
520 epoch_rms_error [ epoch - 1 ] = 0.0;
521
522 for ( int index_iteration = 0;
523 index_iteration < ITERATIONS_PER_EPOCH;
524 index_iteration++ )
525 {
526 epoch_rms_error [ epoch - 1 ]
527 += squared_errors [ index_iteration ];
528 }
529
530 epoch_rms_error [ epoch - 1 ] /= ( double ) ITERATIONS_PER_EPOCH;
531
532 epoch_rms_error [ epoch - 1 ]
533 = Math.sqrt ( epoch_rms_error [ epoch - 1 ] );
534 }
535
536 Thread.sleep ( 0 );
537 }
538
539 repaint ( );
540
541 }
542 catch ( Exception ex )
543 {
544 ex.printStackTrace ( );
545 }
546 finally
547 {
548 runner = null;
549 }
550 }
551
552 //////////////////////////////////////////////////////////////////////
553 // private methods
554 //////////////////////////////////////////////////////////////////////
555
556 private void plot_epochs (
557 Rectangle r, Graphics g, double [ ] epochs )
558 //////////////////////////////////////////////////////////////////////
559 {
560 g.setColor ( java.awt.Color.black );
561 g.fillRect ( r.x, r.y, r.width, r.height );
562 g.setColor ( java.awt.Color.white );
563 g.drawRect ( r.x, r.y, r.width, r.height / 2 );
564 // g.clipRect ( r.x, r.y, r.width, r.height );
565 for ( int index_epoch = 1;
566 index_epoch <= epoch;
567 index_epoch++ )
568 {
569 PlotLib.xy ( java.awt.Color.red,
570 ( double ) index_epoch, epochs [ index_epoch - 1 ],
571 r, g, 1.0, ( double ) epoch, 0.0, 1.0, OVAL_SIZE, true );
572 }
573 // g.clipRect ( 0, 0, SIZE.width, SIZE.height );
574 g.setColor ( java.awt.Color.white );
575 g.drawRect ( r.x, r.y, r.width, r.height );
576 g.setColor ( this.getForeground ( ) );
577 }
578
579 private void plotGraph ( Graphics g, double [ ] [ ] samples )
580 //////////////////////////////////////////////////////////////////////
581 {
582 g.setColor ( Color.black );
583
584 g.fillRect ( rs.x, rs.y, rs.width, rs.height );
585
586 for ( int index_iteration = 0;
587 index_iteration < ITERATIONS_PER_EPOCH;
588 index_iteration++ )
589 {
590 PlotLib.xy (
591 samples [ index_iteration ] [ 2 ] >= 0.5
592 ? Color.green : Color.red,
593 samples [ index_iteration ] [ 0 ],
594 samples [ index_iteration ] [ 1 ],
595 rs, g, 0.0, 1.0, 0.0, 1.0, OVAL_SIZE );
596 }
597
598 g.setColor ( Color.white );
599
600 g.drawRect ( rs.x, rs.y, rs.width, rs.height );
601
602 g.setColor ( this.getForeground ( ) );
603 }
604
605 private void randomize_weights ( )
606 //////////////////////////////////////////////////////////////////////
607 {
608 l1_weights = l1_weights.randomizeUniform ( -1.0, 1.0 );
609
610 l2_weights = l2_weights.randomizeUniform ( -1.0, 1.0 );
611 }
612
613 private double target_function ( Matrix inputs )
614 //////////////////////////////////////////////////////////////////////
615 {
616 long a = Math.round ( inputs.data [ 1 ] [ 0 ] );
617
618 long b = Math.round ( inputs.data [ 2 ] [ 0 ] );
619
620 long bit_num = 2 * b + a;
621
622 long mask = 1 << bit_num;
623
624 long masked = function_selected & mask;
625
626 return ( masked == mask ) ? 1.0 : 0.0;
627 }
628
629 //////////////////////////////////////////////////////////////////////
630 //////////////////////////////////////////////////////////////////////
631 }