のねのBlog

パソコンの問題や、ソフトウェアの開発で起きた問題など書いていきます。よろしくお願いします^^。

Pix2pixを読む train

https://github.com/phillipi/pix2pix/blob/master/train.lua

-- usage example: DATA_ROOT=/path/to/data/ which_direction=BtoA name=expt1 th train.lua 
--
-- code derived from https://github.com/soumith/dcgan.torch
--
>|lua|
require 'torch'
require 'nn'
require 'optim'
util = paths.dofile('util/util.lua')
require 'image'
require 'models'
opt = {
   DATA_ROOT = '',         
-- path to images (should have subfolders 'train', 'val', etc)
   batchSize = 1,          
-- # images in batch
   loadSize = 286,         
-- scale images to this size
   fineSize = 256,         
--  then crop to this size
   ngf = 64,               
-- #  of gen filters in first conv layer
   ndf = 64,               
-- #  of discrim filters in first conv layer
   input_nc = 3,           
-- #  of input image channels
   output_nc = 3,          
-- #  of output image channels
   niter = 200,            
-- #  of iter at starting learning rate
   lr = 0.0002,            
-- initial learning rate for adam
   beta1 = 0.5,            
-- momentum term of adam
   ntrain = math.huge,     
-- #  of examples per epoch. math.huge for full dataset
   flip = 1,               
-- if flip the images for data argumentation
   display = 1,            
-- display samples while training. 0 = false
   display_id = 10,        
-- display window id.
   display_plot = 'errL1',    
-- which loss values to plot over time. 
-- Accepted values include a comma seperated list of: errL1, errG, and errD
   gpu = 1,                
-- gpu = 0 is CPU mode. gpu=X is GPU mode on GPU X
   name = '',              
-- name of the experiment, should generally be passed on the command line
   which_direction = 'AtoB',    
-- AtoB or BtoA
   phase = 'train',             
-- train, val, test, etc
   preprocess = 'regular',      
-- for special purpose preprocessing, e.g., 
-- for colorization, change this (selects preprocessing functions in util.lua)
   nThreads = 2, 
-- # threads for loading data
   save_epoch_freq = 50,        
-- save a model every save_epoch_freq epochs
-- (does not overwrite previously saved models)
   save_latest_freq = 5000,     
-- save the latest model every latest_freq sgd iterations
-- (overwrites the previous latest model)
   print_freq = 50,             
-- print the debug information every print_freq iterations
   display_freq = 100,          
-- display the current results every display_freq iterations
   save_display_freq = 5000,    
-- save the current display of results every save_display_freq_iterations
   continue_train=0,            
-- if continue training, load the latest model: 1: true, 0: false
   serial_batches = 0,          
-- if 1, takes images in order to make batches, otherwise takes them randomly
   serial_batch_iter = 1,       
-- iter into serial image list
   checkpoints_dir = './checkpoints', 
-- models are saved here
   cudnn = 1,                         
-- set to 0 to not use cudnn
   condition_GAN = 1,                 
-- set to 0 to use unconditional discriminator
   use_GAN = 1,                       
-- set to 0 to turn off GAN term
   use_L1 = 1,                        
-- set to 0 to turn off L1 term

   which_model_netD = 'basic', 
-- selects model to use for netD

   which_model_netG = 'unet',  
-- selects model to use for netG

   n_layers_D = 0,             
-- only used if which_model_netD=='n_layers'

   lambda = 100,               
-- weight on L1 term in objective
}
-- one-line argument parser. 
-- parses enviroment variables to override the defaults
for k,v in pairs(opt) 
    do opt[k] = tonumber(os.getenv(k)) 
    or os.getenv(k) 
    or opt[k] end
print(opt)

local input_nc = opt.input_nc
local output_nc = opt.output_nc
-- translation direction
local idx_A = nil
local idx_B = nil

if opt.which_direction=='AtoB' then -- def
    idx_A = {1, input_nc}
    idx_B = {input_nc+1, input_nc+output_nc}
elseif opt.which_direction=='BtoA' then
    -- idx_A = {input_nc+1, input_nc+output_nc}
    -- idx_B = {1, input_nc}
else
    -- error(string.format('bad direction %s',opt.which_direction))
end

if opt.display == 0 
then opt.display = false end

opt.manualSeed = torch.random(1, 10000) -- fix seed
print("Random Seed: " .. opt.manualSeed)
torch.manualSeed(opt.manualSeed)
torch.setdefaulttensortype('torch.FloatTensor')

-- create data loader
local data_loader = paths.dofile('data/data.lua')
print('#threads...' .. opt.nThreads)
local data = data_loader.new(opt.nThreads, opt)
print("Dataset Size: ", data:size())
----------------------------------------------------------------------------
local function weights_init(m)
   local name = torch.type(m)
   if name:find('Convolution') then
      m.weight:normal(0.0, 0.02)
      m.bias:fill(0)
   elseif name:find('BatchNormalization') then
      if m.weight then m.weight:normal(1.0, 0.02) end
      if m.bias then m.bias:fill(0) end
   end
end
local ndf = opt.ndf
local ngf = opt.ngf
local real_label = 1
local fake_label = 0

function defineG(input_nc, output_nc, ngf)
    local netG = nil
    -- if     opt.which_model_netG == "encoder_decoder" 
    -- then netG = defineG_encoder_decoder(input_nc, output_nc, ngf)
    elseif opt.which_model_netG == "unet"  <==
    then 
         netG = defineG_unet(input_nc, output_nc, ngf)
    -- elseif opt.which_model_netG == "unet_128" 
    -- then netG = defineG_unet_128(input_nc, output_nc, ngf)
    -- else error("unsupported netG model")
    end
   
    netG:apply(weights_init)
  
    return netG
end
function defineD(input_nc, output_nc, ndf)
    local netD = nil
    if opt.condition_GAN==1 then
        input_nc_tmp = input_nc
    else
        input_nc_tmp = 0 -- only penalizes structure in output channels
    end
    
    if     opt.which_model_netD == "basic" <==
    then 
        netD = defineD_basic(input_nc_tmp, output_nc, ndf)
    -- elseif opt.which_model_netD == "n_layers" 
    -- then 
    --    netD = defineD_n_layers(input_nc_tmp, output_nc, ndf, opt.n_layers_D)
    -- else error("unsupported netD model")
    end
    
    netD:apply(weights_init)
    
    return netD
end
-- load saved models and finetune
if opt.continue_train == 1 then
   print('loading previously trained netG...')
   netG = util.load(paths.concat(
        opt.checkpoints_dir, 
        opt.name, 'latest_net_G.t7'), opt)

   print('loading previously trained netD...')
   netD = util.load(paths.concat(
        opt.checkpoints_dir, 
        opt.name, 'latest_net_D.t7'), opt)
else
  print('define model netG...')
  netG = defineG(input_nc, output_nc, ngf)

  print('define model netD...')
  netD = defineD(input_nc, output_nc, ndf)
end

print(netG)
print(netD)


local criterion = nn.BCECriterion()
local criterionAE = nn.AbsCriterion()
---------------------------------------------------------------------------
optimStateG = {
   learningRate = opt.lr,
   beta1 = opt.beta1,
}
optimStateD = {
   learningRate = opt.lr,
   beta1 = opt.beta1,
}
----------------------------------------------------------------------------
local real_A = torch.Tensor(
    opt.batchSize, 
    input_nc, 
    opt.fineSize, 
    opt.fineSize)
local real_B = torch.Tensor(
    opt.batchSize, 
    output_nc, 
    opt.fineSize, 
    opt.fineSize)

local fake_B = torch.Tensor(
    opt.batchSize, 
    output_nc, 
    opt.fineSize, 
    opt.fineSize)

local real_AB = torch.Tensor(
    opt.batchSize, 
    output_nc + input_nc*opt.condition_GAN,
    opt.fineSize,
    opt.fineSize)

local fake_AB = torch.Tensor(
    opt.batchSize, 
    output_nc + input_nc*opt.condition_GAN, 
    opt.fineSize, 
    opt.fineSize)

local errD, errG, errL1 = 0, 0, 0
local epoch_tm = torch.Timer()
local tm = torch.Timer()
local data_tm = torch.Timer()
----------------------------------------------------------------------------