-
Notifications
You must be signed in to change notification settings - Fork 11
/
example_1.m
27 lines (27 loc) · 1.54 KB
/
example_1.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
clear;
clc;
% -----------load mnist data
load('mnist_uint8', 'train_x');
train_x = double(reshape(train_x, 60000, 28, 28))/255;
% train_x:[height, width, channel, images_index]
train_x = permute(train_x,[3,2,4,1]);
batch_size = 64;
% ----------- model
generator.layers = {
struct('type', 'input', 'output_shape', [100, batch_size])
struct('type', 'fully_connect', 'output_shape', [7*7*32, batch_size], 'activation', 'leaky_relu')
struct('type', 'reshape', 'output_shape', [7,7,32, batch_size])
struct('type', 'conv2d_transpose', 'output_shape', [14, 14, 16, batch_size], 'kernel_size', 5, 'stride', 2, 'padding', 'same', 'activation', 'leaky_relu')
struct('type', 'conv2d_transpose', 'output_shape', [28, 28, 1, batch_size], 'kernel_size', 5, 'stride', 2, 'padding', 'same', 'activation', 'sigmoid')
};
discriminator.layers = {
struct('type', 'input', 'output_shape', [28, 28, 1, batch_size])
struct('type', 'conv2d', 'output_maps', 16, 'kernel_size', 5, 'padding', 'same', 'activation', 'leaky_relu')
struct('type', 'sub_sampling', 'scale', 2)
struct('type', 'conv2d', 'output_maps', 32, 'kernel_size', 5, 'padding', 'same', 'activation', 'leaky_relu')
struct('type', 'sub_sampling', 'scale', 2)
struct('type', 'reshape', 'output_shape', [7*7*32, batch_size])
struct('type', 'fully_connect', 'output_shape', [1, batch_size], 'activation', 'sigmoid')
};
args = struct('batch_size', batch_size, 'epoch', 10, 'learning_rate', 0.001, 'optimizer', 'adam');
[generator, discriminator] = gan_train(generator, discriminator, train_x, args);