Binary Neural Network

This page contains a short tutorial on how to train a binary neural network. It is loosely adapted from the Flux model zoo's MLP example.

We will train neural network with binary {-1,+1} weights and hidden layer activations. The output layer will have real-valued neurons. binary networks require less storage and memory space and when deployed on edge computing devices they can be more computationally effecient than real-valued networks (not so much on traditional CPUs and GPUs which have much better optimized BLAS and CUDA kernels for Float32 operations).

To create a network with binary weights we need to specify a forward mapping which applies a binary activation function to the weights, when we create the layers. We will also specify a binary activation function to ba applied to the layers output. The tricky part about training binary neural networks is that a binary function such as sign(x) has derivative zero everywhere except at the oprigin where it is undefined. A typical workaround is to only use the sign function during the forward pass and replace it with a differentiable surrogate such as tanh or identity during the backwards pass. This hack actually works quite well and a lot of research on binary neural networks explores variations of this idea. We will use the straight-through estimator, which is explained by Courbariaux et al in the BinaryConnect paper. During a forward pass it acts as the sign function, but during the backwards pass (i.e. during backprop) it is replaced by the identity function.

In Bender.jl there is a stocastic and a deterministic variant of the STE, stoc_sign_STE and sign_STE. We will make use of the deterministic version here. In a traditional fully connected layer the output is computed by first computing some preactivation $a=Wx+b$ and then applying an elementwise activation function $z=f\circ a$. You can specify a forward mapping for the computation of $a$ when initializing a GenDense layer, which is how we can make the layer use the straight through estimator. We will use the following mapping (which is also exported by Bender.jl)

function linear_binary_weights(a, x)
    W, b, = a.weight, a.bias
    return sign_STE.(W)*x .+ b
end

Using this forward mapping we define a 3 layer fully connected network.

function build_model()
	m = Chain(  GenDense( 	784=>300, sign_STE, forward=linear_binary_weights),
                GenDense( 	300=>100, sign_STE; forward=linear_binary_weights), 
                GenDense(  	100=>10; forward=linear_binary_weights))
    return m
end

The following sections contain some boilerplate code. You can expand them to show the details.

Handling dependencies

show details
using Pkg; Pkg.add(url="https://github.com/Rasmuskh/Bender.jl.git")
begin
	using Bender, Flux, MLDatasets
	using Flux: onehotbatch, onecold, logitcrossentropy, throttle
	using Flux.Data: DataLoader
	using Parameters: @with_kw
	using DataFrames
end

Utilities

show details
@with_kw mutable struct Args
    η::Float64 = 0.0003     # learning rate
    batchsize::Int = 64     # batch size
    epochs::Int = 10        # number of epochs
    device::Function = gpu  # set as gpu, if gpu available
end
function getdata(args)
    ENV["DATADEPS_ALWAYS_ACCEPT"] = "true"

    # Loading Dataset	
    xtrain, ytrain = MLDatasets.MNIST.traindata(Float32)
    xtest, ytest = MLDatasets.MNIST.testdata(Float32)

    # Reshape Data in order to flatten each image into a vector
    xtrain = Flux.flatten(xtrain)
    xtest = Flux.flatten(xtest)

    # One-hot-encode the labels
    ytrain, ytest = onehotbatch(ytrain, 0:9), onehotbatch(ytest, 0:9)

    # Batching
    train_data = DataLoader((xtrain, ytrain), batchsize=args.batchsize, shuffle=true, partial=false)
    test_data = DataLoader((xtest, ytest), batchsize=args.batchsize, partial=false)

    return train_data, test_data
end
function evaluate(data_loader, model)
    acc = 0
	l = 0
	numbatches = length(data_loader)
    for (x,y) in data_loader
        acc += sum(onecold(cpu(model(x))) .== onecold(cpu(y)))*1 / size(x,2)
		l += logitcrossentropy(model(x), y)
    end
    return acc/numbatches, l/numbatches
end

Defining the training loop

show details
function train(; kws...)
    # Initializing Model parameters 
    args = Args(; kws...)

	# Create arrays for recording training metrics
	acc_train = zeros(Float32, args.epochs)
	acc_test = zeros(Float32, args.epochs)
	loss_train = zeros(Float32, args.epochs)
	loss_test = zeros(Float32, args.epochs)
	
    # Load Data
    train_data,test_data = getdata(args)

    # Construct model
    m = build_model()
	
    train_data = args.device.(train_data)
    test_data = args.device.(test_data)
    m = args.device(m)
    loss(x,y) = logitcrossentropy(m(x), y)
    
    # Training
    opt = ADAM(args.η)
	for epoch=1:args.epochs	
        Flux.train!(loss, params(m), train_data, opt)
		acc_train[epoch], loss_train[epoch] = evaluate(train_data, m)
		acc_test[epoch], loss_test[epoch] = evaluate(test_data, m)
		# clamp weights to lie in the range [-1,+1]
		for layer in m
			layer.weight .= clamp.(layer.weight, -1, 1)
		end
    end

	# Return trianing metrics as a DataFrame
	df = DataFrame([loss_train, loss_test, acc_train, acc_test], 
				   [:loss_train, :loss_test, :acc_train, :acc_test])
	return df, cpu(m)
end

Training the model

Although we constrained the model to use {-1,+1} weights during inference it still manages to solve the problem fairly well. The train function stores the model's loss and accuracy on the train and test set in a DataFrame.

df, m = train(epochs=10); round.(df, digits=3)
loss_trainloss_testacc_trainacc_test
Float32Float32Float32Float32
11.91.8620.7890.792
21.561.5420.8240.83
31.3651.3610.850.849
41.3621.3820.8460.852
51.2941.340.8590.859
61.2341.2240.8660.87
71.2681.2690.8650.869
81.3061.3950.8570.855
91.2221.260.8730.873
101.221.2630.8670.867