[ML] Swift TensorFlow (Part 4)

With 0.8 dropping, a few things in my previous posts changed, thankfully not much. And by trying to train bigger models, I ran into a huge RAM issue, so I'll share what I did in a few paragraphs

Changes for 0.8

valueWithGradient is now a global module function, and you have to call it through TensorFlow like this:

let (loss, grad) = TensorFlow.valueWithGradient(at: model) { (model: TextModel) -> Tensor<Float> in
                    let logits = model(sampleFeatures)
                    return softmaxCrossEntropy(logits: logits, labels: sampleLabels)
}

Also, they revamped the serialization mechanics, you can now get serializable data through

try model.serializedParameters()
RAM issues

It so happens that someone told me to try with characters trigrams instead of word trigrams. I have no idea if the results are better or worst yet, because the dataset generated is huge: 4*<number of chars>, and a somewhat simple text file gave way to a magnificent 96GB of RAM usage.

Of course, this means that the program can't really run. It also meant I had to find an alternative way, and the simplest I know of that I could implement quickly was storing all the trigrams in a big database, and extract random samples from it, rather than doing it in memory. This meants going from 96GB of RAM usage down to 4GB.

The setup

I do Kitura stuff, and I ♥️ PostgreSQL, so I went for a simple ORM+Kuery setup.

The table stores trigrams, and I went for generics for the stored structure:

struct StorableTrigram<FScalar, TScalar> : Codable where FScalar : TensorFlowScalar, FScalar : Codable, TScalar : TensorFlowScalar, TScalar : Codable {
    var random_id : Int64
    var t1 : FScalar
    var t2 : FScalar
    var t3 : FScalar
    var r : TScalar
}

extension StorableTrigram : Model {
    static var tableName: String {
        get {
            return ("StorableTrigram"+String(describing: FScalar.self)+String(describing: TScalar.self)).replacingOccurrences(of: " ", with: "_")
        }
    }
}

The random_id will be used to shuffle the lines into multiple partitions later, and the tableName override is to avoid < and > from the table name.

The partitionning

One of the key things needed to avoid saturating the RAM is to partition the data. As the rest of the training loop expects an array, I decided to go with a custom Collection that can fit in a for loop and load only the current partition:

struct RandomAccessPartition : Collection {
    let numberOfPartitions: Int
    let db : ConnectionPool
    
    typealias Index = Int
    var startIndex: Int { return 0 }
    var endIndex: Int { return numberOfPartitions-1 }
    
    func index(after i: Int) -> Int {
        return i+1
    }

    subscript(position: Int) -> (features: Tensor<Float>, labels: Tensor<Int32>) {
        let partitionSize = Int64.max / Int64(numberOfPartitions)
        let start_rid = partitionSize * Int64(position)
        let end_rid = partitionSize * Int64(position + 1)
        var rf : [[Float]] = []
        var rl : [Int32] = []

        let lsem = DispatchSemaphore(value: 0)
        db.getConnection() { conn, err in
             if conn == nil {
                 lsem.signal()
                 return
             }
             
             conn!.execute("SELECT * FROM \"\(StorableTrigram<Float,Int32>.tableName)\" WHERE random_id >= \(start_rid) AND random_id < \(end_rid)") { resultSet in
                 resultSet.asRows { rows,error in
                     guard let rows = rows else {
                         lsem.signal()
                         return
                     }
                     for row in rows {
                         if let t1 = row["t1"] as? Float,
                         let t2 = row["t1"] as? Float,
                         let t3 = row["t1"] as? Float,
                             let r = row["r"] as? Int32 {
                             rf.append([t1,t2,t3])
                             rl.append(r)
                         }
                     }
                     lsem.signal()
                 }
             }
         }

        
        lsem.wait()
         let featuresT = Tensor<Float>(shape: [rf.count, 3], scalars: rf.flatMap { $0 })
         let labelsT = Tensor<Int32>(rl)
         return (featuresT, labelsT)
    }
}

Relying on random_id for the partitions is a bit iffy, but thankfully PostgreSQL can re-randomize those ids somewhat fast works well enough for my use

The TextBatch replacement

The three key features of that batch-holding struct was:

  • initialization
  • random sample (once)
  • random partitions (once every epoch)

So here's the relevant code, with breaks for explanations:

struct RandomAccessStringStorage {
    var db : ConnectionPool
    var tableCreated : Bool = false
    
    let original: [String]
    let vocabulary: [String]
    let indexHelper: [Int:Int]
    
    init(db database: ConnectionPool, original o: [String], terminator: String? = nil, fromScratch: Bool) {
        db = database
        Database.default = Database(database) // shady, but hey
        
        original = o
        let f : [[Float]]
        let l : [Int32]
        let v : [String]
        let h : [Int:Int]
        if let term = terminator {
            (f,l,v,h) = RandomAccessStringStorage.makeArrays(original, terminator: term)
        } else {
            (f,l,v,h) = RandomAccessStringStorage.makeArrays(original)
        }
        
        vocabulary = v
        indexHelper = h
        if fromScratch {
            deleteAll()
            for i in 0..<f.count {
                insertTrigram(t1: f[i][0], t2: f[i][1], t3: f[i][2], r: l[i])
            }
        } 
    }
    
        mutating func deleteAll() {
        let _ = try? StorableTrigram<Float,Int32>.dropTableSync()
        tableCreated = false
    }
    
    mutating func insertTrigram(t1: Float, t2: Float, t3: Float, r: Int32) {
        if !tableCreated {
            let _ = try? StorableTrigram<Float,Int32>.createTableSync()
            tableCreated = true
        }
        let trig = StorableTrigram(random_id: Int64.random(in: Int64(0)...Int64.max), t1: t1, t2: t2, t3: t3, r: r)
        let lsem = DispatchSemaphore(value: 0)
        trig.save { st, error in
            lsem.signal()
        }
        lsem.wait()
    }
// ...
}

The two makeArrays are copied and pasted from the in-memory TextBatch, and the only other thing the initialization relies on is the insertion in the DB system.

There are two ways of drawing random items: a one-off and partition the data into random chunks:

func randomSample(of size: Int) -> (features: Tensor<Float>, labels: Tensor<Int32>) {
    var rf : [[Float]] = []
    var rl : [Int32] = []

    let lsem = DispatchSemaphore(value: 0)
    db.getConnection() { conn, err in
        if conn == nil {
            lsem.signal()
            return
        }
        
        conn!.execute("SELECT * FROM \"\(StorableTrigram<Float,Int32>.tableName)\" ORDER BY random() LIMIT \(size)") { resultSet in
            resultSet.asRows { rows,error in
                guard let rows = rows else {
                    lsem.signal()
                    return
                }
                for row in rows {
                    if let t1 = row["t1"] as? Float,
                    let t2 = row["t1"] as? Float,
                    let t3 = row["t1"] as? Float,
                        let r = row["r"] as? Int32 {
                        rf.append([t1,t2,t3])
                        rl.append(r)
                    }
                }
                lsem.signal()
            }
        }
    }
    
    lsem.wait()
    let featuresT = Tensor<Float>(shape: [rf.count, 3], scalars: rf.flatMap { $0 })
    let labelsT = Tensor<Int32>(rl)
    return (featuresT, labelsT)
}

Random selection in Pg actually works pretty well, but can't be repeated, which is why we have to rely on the random_id to partition:

func randomSample(splits: Int) -> RandomAccessPartition<Float,Int32> {
    // reshuffle (will take a while)
    // update "StorableTrigramFloatInt32" SET random_id = cast(9223372036854775807 * random() as bigint);
    let lsem = DispatchSemaphore(value: 0)
    db.getConnection() { conn, err in
        if conn == nil {
            lsem.signal()
            return
        }
        
        conn!.execute("UPDATE \"\(StorableTrigram<Float,Int32>.tableName)\" SET random_id = cast(9223372036854775807 * random() as bigint)") { resultSet in
            lsem.signal()
        }
    }
    lsem.wait()
    return RandomAccessPartition<Float,Int32>(numberOfPartitions: splits, db: self.db)
}

The update will re-randomize the ids, paving the way for the RandomAccessPartition.

Of course the tradeoff in terms of performance is rather big, especially in the initialization phase, but hey, more ram to do other things when the model is training!