//
//  Rabbit.swift
//  CryptoSwift
//
//  Created by Dima Kalachov on 12/11/15.
//  Copyright © 2015 Marcin Krzyzanowski. All rights reserved.
//

final public class Rabbit {
    
    /// Size of IV in bytes
    public static let ivSize = 64 / 8
    
    /// Size of key in bytes
    public static let keySize = 128 / 8
    
    /// Size of block in bytes
    public static let blockSize = 128 / 8
    
    /// Key
    private let key: [UInt8]
    
    /// IV (optional)
    private let iv: [UInt8]?
    
    /// State variables
    private var x = [UInt32](count: 8, repeatedValue: 0)
    
    /// Counter variables
    private var c = [UInt32](count: 8, repeatedValue: 0)
    
    /// Counter carry
    private var p7: UInt32 = 0
    
    /// 'a' constants
    private var a: [UInt32] = [
        0x4D34D34D,
        0xD34D34D3,
        0x34D34D34,
        0x4D34D34D,
        0xD34D34D3,
        0x34D34D34,
        0x4D34D34D,
        0xD34D34D3,
    ]
    
    // MARK: - Initializers
    convenience public init?(key:[UInt8]) {
        self.init(key: key, iv: nil)
    }
    
    public init?(key:[UInt8], iv:[UInt8]?) {
        self.key = key
        self.iv = iv
        
        guard key.count == Rabbit.keySize && (iv == nil || iv!.count == Rabbit.ivSize) else {
            return nil
        }
    }
    
    // MARK: -
    private func setup() {
        p7 = 0
        
        // Key divided into 8 subkeys
        var k = [UInt32](count: 8, repeatedValue: 0)
        for var j = 0; j < 8; j++ {
            k[j] = UInt32(key[Rabbit.blockSize - (2*j + 1)]) | (UInt32(key[Rabbit.blockSize - (2*j + 2)]) << 8)
        }
        
        // Initialize state and counter variables from subkeys
        for var j = 0; j < 8; j++ {
            if j % 2 == 0 {
                x[j] = (k[(j+1) % 8] << 16) | k[j]
                c[j] = (k[(j+4) % 8] << 16) | k[(j+5) % 8]
            } else {
                x[j] = (k[(j+5) % 8] << 16) | k[(j+4) % 8]
                c[j] = (k[j] << 16)         | k[(j+1) % 8]
            }
        }
        
        // Iterate system four times
        nextState()
        nextState()
        nextState()
        nextState()
        
        // Reinitialize counter variables
        for var j = 0; j < 8; j++ {
            c[j] = c[j] ^ x[(j+4) % 8]
        }
        
        if let iv = iv {
            setupIV(iv)
        }
    }
    
    private func setupIV(iv: [UInt8]) {
        // 63...56 55...48 47...40 39...32 31...24 23...16 15...8 7...0 IV bits
        //    0       1       2       3       4       5       6     7   IV bytes in array
        let iv0: UInt32 = integerWithBytes([iv[4], iv[5], iv[6], iv[7]])
        let iv1: UInt32 = integerWithBytes([iv[0], iv[1], iv[4], iv[5]])
        let iv2: UInt32 = integerWithBytes([iv[0], iv[1], iv[2], iv[3]])
        let iv3: UInt32 = integerWithBytes([iv[2], iv[3], iv[6], iv[7]])
        
        // Modify the counter state as function of the IV
        c[0] = c[0] ^ iv0
        c[1] = c[1] ^ iv1
        c[2] = c[2] ^ iv2
        c[3] = c[3] ^ iv3
        c[4] = c[4] ^ iv0
        c[5] = c[5] ^ iv1
        c[6] = c[6] ^ iv2
        c[7] = c[7] ^ iv3
        
        // Iterate system four times
        nextState()
        nextState()
        nextState()
        nextState()
    }
    
    private func nextState() {
        // Before an iteration the counters are incremented
        var carry = p7
        for var j = 0; j < 8; j++ {
            let prev = c[j]
            c[j] = prev &+ a[j] &+ carry
            carry = prev > c[j] ? 1 : 0 // detect overflow
        }
        p7 = carry // save last carry bit
        
        // Iteration of the system
        var newX = [UInt32](count: 8, repeatedValue: 0)
        newX[0] = g(0) &+ rotateLeft(g(7), 16) &+ rotateLeft(g(6), 16)
        newX[1] = g(1) &+ rotateLeft(g(0), 8)  &+ g(7)
        newX[2] = g(2) &+ rotateLeft(g(1), 16) &+ rotateLeft(g(0), 16)
        newX[3] = g(3) &+ rotateLeft(g(2), 8)  &+ g(1)
        newX[4] = g(4) &+ rotateLeft(g(3), 16) &+ rotateLeft(g(2), 16)
        newX[5] = g(5) &+ rotateLeft(g(4), 8)  &+ g(3)
        newX[6] = g(6) &+ rotateLeft(g(5), 16) &+ rotateLeft(g(4), 16)
        newX[7] = g(7) &+ rotateLeft(g(6), 8)  &+ g(5)
        x = newX
    }
    
    private func g(j: Int) -> UInt32 {
        let sum = x[j] &+ c[j]
        let square = UInt64(sum) * UInt64(sum)
        return UInt32(truncatingBitPattern: square ^ (square >> 32))
    }
    
    private func nextOutput() -> [UInt8] {
        nextState()
        
        var output16 = [UInt16](count: Rabbit.blockSize / 2, repeatedValue: 0)
        output16[7] = UInt16(truncatingBitPattern: x[0]) ^ UInt16(truncatingBitPattern: x[5] >> 16)
        output16[6] = UInt16(truncatingBitPattern: x[0] >> 16) ^ UInt16(truncatingBitPattern: x[3])
        output16[5] = UInt16(truncatingBitPattern: x[2]) ^ UInt16(truncatingBitPattern: x[7] >> 16)
        output16[4] = UInt16(truncatingBitPattern: x[2] >> 16) ^ UInt16(truncatingBitPattern: x[5])
        output16[3] = UInt16(truncatingBitPattern: x[4]) ^ UInt16(truncatingBitPattern: x[1] >> 16)
        output16[2] = UInt16(truncatingBitPattern: x[4] >> 16) ^ UInt16(truncatingBitPattern: x[7])
        output16[1] = UInt16(truncatingBitPattern: x[6]) ^ UInt16(truncatingBitPattern: x[3] >> 16)
        output16[0] = UInt16(truncatingBitPattern: x[6] >> 16) ^ UInt16(truncatingBitPattern: x[1])
        
        var output8 = [UInt8](count: Rabbit.blockSize, repeatedValue: 0)
        for var j = 0; j < output16.count; j++ {
            output8[j * 2] = UInt8(truncatingBitPattern: output16[j] >> 8)
            output8[j * 2 + 1] = UInt8(truncatingBitPattern: output16[j])
        }
        return output8
    }
    
    // MARK: - Public
    public func encrypt(bytes: [UInt8]) -> [UInt8] {
        setup()
        
        var result = [UInt8](count: bytes.count, repeatedValue: 0)
        var output = nextOutput()
        for var byteIdx = 0, outputIdx = 0; byteIdx < bytes.count; byteIdx++, outputIdx++ {
            if (outputIdx == Rabbit.blockSize) {
                output = nextOutput()
                outputIdx = 0
            }
            
            result[byteIdx] = bytes[byteIdx] ^ output[outputIdx]
        }
        return result
    }
    
    public func decrypt(bytes: [UInt8]) -> [UInt8] {
        return encrypt(bytes)
    }
}


// MARK: - Cipher

extension Rabbit: Cipher {
    public func cipherEncrypt(bytes:[UInt8]) -> [UInt8] {
        return self.encrypt(bytes)
    }
    
    public func cipherDecrypt(bytes: [UInt8]) -> [UInt8] {
        return self.decrypt(bytes)
    }
}
