Pix2pixを読む train
https://github.com/phillipi/pix2pix/blob/master/train.lua
-- usage example: DATA_ROOT=/path/to/data/ which_direction=BtoA name=expt1 th train.lua -- -- code derived from https://github.com/soumith/dcgan.torch -- >|lua| require 'torch' require 'nn' require 'optim' util = paths.dofile('util/util.lua') require 'image' require 'models'
opt = { DATA_ROOT = '', -- path to images (should have subfolders 'train', 'val', etc) batchSize = 1, -- # images in batch loadSize = 286, -- scale images to this size fineSize = 256, -- then crop to this size ngf = 64, -- # of gen filters in first conv layer ndf = 64, -- # of discrim filters in first conv layer input_nc = 3, -- # of input image channels output_nc = 3, -- # of output image channels niter = 200, -- # of iter at starting learning rate lr = 0.0002, -- initial learning rate for adam beta1 = 0.5, -- momentum term of adam ntrain = math.huge, -- # of examples per epoch. math.huge for full dataset flip = 1,
-- if flip the images for data argumentation display = 1, -- display samples while training. 0 = false display_id = 10, -- display window id. display_plot = 'errL1', -- which loss values to plot over time. -- Accepted values include a comma seperated list of: errL1, errG, and errD gpu = 1, -- gpu = 0 is CPU mode. gpu=X is GPU mode on GPU X name = '', -- name of the experiment, should generally be passed on the command line which_direction = 'AtoB', -- AtoB or BtoA phase = 'train', -- train, val, test, etc preprocess = 'regular', -- for special purpose preprocessing, e.g., -- for colorization, change this (selects preprocessing functions in util.lua) nThreads = 2,
-- # threads for loading data save_epoch_freq = 50, -- save a model every save_epoch_freq epochs -- (does not overwrite previously saved models) save_latest_freq = 5000, -- save the latest model every latest_freq sgd iterations -- (overwrites the previous latest model) print_freq = 50, -- print the debug information every print_freq iterations display_freq = 100, -- display the current results every display_freq iterations save_display_freq = 5000, -- save the current display of results every save_display_freq_iterations
continue_train=0, -- if continue training, load the latest model: 1: true, 0: false serial_batches = 0, -- if 1, takes images in order to make batches, otherwise takes them randomly serial_batch_iter = 1, -- iter into serial image list checkpoints_dir = './checkpoints', -- models are saved here cudnn = 1, -- set to 0 to not use cudnn condition_GAN = 1, -- set to 0 to use unconditional discriminator use_GAN = 1, -- set to 0 to turn off GAN term use_L1 = 1, -- set to 0 to turn off L1 term which_model_netD = 'basic', -- selects model to use for netD which_model_netG = 'unet', -- selects model to use for netG n_layers_D = 0, -- only used if which_model_netD=='n_layers' lambda = 100, -- weight on L1 term in objective }
-- one-line argument parser. -- parses enviroment variables to override the defaults for k,v in pairs(opt) do opt[k] = tonumber(os.getenv(k)) or os.getenv(k) or opt[k] end print(opt) local input_nc = opt.input_nc local output_nc = opt.output_nc -- translation direction local idx_A = nil local idx_B = nil if opt.which_direction=='AtoB' then -- def idx_A = {1, input_nc} idx_B = {input_nc+1, input_nc+output_nc} elseif opt.which_direction=='BtoA' then -- idx_A = {input_nc+1, input_nc+output_nc} -- idx_B = {1, input_nc} else -- error(string.format('bad direction %s',opt.which_direction)) end if opt.display == 0 then opt.display = false end opt.manualSeed = torch.random(1, 10000) -- fix seed print("Random Seed: " .. opt.manualSeed) torch.manualSeed(opt.manualSeed) torch.setdefaulttensortype('torch.FloatTensor') -- create data loader local data_loader = paths.dofile('data/data.lua') print('#threads...' .. opt.nThreads) local data = data_loader.new(opt.nThreads, opt) print("Dataset Size: ", data:size())
---------------------------------------------------------------------------- local function weights_init(m) local name = torch.type(m) if name:find('Convolution') then m.weight:normal(0.0, 0.02) m.bias:fill(0) elseif name:find('BatchNormalization') then if m.weight then m.weight:normal(1.0, 0.02) end if m.bias then m.bias:fill(0) end end end
local ndf = opt.ndf local ngf = opt.ngf local real_label = 1 local fake_label = 0 function defineG(input_nc, output_nc, ngf) local netG = nil -- if opt.which_model_netG == "encoder_decoder" -- then netG = defineG_encoder_decoder(input_nc, output_nc, ngf) elseif opt.which_model_netG == "unet" <== then netG = defineG_unet(input_nc, output_nc, ngf) -- elseif opt.which_model_netG == "unet_128" -- then netG = defineG_unet_128(input_nc, output_nc, ngf) -- else error("unsupported netG model") end netG:apply(weights_init) return netG end
function defineD(input_nc, output_nc, ndf) local netD = nil if opt.condition_GAN==1 then input_nc_tmp = input_nc else input_nc_tmp = 0 -- only penalizes structure in output channels end if opt.which_model_netD == "basic" <== then netD = defineD_basic(input_nc_tmp, output_nc, ndf) -- elseif opt.which_model_netD == "n_layers" -- then -- netD = defineD_n_layers(input_nc_tmp, output_nc, ndf, opt.n_layers_D) -- else error("unsupported netD model") end netD:apply(weights_init) return netD end
-- load saved models and finetune if opt.continue_train == 1 then print('loading previously trained netG...') netG = util.load(paths.concat( opt.checkpoints_dir, opt.name, 'latest_net_G.t7'), opt) print('loading previously trained netD...') netD = util.load(paths.concat( opt.checkpoints_dir, opt.name, 'latest_net_D.t7'), opt) else print('define model netG...') netG = defineG(input_nc, output_nc, ngf) print('define model netD...') netD = defineD(input_nc, output_nc, ndf) end print(netG) print(netD) local criterion = nn.BCECriterion() local criterionAE = nn.AbsCriterion()
--------------------------------------------------------------------------- optimStateG = { learningRate = opt.lr, beta1 = opt.beta1, } optimStateD = { learningRate = opt.lr, beta1 = opt.beta1, }
---------------------------------------------------------------------------- local real_A = torch.Tensor( opt.batchSize, input_nc, opt.fineSize, opt.fineSize) local real_B = torch.Tensor( opt.batchSize, output_nc, opt.fineSize, opt.fineSize) local fake_B = torch.Tensor( opt.batchSize, output_nc, opt.fineSize, opt.fineSize) local real_AB = torch.Tensor( opt.batchSize, output_nc + input_nc*opt.condition_GAN, opt.fineSize, opt.fineSize) local fake_AB = torch.Tensor( opt.batchSize, output_nc + input_nc*opt.condition_GAN, opt.fineSize, opt.fineSize) local errD, errG, errL1 = 0, 0, 0 local epoch_tm = torch.Timer() local tm = torch.Timer() local data_tm = torch.Timer() ----------------------------------------------------------------------------