001 package com.croftsoft.apps.backpropxor; 002 003 import java.awt.Choice; 004 import java.awt.Color; 005 import java.awt.Font; 006 import java.awt.Graphics; 007 import java.awt.Rectangle; 008 import java.awt.event.*; 009 import java.text.DecimalFormat; 010 import javax.swing.*; 011 012 import com.croftsoft.core.gui.ShutdownWindowListener; 013 import com.croftsoft.core.gui.FrameLib; 014 import com.croftsoft.core.gui.plot.PlotLib; 015 import com.croftsoft.core.lang.lifecycle.Lifecycle; 016 import com.croftsoft.core.math.*; 017 018 /********************************************************************* 019 * A Backpropagation neural network learning algorithm demonstration. 020 * 021 * @version 022 * 2002-02-28 023 * @since 024 * 1996 025 * @author 026 * <a href="http://www.alumni.caltech.edu/~croft/">David W. Croft</a> 027 *********************************************************************/ 028 029 public class BackpropXor 030 extends JPanel 031 implements ActionListener, ItemListener, Lifecycle, Runnable, 032 BackpropXorConstants 033 ////////////////////////////////////////////////////////////////////// 034 ////////////////////////////////////////////////////////////////////// 035 { 036 037 private Matrix l0_activations = new Matrix ( 3, 1 ); 038 private Matrix l1_activations = new Matrix ( 2, 1 ); 039 private Matrix l2_inputs = new Matrix ( 3, 1 ); 040 private Matrix l2_activations = new Matrix ( 1, 1 ); 041 private Matrix l1_weights = new Matrix ( 3, 2 ); 042 private Matrix l2_weights = new Matrix ( 3, 1 ); 043 private Matrix l1_weighted_sums = new Matrix ( 2, 1 ); 044 private Matrix l2_weighted_sums = new Matrix ( 1, 1 ); 045 private double output_desired = 0.0; 046 private double output_error = 0.0; 047 private Matrix l2_local_gradient = new Matrix ( 1, 1 ); 048 private double learning_rate = 0.1; 049 private double momentum_constant = 0.9; 050 private Matrix l2_weights_delta = new Matrix ( 3, 1 ); 051 private Matrix l2_weights_momentum = new Matrix ( 3, 1 ); 052 private Matrix sum_weighted_deltas = new Matrix ( 2, 1 ); 053 private Matrix l1_local_gradients = new Matrix ( 2, 1 ); 054 private Matrix l1_weights_delta = new Matrix ( 3, 2 ); 055 private Matrix l1_weights_momentum = new Matrix ( 3, 2 ); 056 057 // 058 059 private long total_samples = 0; 060 061 private int iteration = 0; 062 private double [ ] squared_errors = new double [ ITERATIONS_PER_EPOCH ]; 063 private int epoch = 0; 064 private int epochs_max = 1000; 065 private double [ ] epoch_rms_error = new double [ epochs_max ]; 066 067 private double [ ] [ ] samples = new double [ ITERATIONS_PER_EPOCH ] [ 3 ]; 068 069 // 070 071 private Rectangle r; 072 private Rectangle rs; 073 074 // 075 076 private Thread runner; 077 078 private boolean pleaseStop = false; 079 080 private boolean isPaused; 081 082 // 083 084 private JComboBox functionComboBox; 085 086 private JTextField learning_rate_TextField; 087 088 private JLabel learning_rate_Label; 089 090 private JTextField momentum_TextField; 091 092 private JLabel momentum_Label; 093 094 private JButton randomize_Button; 095 096 private JButton reset_Button; 097 098 private JButton pause_Button; 099 100 // 101 102 private int function_selected; 103 104 private String function_String; 105 106 private int y2; 107 108 private int y4; 109 110 private DecimalFormat decimalFormat 111 = new DecimalFormat ( "0.000000" ); 112 113 ////////////////////////////////////////////////////////////////////// 114 ////////////////////////////////////////////////////////////////////// 115 116 public static void main ( String [ ] args ) 117 ////////////////////////////////////////////////////////////////////// 118 { 119 JFrame jFrame = new JFrame ( TITLE ); 120 121 /* 122 try 123 { 124 jFrame.setIconImage ( ClassLib.getResourceAsImage ( 125 BackpropXor.class, FRAME_ICON_FILENAME ) ); 126 } 127 catch ( Exception ex ) 128 { 129 ex.printStackTrace ( ); 130 } 131 */ 132 133 BackpropXor backpropXor = new BackpropXor ( ); 134 135 jFrame.setContentPane ( backpropXor ); 136 137 FrameLib.launchJFrameAsDesktopApp ( 138 jFrame, 139 new Lifecycle [ ] { backpropXor }, 140 FRAME_SIZE, 141 SHUTDOWN_CONFIRMATION_PROMPT ); 142 } 143 144 ////////////////////////////////////////////////////////////////////// 145 // interface Lifecycle methods 146 ////////////////////////////////////////////////////////////////////// 147 148 public void init ( ) 149 ////////////////////////////////////////////////////////////////////// 150 { 151 setFont ( new Font ( "Times New Roman", Font.PLAIN, FONT_SIZE ) ); 152 153 function_selected = INITIAL_FUNCTION; 154 155 function_String = FUNCTION_NAMES [ function_selected ]; 156 157 functionComboBox = new JComboBox ( FUNCTION_NAMES ); 158 159 functionComboBox.setSelectedIndex ( function_selected ); 160 161 learning_rate_TextField = new JTextField ( "" + learning_rate, 4 ); 162 163 learning_rate_Label = new JLabel ( "Learning Rate" ); 164 165 momentum_TextField = new JTextField ( "" + momentum_constant, 4 ); 166 167 momentum_Label = new JLabel ( "Momentum Factor" ); 168 169 add ( functionComboBox ); 170 add ( momentum_Label ); 171 add ( momentum_TextField ); 172 add ( learning_rate_Label ); 173 add ( learning_rate_TextField ); 174 175 JPanel buttonPanel = new JPanel ( ); 176 177 randomize_Button = new JButton ( "Randomize Weights" ); 178 179 reset_Button = new JButton ( "Reset Strip" ); 180 181 pause_Button = new JButton ( " Pause " ); 182 183 buttonPanel.add ( randomize_Button ); 184 185 buttonPanel.add ( pause_Button ); 186 187 buttonPanel.add ( reset_Button ); 188 189 add ( buttonPanel ); 190 191 functionComboBox .addItemListener ( this ); 192 193 momentum_TextField .addActionListener ( this ); 194 195 learning_rate_TextField.addActionListener ( this ); 196 197 randomize_Button .addActionListener ( this ); 198 199 pause_Button .addActionListener ( this ); 200 201 reset_Button .addActionListener ( this ); 202 203 r = new Rectangle ( 204 10, 10 + Y, SIZE.width - 100, SIZE.height - 350 ); 205 206 y2 = 30 + Y + r.height; 207 208 y4 = y2 + YTAB * 8; 209 210 rs = new Rectangle ( 211 10 + r.width + 10, 10 + Y, 212 SIZE.height - 350, SIZE.height - 350 ); 213 214 randomize_weights ( ); 215 } 216 217 public void start ( ) 218 ////////////////////////////////////////////////////////////////////// 219 { 220 if ( ( runner == null ) && !isPaused ) 221 { 222 pleaseStop = false; 223 224 runner = new Thread ( this ); 225 226 runner.setPriority ( runner.getPriority ( ) - 1 ); 227 228 runner.setDaemon ( true ); 229 230 runner.start ( ); 231 } 232 } 233 234 public void stop ( ) 235 ////////////////////////////////////////////////////////////////////// 236 { 237 pleaseStop = true; 238 } 239 240 public void destroy ( ) 241 ////////////////////////////////////////////////////////////////////// 242 { 243 stop ( ); 244 } 245 246 ////////////////////////////////////////////////////////////////////// 247 ////////////////////////////////////////////////////////////////////// 248 249 public void itemStateChanged ( ItemEvent itemEvent ) 250 ////////////////////////////////////////////////////////////////////// 251 { 252 Object source = itemEvent.getSource ( ); 253 254 if ( source == functionComboBox ) 255 { 256 function_selected = functionComboBox.getSelectedIndex ( ); 257 258 function_String = ( String ) functionComboBox.getSelectedItem ( ); 259 } 260 } 261 262 public void actionPerformed ( ActionEvent actionEvent ) 263 ////////////////////////////////////////////////////////////////////// 264 { 265 Object source = actionEvent.getSource ( ); 266 267 if ( source == momentum_TextField ) 268 { 269 momentum_constant = Double.valueOf ( 270 momentum_TextField.getText ( ) ).doubleValue ( ); 271 } 272 else if ( source == learning_rate_TextField ) 273 { 274 learning_rate = Double.valueOf ( 275 learning_rate_TextField.getText ( ) ).doubleValue ( ); 276 } 277 else if ( source == randomize_Button ) 278 { 279 randomize_weights ( ); 280 } 281 else if ( source == reset_Button ) 282 { 283 total_samples = 0; 284 285 iteration = 0; 286 287 epoch = 0; 288 } 289 else if ( source == pause_Button ) 290 { 291 if ( isPaused ) 292 { 293 isPaused = false; 294 295 pause_Button.setText ( " Pause " ); 296 297 start ( ); 298 } 299 else 300 { 301 isPaused = true; 302 303 pause_Button.setText ( "Continue" ); 304 305 stop ( ); 306 } 307 } 308 } 309 310 public void paintComponent ( Graphics g ) 311 ////////////////////////////////////////////////////////////////////// 312 { 313 super.paintComponent ( g ); 314 315 plot_epochs ( r, g, epoch_rms_error ); 316 317 plotGraph ( g, samples ); 318 319 g.drawString ( "Inputs", 10 + 0 * XTAB, 10 + 0 * YTAB + y2 ); 320 g.drawString ( decimalFormat.format ( l0_activations.data [ 0 ] [ 0 ] ), 10 + 0 * XTAB, 10 + 1 * YTAB + y2 ); 321 g.drawString ( decimalFormat.format ( l0_activations.data [ 1 ] [ 0 ] ), 10 + 0 * XTAB, 10 + 2 * YTAB + y2 ); 322 g.drawString ( decimalFormat.format ( l0_activations.data [ 2 ] [ 0 ] ), 10 + 0 * XTAB, 10 + 3 * YTAB + y2 ); 323 324 g.drawString ( "Weights", 10 + 1 * XTAB, 10 + 0 * YTAB + y2 ); 325 g.drawString ( decimalFormat.format ( l1_weights.data [ 0 ] [ 0 ] ), 10 + 1 * XTAB, 10 + 1 * YTAB + y2 ); 326 g.drawString ( decimalFormat.format ( l1_weights.data [ 1 ] [ 0 ] ), 10 + 1 * XTAB, 10 + 2 * YTAB + y2 ); 327 g.drawString ( decimalFormat.format ( l1_weights.data [ 2 ] [ 0 ] ), 10 + 1 * XTAB, 10 + 3 * YTAB + y2 ); 328 g.drawString ( decimalFormat.format ( l1_weights.data [ 0 ] [ 1 ] ), 10 + 1 * XTAB, 10 + 4 * YTAB + y2 ); 329 g.drawString ( decimalFormat.format ( l1_weights.data [ 1 ] [ 1 ] ), 10 + 1 * XTAB, 10 + 5 * YTAB + y2 ); 330 g.drawString ( decimalFormat.format ( l1_weights.data [ 2 ] [ 1 ] ), 10 + 1 * XTAB, 10 + 6 * YTAB + y2 ); 331 332 g.drawString ( "Weighted Sums", 10 + 2 * XTAB, 10 + 0 * YTAB + y2 ); 333 g.drawString ( decimalFormat.format ( l1_weighted_sums.data [ 0 ] [ 0 ] ), 10 + 2 * XTAB, 10 + 1 * YTAB + y2 ); 334 g.drawString ( decimalFormat.format ( l1_weighted_sums.data [ 1 ] [ 0 ] ), 10 + 2 * XTAB, 10 + 2 * YTAB + y2 ); 335 336 g.drawString ( "Hidden", 10 + 3 * XTAB, 10 + 0 * YTAB + y2 ); 337 g.drawString ( decimalFormat.format ( l2_inputs.data [ 0 ] [ 0 ] ), 10 + 3 * XTAB, 10 + 1 * YTAB + y2 ); 338 g.drawString ( decimalFormat.format ( l2_inputs.data [ 1 ] [ 0 ] ), 10 + 3 * XTAB, 10 + 2 * YTAB + y2 ); 339 g.drawString ( decimalFormat.format ( l2_inputs.data [ 2 ] [ 0 ] ), 10 + 3 * XTAB, 10 + 3 * YTAB + y2 ); 340 341 g.drawString ( "Weights", 10 + 4 * XTAB, 10 + 0 * YTAB + y2 ); 342 g.drawString ( decimalFormat.format ( l2_weights.data [ 0 ] [ 0 ] ), 10 + 4 * XTAB, 10 + 1 * YTAB + y2 ); 343 g.drawString ( decimalFormat.format ( l2_weights.data [ 1 ] [ 0 ] ), 10 + 4 * XTAB, 10 + 2 * YTAB + y2 ); 344 g.drawString ( decimalFormat.format ( l2_weights.data [ 2 ] [ 0 ] ), 10 + 4 * XTAB, 10 + 3 * YTAB + y2 ); 345 346 g.drawString ( "Weighted Sum", 10 + 5 * XTAB, 10 + 0 * YTAB + y2 ); 347 g.drawString ( decimalFormat.format ( l2_weighted_sums.data [ 0 ] [ 0 ] ), 10 + 5 * XTAB, 10 + 1 * YTAB + y2 ); 348 349 g.drawString ( "Output", 10 + 6 * XTAB, 10 + 0 * YTAB + y2 ); 350 g.drawString ( decimalFormat.format ( l2_activations.data [ 0 ] [ 0 ] ), 10 + 6 * XTAB, 10 + 1 * YTAB + y2 ); 351 352 g.drawString ( "Output Desired", 10 + 6 * XTAB, 10 + 2 * YTAB + y2 ); 353 g.drawString ( decimalFormat.format ( output_desired ), 10 + 6 * XTAB, 10 + 3 * YTAB + y2 ); 354 355 g.drawString ( "Output Error", 10 + 6 * XTAB, 10 + 4 * YTAB + y2 ); 356 g.drawString ( decimalFormat.format ( output_error ), 10 + 6 * XTAB, 10 + 5 * YTAB + y2 ); 357 358 g.drawString ( "Iterations", 10 + 7 * XTAB, 10 + 0 * YTAB + y2 ); 359 g.drawString ( "" + total_samples, 10 + 7 * XTAB, 10 + 1 * YTAB + y2 ); 360 361 g.drawString ( "Function", 10 + 7 * XTAB, 10 + 2 * YTAB + y2 ); 362 g.drawString ( function_String, 10 + 7 * XTAB, 10 + 3 * YTAB + y2 ); 363 364 g.drawString ( "L2 Gradient", 10 + 0 * XTAB, 0 * YTAB + y4 ); 365 g.drawString ( decimalFormat.format ( l2_local_gradient.data [ 0 ] [ 0 ] ), 10 + 0 * XTAB, 1 * YTAB + y4 ); 366 367 g.drawString ( "W2 Deltas", 10 + 1 * XTAB, 0 * YTAB + y4 ); 368 g.drawString ( decimalFormat.format ( l2_weights_delta.data [ 0 ] [ 0 ] ), 10 + 1 * XTAB, 1 * YTAB + y4 ); 369 g.drawString ( decimalFormat.format ( l2_weights_delta.data [ 1 ] [ 0 ] ), 10 + 1 * XTAB, 2 * YTAB + y4 ); 370 g.drawString ( decimalFormat.format ( l2_weights_delta.data [ 2 ] [ 0 ] ), 10 + 1 * XTAB, 3 * YTAB + y4 ); 371 372 g.drawString ( "W2 Momentum", 10 + 2 * XTAB, 0 * YTAB + y4 ); 373 g.drawString ( decimalFormat.format ( l2_weights_momentum.data [ 0 ] [ 0 ] ), 10 + 2 * XTAB, 1 * YTAB + y4 ); 374 g.drawString ( decimalFormat.format ( l2_weights_momentum.data [ 1 ] [ 0 ] ), 10 + 2 * XTAB, 2 * YTAB + y4 ); 375 g.drawString ( decimalFormat.format ( l2_weights_momentum.data [ 2 ] [ 0 ] ), 10 + 2 * XTAB, 3 * YTAB + y4 ); 376 377 g.drawString ( "Sum W Deltas", 10 + 3 * XTAB, 0 * YTAB + y4 ); 378 g.drawString ( decimalFormat.format ( sum_weighted_deltas.data [ 0 ] [ 0 ] ), 10 + 3 * XTAB, 1 * YTAB + y4 ); 379 g.drawString ( decimalFormat.format ( sum_weighted_deltas.data [ 1 ] [ 0 ] ), 10 + 3 * XTAB, 2 * YTAB + y4 ); 380 381 g.drawString ( "L1 Gradients", 10 + 4 * XTAB, 0 * YTAB + y4 ); 382 g.drawString ( decimalFormat.format ( l1_local_gradients.data [ 0 ] [ 0 ] ), 10 + 4 * XTAB, 1 * YTAB + y4 ); 383 g.drawString ( decimalFormat.format ( l1_local_gradients.data [ 1 ] [ 0 ] ), 10 + 4 * XTAB, 2 * YTAB + y4 ); 384 385 g.drawString ( "W1 Deltas", 10 + 5 * XTAB, 0 * YTAB + y4 ); 386 g.drawString ( decimalFormat.format ( l1_weights_delta.data [ 0 ] [ 0 ] ), 10 + 5 * XTAB, 1 * YTAB + y4 ); 387 g.drawString ( decimalFormat.format ( l1_weights_delta.data [ 1 ] [ 0 ] ), 10 + 5 * XTAB, 2 * YTAB + y4 ); 388 g.drawString ( decimalFormat.format ( l1_weights_delta.data [ 2 ] [ 0 ] ), 10 + 5 * XTAB, 3 * YTAB + y4 ); 389 g.drawString ( decimalFormat.format ( l1_weights_delta.data [ 0 ] [ 1 ] ), 10 + 5 * XTAB, 4 * YTAB + y4 ); 390 g.drawString ( decimalFormat.format ( l1_weights_delta.data [ 1 ] [ 1 ] ), 10 + 5 * XTAB, 5 * YTAB + y4 ); 391 g.drawString ( decimalFormat.format ( l1_weights_delta.data [ 2 ] [ 1 ] ), 10 + 5 * XTAB, 6 * YTAB + y4 ); 392 393 g.drawString ( "W1 Momentum", 10 + 6 * XTAB, 0 * YTAB + y4 ); 394 g.drawString ( decimalFormat.format ( l1_weights_momentum.data [ 0 ] [ 0 ] ), 10 + 6 * XTAB, 1 * YTAB + y4 ); 395 g.drawString ( decimalFormat.format ( l1_weights_momentum.data [ 1 ] [ 0 ] ), 10 + 6 * XTAB, 2 * YTAB + y4 ); 396 g.drawString ( decimalFormat.format ( l1_weights_momentum.data [ 2 ] [ 0 ] ), 10 + 6 * XTAB, 3 * YTAB + y4 ); 397 g.drawString ( decimalFormat.format ( l1_weights_momentum.data [ 0 ] [ 1 ] ), 10 + 6 * XTAB, 4 * YTAB + y4 ); 398 g.drawString ( decimalFormat.format ( l1_weights_momentum.data [ 1 ] [ 1 ] ), 10 + 6 * XTAB, 5 * YTAB + y4 ); 399 g.drawString ( decimalFormat.format ( l1_weights_momentum.data [ 2 ] [ 1 ] ), 10 + 6 * XTAB, 6 * YTAB + y4 ); 400 } 401 402 public void run ( ) 403 ////////////////////////////////////////////////////////////////////// 404 { 405 try 406 { 407 408 long lastRepaintTime = 0; 409 410 while ( !pleaseStop ) 411 { 412 if ( iteration == 0 ) 413 { 414 long currentTime = System.currentTimeMillis ( ); 415 416 if ( currentTime >= lastRepaintTime + REPAINT_PERIOD ) 417 { 418 lastRepaintTime = currentTime; 419 420 repaint ( ); 421 } 422 } 423 424 l2_weights = l2_weights.add ( l2_weights_delta ); 425 426 // l2_weights = l2_weights.clip ( -1.0, 1.0 ); 427 428 l1_weights = l1_weights.add ( l1_weights_delta ); 429 430 // l1_weights = l1_weights.clip ( -1.0, 1.0 ); 431 432 l0_activations = l0_activations.randomizeUniform ( 0.0, 1.0 ); 433 434 l0_activations.data [ 0 ] [ 0 ] = 1.0; 435 436 samples [ iteration ] [ 0 ] = l0_activations.data [ 1 ] [ 0 ]; 437 438 samples [ iteration ] [ 1 ] = l0_activations.data [ 2 ] [ 0 ]; 439 440 l1_weighted_sums = Matrix.multiply ( 441 l1_weights.transpose ( ), l0_activations ); 442 443 l1_activations = l1_weighted_sums.sigmoid ( ); 444 445 l2_inputs.data [ 0 ] [ 0 ] = 1.0; 446 447 l2_inputs.data [ 1 ] [ 0 ] = l1_activations.data [ 0 ] [ 0 ]; 448 449 l2_inputs.data [ 2 ] [ 0 ] = l1_activations.data [ 1 ] [ 0 ]; 450 451 l2_weighted_sums = Matrix.multiply ( 452 l2_weights.transpose ( ), l2_inputs ); 453 454 l2_activations = l2_weighted_sums.sigmoid ( ); 455 456 samples [ iteration ] [ 2 ] = l2_activations.data [ 0 ] [ 0 ]; 457 458 output_desired = target_function ( l0_activations ); 459 460 output_error = output_desired - l2_activations.data [ 0 ] [ 0 ]; 461 462 l2_local_gradient = l2_weighted_sums.sigmoidDerivative ( ); 463 464 l2_local_gradient = l2_local_gradient.multiply ( output_error ); 465 466 // I'm not sure about the transpose below here. 467 468 l2_weights_delta 469 = Matrix.multiply ( l2_inputs, l2_local_gradient.transpose ( ) ); 470 471 l2_weights_delta = l2_weights_delta.multiply ( learning_rate ); 472 473 l2_weights_delta = l2_weights_delta.add ( l2_weights_momentum ); 474 475 l2_weights_momentum 476 = l2_weights_delta.multiply ( momentum_constant ); 477 478 // needs transpose or sum below? 479 480 sum_weighted_deltas = l2_weights.submatrix ( 1, 2, 0, 0 ); 481 482 sum_weighted_deltas = Matrix.multiply ( 483 sum_weighted_deltas, l2_local_gradient ); 484 485 l1_local_gradients = new Matrix ( 486 l1_activations.rows, l1_activations.cols, 1.0 ); 487 488 l1_local_gradients = l1_local_gradients.subtract ( l1_activations ); 489 490 l1_local_gradients = Matrix.multiplyPairwise ( 491 l1_activations, l1_local_gradients ); 492 493 l1_local_gradients = Matrix.multiplyPairwise ( 494 l1_local_gradients, sum_weighted_deltas ); 495 496 l1_weights_delta = Matrix.multiply ( 497 l0_activations, l1_local_gradients.transpose ( ) ); 498 499 l1_weights_delta = l1_weights_delta.multiply ( learning_rate ); 500 501 l1_weights_delta = l1_weights_delta.add ( l1_weights_momentum ); 502 503 l1_weights_momentum 504 = l1_weights_delta.multiply ( momentum_constant ); 505 506 squared_errors [ iteration ] = output_error * output_error; 507 508 iteration++; 509 510 total_samples++; 511 512 if ( iteration % ITERATIONS_PER_EPOCH == 0 ) 513 { 514 epoch++; 515 516 if ( epoch >= epochs_max ) epoch = 1; 517 518 iteration = 0; 519 520 epoch_rms_error [ epoch - 1 ] = 0.0; 521 522 for ( int index_iteration = 0; 523 index_iteration < ITERATIONS_PER_EPOCH; 524 index_iteration++ ) 525 { 526 epoch_rms_error [ epoch - 1 ] 527 += squared_errors [ index_iteration ]; 528 } 529 530 epoch_rms_error [ epoch - 1 ] /= ( double ) ITERATIONS_PER_EPOCH; 531 532 epoch_rms_error [ epoch - 1 ] 533 = Math.sqrt ( epoch_rms_error [ epoch - 1 ] ); 534 } 535 536 Thread.sleep ( 0 ); 537 } 538 539 repaint ( ); 540 541 } 542 catch ( Exception ex ) 543 { 544 ex.printStackTrace ( ); 545 } 546 finally 547 { 548 runner = null; 549 } 550 } 551 552 ////////////////////////////////////////////////////////////////////// 553 // private methods 554 ////////////////////////////////////////////////////////////////////// 555 556 private void plot_epochs ( 557 Rectangle r, Graphics g, double [ ] epochs ) 558 ////////////////////////////////////////////////////////////////////// 559 { 560 g.setColor ( java.awt.Color.black ); 561 g.fillRect ( r.x, r.y, r.width, r.height ); 562 g.setColor ( java.awt.Color.white ); 563 g.drawRect ( r.x, r.y, r.width, r.height / 2 ); 564 // g.clipRect ( r.x, r.y, r.width, r.height ); 565 for ( int index_epoch = 1; 566 index_epoch <= epoch; 567 index_epoch++ ) 568 { 569 PlotLib.xy ( java.awt.Color.red, 570 ( double ) index_epoch, epochs [ index_epoch - 1 ], 571 r, g, 1.0, ( double ) epoch, 0.0, 1.0, OVAL_SIZE, true ); 572 } 573 // g.clipRect ( 0, 0, SIZE.width, SIZE.height ); 574 g.setColor ( java.awt.Color.white ); 575 g.drawRect ( r.x, r.y, r.width, r.height ); 576 g.setColor ( this.getForeground ( ) ); 577 } 578 579 private void plotGraph ( Graphics g, double [ ] [ ] samples ) 580 ////////////////////////////////////////////////////////////////////// 581 { 582 g.setColor ( Color.black ); 583 584 g.fillRect ( rs.x, rs.y, rs.width, rs.height ); 585 586 for ( int index_iteration = 0; 587 index_iteration < ITERATIONS_PER_EPOCH; 588 index_iteration++ ) 589 { 590 PlotLib.xy ( 591 samples [ index_iteration ] [ 2 ] >= 0.5 592 ? Color.green : Color.red, 593 samples [ index_iteration ] [ 0 ], 594 samples [ index_iteration ] [ 1 ], 595 rs, g, 0.0, 1.0, 0.0, 1.0, OVAL_SIZE ); 596 } 597 598 g.setColor ( Color.white ); 599 600 g.drawRect ( rs.x, rs.y, rs.width, rs.height ); 601 602 g.setColor ( this.getForeground ( ) ); 603 } 604 605 private void randomize_weights ( ) 606 ////////////////////////////////////////////////////////////////////// 607 { 608 l1_weights = l1_weights.randomizeUniform ( -1.0, 1.0 ); 609 610 l2_weights = l2_weights.randomizeUniform ( -1.0, 1.0 ); 611 } 612 613 private double target_function ( Matrix inputs ) 614 ////////////////////////////////////////////////////////////////////// 615 { 616 long a = Math.round ( inputs.data [ 1 ] [ 0 ] ); 617 618 long b = Math.round ( inputs.data [ 2 ] [ 0 ] ); 619 620 long bit_num = 2 * b + a; 621 622 long mask = 1 << bit_num; 623 624 long masked = function_selected & mask; 625 626 return ( masked == mask ) ? 1.0 : 0.0; 627 } 628 629 ////////////////////////////////////////////////////////////////////// 630 ////////////////////////////////////////////////////////////////////// 631 }