This is the last part of a 3-parts series. In part 1, I tried to make sense of how it works and what we are trying to achieve, and in part 2, we set up the training loop.
Model Predictions
We have a trained model. Now what?
Remember, a model is a series of giant matrices that take an input like you trained it on, and spits out the list of probabilities associated with the outputs you trained it on. So all you have to do is feed it a new input and see what it tells you:
let input = [1.0, 179.0, 115.0]
let unlabeled : Tensor<Float> = Tensor<Float>(shape: [1, 3], scalars: input)
let predictions = model(unlabeled)
let logits = predictions[0]
let classIdx = logits.argmax().scalar! // we take only the best guess
print(classIdx)
17
Cool.
Cool, cool.
What?
Models deal with numbers. I am the one who assigned numbers to words to train the model on, so I need a translation layer. That's why I kept my contents
structure around: I need it for its vocabulary
map.
The real code:
let w1 = "on"
let w2 = "flocks"
let w3 = "settlement"
var indices = [w1, w2, w3].map {
Float(contents.indexHelper[$0.hash] ?? 0)
}
var wordsToPredict = 50
var sentence = "\(w1) \(w2) \(w3)"
while wordsToPredict >= 0 {
let unlabeled : Tensor<Float> = Tensor<Float>(shape: [1, 3], scalars: indices)
let predictions = model(unlabeled)
for i in 0..<predictions.shape[0] {
let logits = predictions[i]
let classIdx = logits.argmax().scalar!
let word = contents.vocabulary[Int(classIdx)]
sentence += " \(word)"
indices.append(Float(classIdx))
indices.remove(at: 0)
wordsToPredict -= 1
}
}
print(sentence)
on flocks settlement or their enter the earth; their only hope in their arrows, which for want of it, with a thorn. and distinction of their nature, that in the same yoke are also chosen their chiefs or rulers, such as administer justice in their villages and by superstitious awe in times of old.
Notice how I remove the first input and add the one the model predicted at the end to keep the loop running.
Seeing that, it kind of makes you think about the suggestions game when you send text messages eh? 😁
Model Serialization
Training a model takes a long time. You don't want a multi-hour launch time on your program every time you want a prediction, and maybe you even want to keep updating the model every now and then. So we need a way to store it and load it.
Thankfully, tensors are just matrices, so it's easy to store an array of arrays of floats, we've been doing that forever. They are even Codable
out of the box.
In my particular case, the model itself needs to remember a few things to be recreated:
- the number of inputs and hidden nodes, in order to recreate the
Reshape
andLSTMCell
layers - the internal probability matrices of both RNNs
- the
weigths
andbiases
correction matrices
Because they are codable, any regular swift encoder will work, but I know some of you will want to see the actual matrices, so I use JSON. It is not the most time or space efficient, it does not come with a way to validate it, and JSON is an all-around awful storage format, but it makes a few things easy.
extension TextModel { // serialization
struct TextModelParams : Codable {
var inputs : Int
var hidden : Int
var rnn1w : Tensor<Float>
var rnn1b : Tensor<Float>
var rnn2w : Tensor<Float>
var rnn2b : Tensor<Float>
var weights : Tensor<Float>
var biases : Tensor<Float>
}
func serializedParameters() throws -> Data {
return try JSONEncoder().encode(TextModelParams(
inputs: self.inputs,
hidden: self.hidden,
rnn1w: self.rnn1.cell.fusedWeight,
rnn1b: self.rnn1.cell.fusedBias,
rnn2w: self.rnn2.cell.fusedWeight,
rnn2b: self.rnn1.cell.fusedBias,
weights: self.weightsOut,
biases: self.biasesOut))
}
struct TextModelSerializationError : Error { }
init(_ serialized: Data) throws {
guard let params = try? JSONDecoder().decode(TextModelParams.self, from: serialized) else { throw TextModelSerializationError() }
inputs = params.inputs
hidden = params.hidden
reshape = Reshape<Float>([-1, inputs])
var lstm1 = LSTMCell<Float>(inputSize: 1, hiddenSize: hidden)
lstm1.fusedWeight = params.rnn1w
lstm1.fusedBias = params.rnn1b
var lstm2 = LSTMCell<Float>(inputSize: hidden, hiddenSize: hidden)
lstm2.fusedWeight = params.rnn2w
lstm2.fusedBias = params.rnn2b
rnn1 = RNN(lstm1)
rnn2 = RNN(lstm2)
weightsOut = params.weights
biasesOut = params.biases
correction = weightsOut+biasesOut
}
}
My resulting JSON file is around 70MB (25 when bzipped), so not too bad.
When you serialize your model, remember to serialize the vocabulary mappings as well! Otherwise, you will lose the word <-> int translation layer.
That's all , folks!
This was a quick and dirty intro to TensorFlow for some, Swift for others, and SwiftTensorflow for most.
It definitely is a highly specialized and quite brittle piece of software, but it's a good conversation piece next time you hear that ML is going to take over the world.
Feel free to drop me comments or questions or corrections on Twitter!