# Copyright (c) 2018 - 2024 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.

from __future__ import print_function
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from future import standard_library
standard_library.install_aliases()
from builtins import *
from builtins import str
import os, sys, struct
import nnir as ir

def generateGraph(graph,outputFolder,label):
    fileName = outputFolder + '/graph.nnef'
    print('creating ' + fileName + ' ...')
    with open(fileName, 'wb') as f:
        f.write( \
"""# This file is generated by nnir2nnef.py
version 1.0;

graph nnir (%s) -> (%s) {
""" % (', '.join([tensor.name for tensor in graph.inputs]), ', '.join([tensor.name for tensor in graph.outputs])))
        f.write( \
"""
    # external inputs
""")
        for tensor in graph.inputs:
            f.write( \
"""    %s = external(shape = [%s]);
""" % (tensor.name, ', '.join([str(dim) for dim in tensor.shape])))
        f.write( \
"""
    # variables
""")
        for tensor in graph.initializers:
            f.write( \
"""    %s = variable(shape = [%s], label = '%s/%s');
""" % (tensor.name, ', '.join([str(dim) for dim in tensor.shape]), label, tensor.name))
        f.write( \
"""
    # nodes
""")
        for node in graph.nodes:
            if node.type == 'conv':
                pads = node.attr.get('pads')
                strides = node.attr.get('strides')
                dilations = node.attr.get('dilations')
                group = node.attr.get('group')
                f.write( \
"""    %s = conv(%s, %s, %sstride=[%d,%d], dilation=[%d,%d], padding=[(%d,%d),(%d,%d)], groups=%d, border = 'ignore');
""" % (node.outputs[0], node.inputs[0], node.inputs[1], node.inputs[2] + ', ' if len(node.inputs) == 3 else '', \
       strides[0], strides[1], dilations[0], dilations[1], pads[0], pads[1], pads[2], pads[3], group))
            elif node.type == 'avg_pool' or node.type == 'max_pool':
                kernel_shape = node.attr.get('kernel_shape')
                pads = node.attr.get('pads')
                padding = '(0,0),(0,0),(%d,%d),(%d,%d)' % (pads[0], pads[1], pads[2], pads[3]) if len(pads) != 0 else ''
                strides = node.attr.get('strides')
                stride = '1,1,%d,%d' % (strides[0], strides[1]) if len(strides) != 0 else ''
                dilations = node.attr.get('dilations')
                dilation = '1,1,%d,%d' % (dilations[0], dilations[1]) if len(dilations) != 0 else ''
                f.write( \
"""    
    %s = %s(%s, size=[1,1,%d,%d], stride=[%s], dilation=[%s], padding=[%s], border = 'ignore');
""" % (node.outputs[0], node.type, node.inputs[0], kernel_shape[0], kernel_shape[1], \
       stride, dilation, padding))
            elif node.type == 'relu' or node.type == 'softmax':
                f.write( \
"""    %s = %s(%s);
""" % (node.outputs[0], node.type, node.inputs[0]))
            elif node.type == 'sum':
                f.write( \
"""    %s = add(%s, %s);
""" % (node.outputs[0], node.inputs[0], node.inputs[1]))
            elif node.type == 'batch_norm':
                f.write( \
"""    %s = batch_normalization(%s, %s, %s, %s, %s, epsilon = %e);
""" % (node.outputs[0], node.inputs[0], node.inputs[3], node.inputs[4], node.inputs[2], node.inputs[1], node.attr.get('epsilon')))
            elif node.type == 'gemm':
                f.write( \
"""    %s = matmul(%s, %s, transposeA = %s, transposeB = %s);
""" % (node.outputs[0], node.inputs[0], node.inputs[1], \
       'true' if node.attr.get('transA') == 1 else 'false', \
       'true' if node.attr.get('transB') == 1 else 'false'))
        f.write( \
"""}
""")

def generateBinaries(graph,outputFolder,label):
    binaryFolder = outputFolder + '/' + label
    print('creating variables in ' + binaryFolder + ' ...')
    if not os.path.isdir(binaryFolder):
        os.mkdir(binaryFolder)
    for tensor in graph.initializers:
        fileName = binaryFolder + '/' + tensor.name + '.dat'
        with open(fileName, 'wb') as f:
            binary = graph.binaries[tensor.name]
            f.write(struct.pack('BBBB', 0x4E, 0xEF, 1, 0))
            f.write(struct.pack('I', 16 + 4 * len(tensor.shape)))
            f.write(struct.pack('I', len(tensor.shape)))
            for dim in tensor.shape:
                f.write(struct.pack('I', dim))
            f.write(struct.pack('BBH', 0, 32, 0))
            f.write(binary)

def generateNNEF(graph,outputFolder):
    if not os.path.isdir(outputFolder):
        os.mkdir(outputFolder)
    label = 'binary'
    generateGraph(graph,outputFolder, label)
    generateBinaries(graph,outputFolder, label)

def main():
    if len(sys.argv) < 3:
        print('Usage: python nnir2nnef.py <nnirInputFolder> <outputFolder>')
        sys.exit(1)
    inputFolder = sys.argv[1]
    outputFolder = sys.argv[2]
    print('reading NNIR model from ' + inputFolder + ' ...')
    graph = ir.IrGraph(True)
    graph.fromFile(inputFolder)
    print('creating NNEF in ' + outputFolder + ' ...')
    generateNNEF(graph,outputFolder)

if __name__ == '__main__':
    main()
