/*
 * Copyright (c) 2017 Rob Clark <robdclark@gmail.com>
 *
 * Permission is hereby granted, free of charge, to any person obtaining a
 * copy of this software and associated documentation files (the "Software"),
 * to deal in the Software without restriction, including without limitation
 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
 * and/or sell copies of the Software, and to permit persons to whom the
 * Software is furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice (including the next
 * paragraph) shall be included in all copies or substantial portions of the
 * Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
 * SOFTWARE.
 */

#include <assert.h>
#include <err.h>
#include <fcntl.h>
#include <getopt.h>
#include <stdarg.h>
#include <stdbool.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>

#include "util/macros.h"
#include "afuc.h"
#include "asm.h"
#include "parser.h"
#include "util.h"

int gpuver;

/* bit lame to hard-code max but fw sizes are small */
static struct asm_instruction instructions[0x2000];
static unsigned num_instructions;

static struct asm_label labels[0x512];
static unsigned num_labels;

struct asm_instruction *
next_instr(int tok)
{
   struct asm_instruction *ai = &instructions[num_instructions++];
   assert(num_instructions < ARRAY_SIZE(instructions));
   ai->tok = tok;
   return ai;
}

void
decl_label(const char *str)
{
   struct asm_label *label = &labels[num_labels++];

   assert(num_labels < ARRAY_SIZE(labels));

   label->offset = num_instructions;
   label->label = str;
}

static int
resolve_label(const char *str)
{
   int i;

   for (i = 0; i < num_labels; i++) {
      struct asm_label *label = &labels[i];

      if (!strcmp(str, label->label)) {
         return label->offset;
      }
   }

   fprintf(stderr, "Undeclared label: %s\n", str);
   exit(2);
}

static afuc_opc
tok2alu(int tok)
{
   switch (tok) {
   case T_OP_ADD:
      return OPC_ADD;
   case T_OP_ADDHI:
      return OPC_ADDHI;
   case T_OP_SUB:
      return OPC_SUB;
   case T_OP_SUBHI:
      return OPC_SUBHI;
   case T_OP_AND:
      return OPC_AND;
   case T_OP_OR:
      return OPC_OR;
   case T_OP_XOR:
      return OPC_XOR;
   case T_OP_NOT:
      return OPC_NOT;
   case T_OP_SHL:
      return OPC_SHL;
   case T_OP_USHR:
      return OPC_USHR;
   case T_OP_ISHR:
      return OPC_ISHR;
   case T_OP_ROT:
      return OPC_ROT;
   case T_OP_MUL8:
      return OPC_MUL8;
   case T_OP_MIN:
      return OPC_MIN;
   case T_OP_MAX:
      return OPC_MAX;
   case T_OP_CMP:
      return OPC_CMP;
   case T_OP_MSB:
      return OPC_MSB;
   default:
      assert(0);
      return -1;
   }
}

static void
emit_instructions(int outfd)
{
   int i;

   /* there is an extra 0x00000000 which kernel strips off.. we could
    * perhaps use it for versioning.
    */
   i = 0;
   write(outfd, &i, 4);

   for (i = 0; i < num_instructions; i++) {
      struct asm_instruction *ai = &instructions[i];
      afuc_instr instr = {0};
      afuc_opc opc;

      /* special case, 2nd dword is patched up w/ # of instructions
       * (ie. offset of jmptbl)
       */
      if (i == 1) {
         assert(ai->is_literal);
         ai->literal &= ~0xffff;
         ai->literal |= num_instructions;
      }

      if (ai->is_literal) {
         write(outfd, &ai->literal, 4);
         continue;
      }

      switch (ai->tok) {
      case T_OP_NOP:
         opc = OPC_NOP;
         if (gpuver >= 6)
            instr.pad = 0x1000000;
         break;
      case T_OP_ADD:
      case T_OP_ADDHI:
      case T_OP_SUB:
      case T_OP_SUBHI:
      case T_OP_AND:
      case T_OP_OR:
      case T_OP_XOR:
      case T_OP_NOT:
      case T_OP_SHL:
      case T_OP_USHR:
      case T_OP_ISHR:
      case T_OP_ROT:
      case T_OP_MUL8:
      case T_OP_MIN:
      case T_OP_MAX:
      case T_OP_CMP:
      case T_OP_MSB:
         if (ai->has_immed) {
            /* MSB overlaps with STORE */
            assert(ai->tok != T_OP_MSB);
            if (ai->xmov) {
               fprintf(stderr,
                       "ALU instruction cannot have immediate and xmov\n");
               exit(1);
            }
            opc = tok2alu(ai->tok);
            instr.alui.dst = ai->dst;
            instr.alui.src = ai->src1;
            instr.alui.uimm = ai->immed;
         } else {
            opc = OPC_ALU;
            instr.alu.dst = ai->dst;
            instr.alu.src1 = ai->src1;
            instr.alu.src2 = ai->src2;
            instr.alu.xmov = ai->xmov;
            instr.alu.alu = tok2alu(ai->tok);
         }
         break;
      case T_OP_MOV:
         /* move can either be encoded as movi (ie. move w/ immed) or
          * an alu instruction
          */
         if ((ai->has_immed || ai->label) && ai->xmov) {
            fprintf(stderr, "ALU instruction cannot have immediate and xmov\n");
            exit(1);
         }
         if (ai->has_immed) {
            opc = OPC_MOVI;
            instr.movi.dst = ai->dst;
            instr.movi.uimm = ai->immed;
            instr.movi.shift = ai->shift;
         } else if (ai->label) {
            /* mov w/ a label is just an alias for an immediate, this
             * is useful to load the address of a constant table into
             * a register:
             */
            opc = OPC_MOVI;
            instr.movi.dst = ai->dst;
            instr.movi.uimm = resolve_label(ai->label);
            instr.movi.shift = ai->shift;
         } else {
            /* encode as: or $dst, $00, $src */
            opc = OPC_ALU;
            instr.alu.dst = ai->dst;
            instr.alu.src1 = 0x00; /* $00 reads-back 0 */
            instr.alu.src2 = ai->src1;
            instr.alu.xmov = ai->xmov;
            instr.alu.alu = OPC_OR;
         }
         break;
      case T_OP_CWRITE:
      case T_OP_CREAD:
      case T_OP_STORE:
      case T_OP_LOAD:
         if (gpuver >= 6) {
            if (ai->tok == T_OP_CWRITE) {
               opc = OPC_CWRITE6;
            } else if (ai->tok == T_OP_CREAD) {
               opc = OPC_CREAD6;
            } else if (ai->tok == T_OP_STORE) {
               opc = OPC_STORE6;
            } else if (ai->tok == T_OP_LOAD) {
               opc = OPC_LOAD6;
            }
         } else {
            if (ai->tok == T_OP_CWRITE) {
               opc = OPC_CWRITE5;
            } else if (ai->tok == T_OP_CREAD) {
               opc = OPC_CREAD5;
            } else if (ai->tok == T_OP_STORE || ai->tok == T_OP_LOAD) {
               fprintf(stderr, "load and store do not exist on a5xx\n");
               exit(1);
            }
         }
         instr.control.src1 = ai->src1;
         instr.control.src2 = ai->src2;
         instr.control.flags = ai->bit;
         instr.control.uimm = ai->immed;
         break;
      case T_OP_BRNE:
      case T_OP_BREQ:
         if (ai->has_immed) {
            opc = (ai->tok == T_OP_BRNE) ? OPC_BRNEI : OPC_BREQI;
            instr.br.bit_or_imm = ai->immed;
         } else {
            opc = (ai->tok == T_OP_BRNE) ? OPC_BRNEB : OPC_BREQB;
            instr.br.bit_or_imm = ai->bit;
         }
         instr.br.src = ai->src1;
         instr.br.ioff = resolve_label(ai->label) - i;
         break;
      case T_OP_RET:
         opc = OPC_RET;
         break;
      case T_OP_IRET:
         opc = OPC_RET;
         instr.ret.interrupt = 1;
         break;
      case T_OP_CALL:
         opc = OPC_CALL;
         instr.call.uoff = resolve_label(ai->label);
         break;
      case T_OP_PREEMPTLEAVE:
         opc = OPC_PREEMPTLEAVE6;
         instr.call.uoff = resolve_label(ai->label);
         break;
      case T_OP_SETSECURE:
         opc = OPC_SETSECURE;
         if (resolve_label(ai->label) != i + 3) {
            fprintf(stderr, "jump label %s is incorrect for setsecure\n",
                    ai->label);
            exit(1);
         }
         if (ai->src1 != 0x2) {
            fprintf(stderr, "source for setsecure must be $02\n");
            exit(1);
         }
         break;
      case T_OP_JUMP:
         /* encode jump as: brne $00, b0, #label */
         opc = OPC_BRNEB;
         instr.br.bit_or_imm = 0;
         instr.br.src = 0x00; /* $00 reads-back 0.. compare to 0 */
         instr.br.ioff = resolve_label(ai->label) - i;
         break;
      case T_OP_WAITIN:
         opc = OPC_WIN;
         break;
      default:
         unreachable("");
      }

      afuc_set_opc(&instr, opc, ai->rep);

      write(outfd, &instr, 4);
   }
}

unsigned
parse_control_reg(const char *name)
{
   /* skip leading "@" */
   return afuc_control_reg(name + 1);
}

static void
emit_jumptable(int outfd)
{
   uint32_t jmptable[0x80] = {0};
   int i;

   for (i = 0; i < num_labels; i++) {
      struct asm_label *label = &labels[i];
      int id = afuc_pm4_id(label->label);

      /* if it doesn't match a known PM4 packet-id, try to match UNKN%d: */
      if (id < 0) {
         if (sscanf(label->label, "UNKN%d", &id) != 1) {
            /* if still not found, must not belong in jump-table: */
            continue;
         }
      }

      jmptable[id] = label->offset;
   }

   write(outfd, jmptable, sizeof(jmptable));
}

static void
usage(void)
{
   fprintf(stderr, "Usage:\n"
                   "\tasm [-g GPUVER] filename.asm filename.fw\n"
                   "\t\t-g - specify GPU version (5, etc)\n");
   exit(2);
}

int
main(int argc, char **argv)
{
   FILE *in;
   char *file, *outfile;
   int c, ret, outfd;

   /* Argument parsing: */
   while ((c = getopt(argc, argv, "g:")) != -1) {
      switch (c) {
      case 'g':
         gpuver = atoi(optarg);
         break;
      default:
         usage();
      }
   }

   if (optind >= (argc + 1)) {
      fprintf(stderr, "no file specified!\n");
      usage();
   }

   file = argv[optind];
   outfile = argv[optind + 1];

   outfd = open(outfile, O_WRONLY | O_CREAT | O_TRUNC, 0644);
   if (outfd < 0) {
      fprintf(stderr, "could not open \"%s\"\n", outfile);
      usage();
   }

   in = fopen(file, "r");
   if (!in) {
      fprintf(stderr, "could not open \"%s\"\n", file);
      usage();
   }

   yyset_in(in);

   /* if gpu version not specified, infer from filename: */
   if (!gpuver) {
      if (strstr(file, "a5")) {
         gpuver = 5;
      } else if (strstr(file, "a6")) {
         gpuver = 6;
      }
   }

   ret = afuc_util_init(gpuver, false);
   if (ret < 0) {
      usage();
   }

   ret = yyparse();
   if (ret) {
      fprintf(stderr, "parse failed: %d\n", ret);
      return ret;
   }

   emit_instructions(outfd);
   emit_jumptable(outfd);

   close(outfd);

   return 0;
}
