| import tensorflow as tf |
| import tensorflow_datasets as tfds |
| import os |
| import numpy as np |
| import matplotlib.pyplot as plt |
| from tensorflow.keras import regularizers |
|
|
| assert 'COLAB_TPU_ADDR' in os.environ, 'Missin TPU?' |
| if('COLAB_TPU_ADDR') in os.environ: |
| TF_MASTER = 'grpc://{}'.format(os.environ['COLAB_TPU_ADDR']) |
| else: |
| TF_MASTER = '' |
| tpu_address = TF_MASTER |
|
|
| resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu_address) |
| tf.config.experimental_connect_to_cluster(resolver) |
| tf.tpu.experimental.initialize_tpu_system(resolver) |
|
|
|
|
| strategy = tf.distribute.TPUStrategy(resolver) |
|
|
|
|
| def create_model(): |
| return tf.keras.Sequential([ |
| tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)), |
| tf.keras.layers.BatchNormalization(), |
| tf.keras.layers.Conv2D(64, (3, 3), activation='relu'), |
| tf.keras.layers.BatchNormalization(), |
| tf.keras.layers.MaxPooling2D(pool_size=(2, 2)), |
| tf.keras.layers.Dropout(0.25), |
|
|
| tf.keras.layers.Conv2D(128, (3, 3), activation='relu'), |
| tf.keras.layers.BatchNormalization(), |
| tf.keras.layers.Conv2D(256, (3, 3), activation='relu', kernel_regularizer=regularizers.l2(0.001)), |
| tf.keras.layers.BatchNormalization(), |
| tf.keras.layers.MaxPooling2D(pool_size=(2, 2)), |
| tf.keras.layers.Dropout(0.25), |
|
|
| tf.keras.layers.Flatten(), |
| tf.keras.layers.Dense(512, activation='relu', kernel_regularizer=regularizers.l2(0.001)), |
| tf.keras.layers.BatchNormalization(), |
| tf.keras.layers.Dropout(0.5), |
| tf.keras.layers.Dense(256, activation='relu', kernel_regularizer=regularizers.l2(0.001)), |
| tf.keras.layers.BatchNormalization(), |
| tf.keras.layers.Dropout(0.5), |
| tf.keras.layers.Dense(10, activation='softmax') |
| ]) |
|
|
|
|
| def get_dataset(batch_size, is_training=True): |
| split = 'train' if is_training else 'test' |
| dataset, info = tfds.load(name='mnist', split=split, with_info= True, as_supervised=True, try_gcs=True) |
| def scale(image, label): |
| image = tf.cast(image, tf.float32) |
| image /= 255.0 |
| return image, label |
| dataset = dataset.map(scale) |
| if is_training: |
| dataset = dataset.shuffle(10000) |
| dataset = dataset.repeat() |
| dataset = dataset.batch(batch_size) |
| return dataset |
|
|
|
|
|
|
| with strategy.scope(): |
| model = create_model() |
| model.compile(optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False), metrics=['sparse_categorical_accuracy']) |
| model.summary() |
|
|
| |
|
|
|
|
| batch_size = 512 |
| train_dataset = get_dataset(batch_size, True) |
| validation_dataset = get_dataset(batch_size, False) |
| with strategy.scope(): |
| model = create_model() |
| model.compile(optimizer='adam', steps_per_execution=50, loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False), metrics=['sparse_categorical_accuracy']) |
| epochs = 80 |
| steps_per_epoch = 60000 // batch_size |
| validation_steps = 10000 // batch_size |
| history = model.fit(train_dataset, epochs=epochs, steps_per_epoch=steps_per_epoch, validation_data=validation_dataset, validation_steps=validation_steps) |
| |
|
|
| acc = history.history['sparse_categorical_accuracy'] |
| val_acc = history.history['val_sparse_categorical_accuracy'] |
| loss = history.history['loss'] |
| val_loss = history.history['val_loss'] |
| epochs_range = range(epochs) |
|
|
|
|
| plt.figure(figsize=(15, 15)) |
| plt.subplot(2, 2, 1) |
| plt.plot(epochs_range, acc, label='Training Accuracy') |
| plt.plot(epochs_range, val_acc, label='Validation Accuracy') |
| plt.legend(loc='lower right') |
| plt.title('Training and Validation Accuracy') |
|
|
| plt.subplot(2, 2, 2) |
| plt.plot(epochs_range, loss, label='Training Loss') |
| plt.plot(epochs_range, val_loss, label='Validation Loss') |
| plt.legend(loc='upper right') |
| plt.title('Training and Validation Loss') |
| plt.show() |
|
|
|
|
| final_daset = validation_dataset.take(10) |
| test_images, test_labels = next(iter(final_daset.take(10))) |
| class_names = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'] |
|
|
| |
| predictions = model.predict(test_images) |
|
|
| fig, axes = plt.subplots(nrows=2, ncols=5, figsize=(15, 6), |
| subplot_kw={'xticks': [], 'yticks': []}) |
| for i, ax in enumerate(axes.flat): |
| |
| ax.imshow(test_images[i]) |
| |
| true_label = class_names[test_labels[i]] |
| pred_label = class_names[np.argmax(predictions[i])] |
| if true_label == pred_label: |
| ax.set_title("Это: {}, ИИ: {}".format(true_label, pred_label), color='green') |
| else: |
| ax.set_title("Это: {}, ИИ: {}".format(true_label, pred_label), color='red') |
|
|
| plt.tight_layout() |
| plt.show() |