#
#
#           The Nim Compiler
#        (c) Copyright 2018 Andreas Rumpf
#
#    See the file "copying.txt", included in this
#    distribution, for details about the copyright.
#

## Intermediate representation for Nim's C++ backend.
## Experimental stuff. The IR is based on "packed" trees. I hope to replace Nim's
## AST with these trees one day. Currently this IR is designed for producing
## C++ code, but eventually it will be used for the other backends too.

import std / varints
import ast

type
  SymId* = distinct int
  TypeId* = distinct int

const
  opcodeEnd = 255u8

proc storeAtom(buf: var seq[byte]; k: NodeKind; data: string) =
  buf.add byte(k)
  # ensure we have the space:
  let start = buf.len
  buf.setLen start + maxVarIntLen
  let realVlen = writeVu64(toOpenArray(buf, start, start + maxVarIntLen - 1), uint64 data.len)
  buf.setLen start + realVlen
  for i in 0..high(data):
    buf.add byte(data[i])

proc beginContainer(buf: var seq[byte]; k: NodeKind) =
  buf.add byte(k)

proc endContainer(buf: var seq[byte]) = buf.add opcodeEnd

type
  IrNode* = object
    k: IrNodeKind
    a, b: int
    t: ref seq[byte]

  IrTree* = distinct IrNode ## a IrTree is a IrNode that can be mutated.

converter toIrNode*(x: IrTree): IrNode {.inline.} = IrNode(x)

proc newJNull*(): IrNode =
  ## Creates a new `JNull IrNode`.
  result.k = JNull

template newBody(kind, x) =
  new(result.t)
  result.t[] = @[]
  storeAtom(result.t[], kind, x)
  result.a = 0
  result.b = high(result.t[])
  result.k = kind

proc newJString*(s: string): IrNode =
  ## Creates a new `JString IrNode`.
  newBody JString, s

proc newJInt*(n: BiggestInt): IrNode =
  ## Creates a new `JInt IrNode`.
  newBody JInt, $n

proc newJFloat*(n: float): IrNode =
  ## Creates a new `JFloat IrNode`.
  newBody JFloat, formatFloat(n)

proc newJBool*(b: bool): IrNode =
  ## Creates a new `JBool IrNode`.
  result.k = JBool
  new(result.t)
  result.t[] = @[if b: byte(opcodeTrue) else: byte(opcodeFalse)]
  result.a = 0
  result.b = high(result.t[])

proc newJObject*(): IrTree =
  ## Creates a new `JObject IrNode`
  IrNode(result).k = JObject
  new(IrNode(result).t)
  IrNode(result).t[] = @[byte opcodeObject, byte opcodeEnd]
  IrNode(result).a = 0
  IrNode(result).b = high(IrNode(result).t[])

proc newJArray*(): IrTree =
  ## Creates a new `JArray IrNode`
  IrNode(result).k = JArray
  new(IrNode(result).t)
  IrNode(result).t[] = @[byte opcodeArray, byte opcodeEnd]
  IrNode(result).a = 0
  IrNode(result).b = high(IrNode(result).t[])

proc kind*(x: IrNode): IrNodeKind = x.k

proc extractLen(x: seq[byte]; pos: int): int =
  var varint: uint64
  let varintLen = readVu64(toOpenArray(x, pos+1, min(pos + 1 + maxVarIntLen, x.high)), varint)
  result = int(varint) + varintLen

proc extractSlice(x: seq[byte]; pos: int): (int, int) =
  var varint: uint64
  let varintLen = readVu64(toOpenArray(x, pos+1, min(pos + 1 + maxVarIntLen, x.high)), varint)
  result = (pos + 1 + varintLen, int(varint))

proc skip(x: seq[byte]; start: int; elements: var int): int =
  var nested = 0
  var pos = start
  while true:
    let k = x[pos] and opcodeMask
    var nextPos = pos + 1
    case k
    of opcodeNull, opcodeBool:
      if nested == 0: inc elements
    of opcodeInt, opcodeFloat, opcodeString:
      let L = extractLen(x, pos)
      nextPos = pos + 1 + L
      if nested == 0: inc elements
    of opcodeObject, opcodeArray:
      if nested == 0: inc elements
      inc nested
    of opcodeEnd:
      if nested == 0: return nextPos
      dec nested
    else: discard
    pos = nextPos

iterator items*(x: IrNode): IrNode =
  ## Iterator for the items of `x`. `x` has to be a JArray.
  assert x.kind == JArray
  var pos = x.a+1
  var dummy: int
  while pos <= x.b:
    let k = x.t[pos] and opcodeMask
    var nextPos = pos + 1
    case k
    of opcodeNull, opcodeBool: discard
    of opcodeInt, opcodeFloat, opcodeString:
      let L = extractLen(x.t[], pos)
      nextPos = pos + 1 + L
    of opcodeObject, opcodeArray:
      nextPos = skip(x.t[], pos+1, dummy)
    of opcodeEnd: break
    else: discard
    yield IrNode(k: IrNodeKind(k), a: pos, b: nextPos-1, t: x.t)
    pos = nextPos

iterator pairs*(x: IrNode): (string, IrNode) =
  ## Iterator for the pairs of `x`. `x` has to be a JObject.
  assert x.kind == JObject
  var pos = x.a+1
  var dummy: int
  var key = newStringOfCap(60)
  while pos <= x.b:
    let k2 = x.t[pos] and opcodeMask
    if k2 == opcodeEnd: break

    assert k2 == opcodeString, $k2
    let (start, L) = extractSlice(x.t[], pos)
    key.setLen L
    for i in 0 ..< L: key[i] = char(x.t[start+i])
    pos = start + L

    let k = x.t[pos] and opcodeMask
    var nextPos = pos + 1
    case k
    of opcodeNull, opcodeBool: discard
    of opcodeInt, opcodeFloat, opcodeString:
      let L = extractLen(x.t[], pos)
      nextPos = pos + 1 + L
    of opcodeObject, opcodeArray:
      nextPos = skip(x.t[], pos+1, dummy)
    of opcodeEnd: doAssert false, "unexpected end of object"
    else: discard
    yield (key, IrNode(k: IrNodeKind(k), a: pos, b: nextPos-1, t: x.t))
    pos = nextPos

proc rawGet(x: IrNode; name: string): IrNode =
  assert x.kind == JObject
  var pos = x.a+1
  var dummy: int
  while pos <= x.b:
    let k2 = x.t[pos] and opcodeMask
    if k2 == opcodeEnd: break

    assert k2 == opcodeString, $k2
    let (start, L) = extractSlice(x.t[], pos)
    # compare for the key without creating the temp string:
    var isMatch = name.len == L
    if isMatch:
      for i in 0 ..< L:
        if name[i] != char(x.t[start+i]):
          isMatch = false
          break
    pos = start + L

    let k = x.t[pos] and opcodeMask
    var nextPos = pos + 1
    case k
    of opcodeNull, opcodeBool: discard
    of opcodeInt, opcodeFloat, opcodeString:
      let L = extractLen(x.t[], pos)
      nextPos = pos + 1 + L
    of opcodeObject, opcodeArray:
      nextPos = skip(x.t[], pos+1, dummy)
    of opcodeEnd: doAssert false, "unexpected end of object"
    else: discard
    if isMatch:
      return IrNode(k: IrNodeKind(k), a: pos, b: nextPos-1, t: x.t)
    pos = nextPos
  result.a = -1

proc `[]`*(x: IrNode; name: string): IrNode =
  ## Gets a field from a `JObject`.
  ## If the value at `name` does not exist, raises KeyError.
  result = rawGet(x, name)
  if result.a < 0:
    raise newException(KeyError, "key not found in object: " & name)

proc len*(n: IrNode): int =
  ## If `n` is a `JArray`, it returns the number of elements.
  ## If `n` is a `JObject`, it returns the number of pairs.
  ## Else it returns 0.
  if n.k notin {JArray, JObject}: return 0
  discard skip(n.t[], n.a+1, result)
  # divide by two because we counted the pairs wrongly:
  if n.k == JObject: result = result shr 1

proc rawAdd(obj: var IrNode; child: seq[byte]; a, b: int) =
  let pa = obj.b
  let L = b - a + 1
  let oldfull = obj.t[].len
  setLen(obj.t[], oldfull+L)
  # now move the tail to the new end so that we can insert effectively
  # into the middle:
  for i in countdown(oldfull+L-1, pa+L):
    shallowCopy(obj.t[][i], obj.t[][i-L])
  # insert into the middle:
  for i in 0 ..< L:
    obj.t[][pa + i] = child[a + i]
  inc obj.b, L

proc add*(parent: var IrTree; child: IrNode) =
  doAssert parent.kind == JArray, "parent is not a JArray"
  rawAdd(IrNode(parent), child)

proc add*(obj: var IrTree, key: string, val: IrNode) =
  ## Sets a field from a `JObject`. **Warning**: It is currently not checked
  ## but assumed that the object does not yet have a field named `key`.
  assert obj.kind == JObject
  let k = newJstring(key)
  # XXX optimize this further!
  rawAdd(IrNode obj, k.t[], 0, high(k.t[]))
  rawAdd(IrNode obj, val)
  when false:
    discard "XXX assert that the key does not exist yet"

proc rawPut(obj: var IrNode, oldval: IrNode, key: string, val: IrNode): int =
  let oldlen = oldval.b - oldval.a + 1
  let newlen = val.b - val.a + 1
  result = newlen - oldlen
  if result == 0:
    for i in 0 ..< newlen:
      obj.t[][oldval.a + i] = (if val.k == JNull: byte opcodeNull else: val.t[][i])
  else:
    let oldfull = obj.t[].len
    if newlen > oldlen:
      setLen(obj.t[], oldfull+result)
      # now move the tail to the new end so that we can insert effectively
      # into the middle:
      for i in countdown(oldfull+result-1, oldval.a+newlen): shallowCopy(obj.t[][i], obj.t[][i-result])
    else:
      for i in countup(oldval.a+newlen, oldfull+result-1): shallowCopy(obj.t[][i], obj.t[][i-result])
      # cut down:
      setLen(obj.t[], oldfull+result)
    # overwrite old value:
    for i in 0 ..< newlen:
      obj.t[][oldval.a + i] = (if val.k == JNull: byte opcodeNull else: val.t[][i])

proc `[]=`*(obj: var IrTree, key: string, val: IrNode) =
  let oldval = rawGet(obj, key)
  if oldval.a < 0:
    add(obj, key, val)
  else:
    let diff = rawPut(IrNode obj, oldval, key, val)
    inc IrNode(obj).b, diff

macro `[]=`*(obj: var IrTree, keys: varargs[typed], val: IrNode): untyped =
  ## keys can be strings or integers for the navigation.
  result = newStmtList()
  template t0(obj, key) {.dirty.} =
    var oldval = obj[key]

  template ti(key) {.dirty.} =
    oldval = oldval[key]

  template tput(obj, finalkey, val) =
    let diff = rawPut(IrNode obj, oldval, finalkey, val)
    inc IrNode(obj).b, diff

  result.add getAst(t0(obj, keys[0]))
  for i in 1..<len(keys):
    result.add getAst(ti(keys[i]))
  result.add getAst(tput(obj, keys[len(keys)-1], val))

proc rawDelete(x: var IrNode, key: string) =
  assert x.kind == JObject
  var pos = x.a+1
  var dummy: int
  while pos <= x.b:
    let k2 = x.t[pos] and opcodeMask
    if k2 == opcodeEnd: break

    assert k2 == opcodeString, $k2
    let begin = pos
    let (start, L) = extractSlice(x.t[], pos)
    # compare for the key without creating the temp string:
    var isMatch = key.len == L
    if isMatch:
      for i in 0 ..< L:
        if key[i] != char(x.t[start+i]):
          isMatch = false
          break
    pos = start + L

    let k = x.t[pos] and opcodeMask
    var nextPos = pos + 1
    case k
    of opcodeNull, opcodeBool: discard
    of opcodeInt, opcodeFloat, opcodeString:
      let L = extractLen(x.t[], pos)
      nextPos = pos + 1 + L
    of opcodeObject, opcodeArray:
      nextPos = skip(x.t[], pos+1, dummy)
    of opcodeEnd: doAssert false, "unexpected end of object"
    else: discard
    if isMatch:
      let diff = nextPos - begin
      let oldfull = x.t[].len
      for i in countup(begin, oldfull-diff-1): shallowCopy(x.t[][i], x.t[][i+diff])
      setLen(x.t[], oldfull-diff)
      dec x.b, diff
      return
    pos = nextPos
  # for compatibility with json.nim, we need to raise an exception
  # here. Not sure it's good idea.
  raise newException(KeyError, "key not in object: " & key)

proc delete*(x: var IrTree, key: string) =
  ## Deletes ``x[key]``.
  rawDelete(IrNode x, key)
