import argparse
import csv
import sys
from chardet import UniversalDetector
import re
import datetime


errors = 0

class DBDatetime():
    def __init__(self, *args, **kwargs):
        args = [ str(a) for a in args ]
        for k,v in kwargs.items():
            args.append(f'{k}={v}')
        self.a = ', '.join(args)
    def __str__(self):
        return f'DBDatetime({self.a})'
    def __repr__(self):
        return str(self)
class UTC():
    def __str__(self):
        return 'UTC()'
    def __repr__(self):
        return str(self)

def decode_hex(m):
    byte = int(m.group(1), 16)
    if byte == 0:
        return b'\\\\0'
    return bytes([byte])

def decode_code_point(c):
    return chr(int(c.group(1)))

def _transcode_text(t, encodings, line):
    global errors
    valid_interps = {}
    first_interp = None
    # decode possible grammatech encoding
    dt = re.sub(br'&#xE0([0-9a-fA-f][0-9a-fA-f]);', decode_hex, t.encode('utf8'))
    for encoding in encodings:
        try:
            # some browsers apparently throw in hex codes for non-ascii
            dt = re.sub(br'\\[Xx]([0-9a-fA-f][0-9a-fA-f])', decode_hex, dt)
            transcoded = dt.decode(encoding)
            # some browsers throw in unicode code points
            transcoded = re.sub(r'&amp;#(\d\d\d\d\d);', decode_code_point, transcoded)
            if not first_interp:
                first_interp = transcoded
            valid_interps[transcoded] = encoding
        except UnicodeError:
            continue
    if len(valid_interps) > 1:
        errors += 1
        print(f'\nAmbiguous text on line {line}: "{t}" ("{dt}").')
        for interp, enc in valid_interps.items():
            print(f'{enc}: {interp}')
    elif len(valid_interps) == 0:
        errors += 1
        print(f'\nCould not decode line {line} "{t}" ("{dt}").')
        return t
    return first_interp


def recursive_transcode(o, encodings, line):
    if isinstance(o, str):
        return _transcode_text(o, encodings, line)
    elif isinstance(o, list) or isinstance(o, tuple):
        l = [ recursive_transcode(x, encodings, line) for x in o ]
        return type(o)(l)
    elif isinstance(o, dict):
        return { recursive_transcode(k, encodings, line) : recursive_transcode(v, encodings, line) for k,v in o.items}
    else:
        return o

def do_transcode(r, w, encodings):
    reader = csv.reader(r)
    writer = csv.writer(w)
    writer.writerow(next(reader))
    line = 1
    for row in reader:
        comments = []
        contents = eval(row[3].replace(r'\x', r'\\x'))
        contents = recursive_transcode(contents, encodings, line)
        row[3] = str(contents)
        for comment in contents:
            if comment[2] is not None:
                comments.append(comment[2])
        row[2] = '\n'.join(comments)
        if not row[2]:
            row[2] = ''
        writer.writerow(row)
        line += 1
    return errors

def main(argv):
    global errors
    argp = argparse.ArgumentParser(
            description='Transcode pre 6.1 annotations exported from a hub upgraded to this version.')
    argp.add_argument(
        'input_annotations',
        help='Annotations file containing pre 6.1 annotations.')
    argp.add_argument('output_annotations')
    argp.add_argument('encodings',
        metavar='input_encoding',
        nargs='*',
        help='List of encodings in the annotations file. This will try decoding each line one by one using each of these in order until one works.')
    arg_obj = argp.parse_args(argv[1:])
    encodings = arg_obj.encodings
    if not encodings:
        det = UniversalDetector()
        with open(arg_obj.input_annotations, 'rb') as f:
            for i, line in enumerate(f):
                line = re.sub(rb'&#xE0([0-9a-fA-f][0-9a-fA-f]);', decode_hex, line)
                line = re.sub(br'\\[Xx]([0-9a-fA-f][0-9a-fA-f])', decode_hex, line)
                det.feed(line)
                if det.done or i > 10000:
                    break
        det.close()
        if det.result['encoding'] is None:
            print('This encoding is mixed. You will have to list encodings as arguments to this command.')
            sys.exit(1)
        print('This looks like "%s" encoding. (%.1f%% confidence)' % (det.result['encoding'], det.result['confidence'] * 100))
        print('If it is not, type the correct encoding now.')
        print('Otherwise, press enter to continue.', end='')
        sys.stdout.flush()
        encodings = input()
        if not encodings:
            encodings = [det.result['encoding']]
        else:
            encodings = encodings.split(' ')

    encodings.append('utf8')
    # use latin1 here so the csv reader doesn't choke on mixed encodings
    with open(arg_obj.input_annotations, 'r', encoding='utf8', newline='') as r:
        with open(arg_obj.output_annotations, 'w', encoding='utf8', newline='') as w:
            do_transcode(r, w, encodings)

    print('\nIf the new annotations are being re-imported into the hub they were exported from, '
          'you will have to first delete its annotation comments.\n'
          'BEWARE: VISITING THIS URL WILL INSTANTLY DELETE ALL COMMENTS ON YOUR HUB.\n'
          r'Go to /sql.html?sql=update%20cs_warningreport%20set%20comments_text_xml%20%3D%27%27%2C%20comment_index_xml%3D%27%27,'
          ' then restart the hub.\n'
          f'{errors} errors detected; see details above')
    if errors:
        sys.exit(1)

if __name__ == '__main__':
    main(sys.argv)
