BaseJpegDecode.py

#!/usr/bin/env python
# coding:utf-8
# BaseJpegDecode.py 2012.10.23
import abc
import logging
from collections import Counter

class BitPattern:
    def __init__(self, data=0, size=0):
        self.data = data
        self.size = size

    def clear(self):
        self.data = 0
        self.size = 0

    def put(self, c, size=8):
        self.data <<= size
        self.data |= c
        self.size += size
        if self.size > 32:
            logging.error("%08X %d" % (self.data, self.size))
            raise ValueError('over flow 32bit')

    def peek(self, size):
        return self.data >> (self.size-size)

    def get(self, size):
        r = self.peek(size)
        self.size -= size
        self.data &= (1<<self.size)-1
        return r

    def match(self, b):
        if b.size > self.size:
            return False
        return self.peek(b.size) == b.peek(b.size)

class Huff:
    def __init__(self, run, value_size, code):
        self.run = run
        self.value_size = value_size
        self.code = code

class HuffmanDecode(object):
    def __init__(self):
        self._ht = {}

    def _ht_clear(self, tc, th):
        self._ht[(tc,th)] = []

    def _ht_append(self, tc, th, huff):
        self._ht[(tc,th)].append(huff)

    def inputDHT(self, c, pos, size):
        if pos == 0:
            self._buf = []
        self._buf.append(c)
        if pos < (size-1):
            return
        pos = 0
        while pos < len(self._buf):
            uc = self._buf[pos]
            pos += 1
            tc = uc >> 4
            th = uc & 0x0f
            logging.info("DHT: Tc=%d Th=%d" %(tc, th))
            self._ht_clear(tc, th)
            l_pos = pos
            pos += 16
            code = 0x0000
            for i in range(16):
                l = self._buf[l_pos + i]
                for k in range(l):
                    value = self._buf[pos]
                    run = value >> 4;
                    value_size = value & 0x0f
                    h = Huff(run, value_size, BitPattern(code, i+1))
                    self._ht_append(tc, th, h)
                    pos += 1
                    code += 1
                code <<= 1

    def Lookup(self, tc, th, bitpat):
        for h in self._ht[(tc,th)]:
            if h.code.size > bitpat.size:
                return None
            if bitpat.match(h.code):
                return h
        logging.error("(%d,%d)%08X %d" % (tc, th, bitpat.data, bitpat.size))
        raise ValueError('Huffman decode error')
        return None

    def getValue(self, huff, bitpat):
        if huff.value_size == 0:
            return 0
        value = bitpat.get(huff.value_size)
        if value & (1<<(huff.value_size-1)):
            return value
        value -= (1<<huff.value_size)-1
        return value

MARK_SOF0 = 0xc0
MARK_DHT  = 0xc4
MARK_RST0 = 0xd0
MARK_RST7 = 0xd7
MARK_SOI  = 0xd8
MARK_EOI  = 0xd9
MARK_SOS  = 0xda
MARK_DQT  = 0xdb
MARK_DRI  = 0xdd
MARK_APP  = 0xe0

SEQ_INIT = 0
SEQ_MARK = 1
SEQ_SEG_LEN  = 2
SEQ_SEG_LEN2 = 3
SEQ_SEG_BODY = 4
SEQ_SOS  = 5
SEQ_SOS2 = 6

class BaseJpegDecode(object):
    __metaclass__ = abc.ABCMeta

    def __init__(self):
        self._hd = HuffmanDecode()
        self.qt = {}
        self.clear()

    def clear(self):
        self._seq = SEQ_INIT

    @abc.abstractmethod
    def outputDC(self, mcu, block, value):
        pass

    @abc.abstractmethod
    def outputAC(self, mcu, block, scan, value):
        pass

    @abc.abstractmethod
    def outputMARK(self, c):
        pass

    def _restart(self):
        self._block = 0
        self._scan = 0
        self._old_DC_value = Counter()
        self._bitpat = BitPattern()
        self._huff = None

    def _inputScan(self, c):
        self._bitpat.put(c)
        while self._bitpat.size > 0:
            if self._scan == 0:
                tc = 0 # DC
            else:
                tc = 1 # AC
            if self._block < self._yblock: # 2 or 4
                th = 0 # Y
            else:
                th = 1 # CbCr
            if self._huff == None:
                self._huff = self._hd.Lookup(tc, th, self._bitpat)
                if self._huff == None:
                    break
                self._bitpat.get(self._huff.code.size) # skip code
            if self._huff.value_size > self._bitpat.size:
                break
            value = self._hd.getValue(self._huff, self._bitpat)
            if self._scan == 0: #DC
                value += self._old_DC_value[th]
                self.outputDC(self._mcu, self._block, value)
                self._old_DC_value[th] = value
                self._scan += 1
            else: # AC
                if self._huff.run == 0 and self._huff.value_size == 0: # EOB
                    self._scan = 63
                    self.outputAC(self._mcu, self._block, self._scan, 0)
                    self._scan = 64
                else:
                    for i in range(self._huff.run):
                        self.outputAC(self._mcu, self._block, self._scan, 0)
                        self._scan += 1
                    self.outputAC(self._mcu, self._block, self._scan, value)
                    self._scan += 1
                if self._scan >= 64:
                    self._scan = 0
                    self._block += 1
                    if self._block >= (self._yblock+2): # 4 or 6
                        self._block = 0
                        self._mcu += 1
            self._huff = None

    def _inputDQT(self, c, pos, size):
        if pos == 0 or pos == 65:
            self._tq = c
            self.qt[self._tq] = []
        else:
            self.qt[self._tq].append(c)
        if pos == (size-1): # last
            for tq in range(2):
                if tq in self.qt:
                    s = ",".join(["%d" % c for c in self.qt[tq]])
                    logging.info("DQT(%d): %s" % (tq, s))

    def _inputSOF(self, c, pos, len):
        if pos == 1:
            self.height = c<<8
        elif pos == 2:
            self.height += c
        elif pos == 3:
            self.width = c<<8
        elif pos == 4:
            self.width += c
        elif pos == 7:
            if c == 0x22:
                self._yblock = 4
            elif c == 0x21:
                self._yblock = 2
            else:
                raise ValueError('SOF error')
        if pos == (len-1):
            logging.info("SOF: width=%d height=%d yblock=%d" % (self.width, self.height, self._yblock))

    def _inputSOS(self, c, pos, len):
        if pos == 0:
            self._buf = []
        self._buf.append(c)
        if pos == (len-1): #last
            s = ",".join(["%02X" % c for c in self._buf])
            logging.info("SOS: "+ s)

    def input(self, c):
        if c < 0 or c > 0xff:
            raise ValueError('input error')
        if self._seq == SEQ_INIT:
            if c == 0xff:
                self._seq = SEQ_MARK
        elif self._seq == SEQ_MARK:
            self.outputMARK(c)
            if c == MARK_SOI:
                self._seq = SEQ_INIT
            else:
                self._mark = c
                self._seq = SEQ_SEG_LEN
        elif self._seq == SEQ_SEG_LEN:
            self._seg_len = c << 8
            self._seq = SEQ_SEG_LEN2
        elif self._seq == SEQ_SEG_LEN2:
            self._seg_len += c
            self._seg_len -= 2
            self._seg_pos = 0
            self._seq = SEQ_SEG_BODY
        elif self._seq == SEQ_SEG_BODY:
            if self._mark == MARK_SOS: # SOS
                self._inputSOS(c, self._seg_pos, self._seg_len)
            elif self._mark == MARK_SOF0: # SOF0
                self._inputSOF(c, self._seg_pos, self._seg_len)
            elif self._mark == MARK_DQT: # DQT
                self._inputDQT(c, self._seg_pos, self._seg_len)
            elif self._mark == MARK_DHT: # DHT
                self._hd.inputDHT(c, self._seg_pos, self._seg_len)
            else:
                pass
            self._seg_pos += 1
            if self._seg_pos < self._seg_len:
                return
            if self._mark == MARK_SOS: # SOS
                self._mcu = 0
                self._restart()
                self._seq = SEQ_SOS
            else:
                self._seq = SEQ_INIT
        elif self._seq == SEQ_SOS:
            if c == 0xff:
                self._seq = SEQ_SOS2
            else:
                self._inputScan(c)
        elif self._seq == SEQ_SOS2:
            if c == 0x00:
                self._inputScan(0xff)
                self._seq = SEQ_SOS
            elif c >= MARK_RST0 and c <= MARK_RST7: # RSTx
                self._restart()
                self._seq = SEQ_SOS
            elif c == MARK_EOI: # EOI
                self.outputMARK(c)
                self._seq = SEQ_INIT
            else:
                self.outputMARK(c)
                self._seq = SEQ_INIT
        else:
            pass

if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO)

    class DemoJpeg(BaseJpegDecode):
        def __init__(self):
            super(DemoJpeg, self).__init__()

        def outputDC(self, mcu, block, value):
            if self._yblock == 2:
                if block <= 1:
                    print "%3d" % value,
                    if block == 1:
                        if (mcu % (self.width/16)) == (self.width/16)-1:
                            print ""
                return
            if block == 0:
                self._value = 0
            if block <= 3:
                self._value += (value+512)
                if block == 3:
                    print("%02X" % (self._value/16)),
                    if (mcu % (self.width/16)) == (self.width/16)-1:
                        print ""

        def outputAC(self, mcu, block, scan, value):
            pass

        def outputMARK(self, c):
            print("MARK: %02X" % c)

    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('infiles', nargs='*')
    args = parser.parse_args()

    jpeg = DemoJpeg()

    for filename in args.infiles:
        with open(filename, "rb") as f:
            data = f.read()
        print("%s %d" % (filename, len(data)))
        jpeg.clear()
        for c in data:
            jpeg.input(ord(c))