# Dump all JPEG tags from an SWF file
import  os
import  zlib
import  struct
import  StringIO

import  Image


# Helpers for reading SWF files
def CalcMaskShift(pos, len):
    shift = pos - len + 1
    return (pow(2, len) - 1) << shift, shift

class BitStream(object):
    lut = dict(((pos, len), CalcMaskShift(pos, len)) for pos in range(8) for len in range(1, pos+2))

    def __init__(self, fp):
        self.fp = fp
        self.next()
    
    def next(self):
        c = self.fp.read(1)
        if (c): self.curr_byte = ord(c)
        else: self.curr_byte = None
        self.bit_pos = 7

    def tell(self):
        return (self.fp.tell()-1, self.bit_pos)

    def seek(self, curr_byte, bit_pos=7):
        self.fp.seek(curr_byte)
        self.next()
        self.bit_pos = bit_pos

    def align(self):
        if (self.bit_pos != 7): self.next()

    def make_signed(self, val, size):
        flag = pow(2, size-1)
        if (val >= flag):
            return val - flag - flag
        else:
            return val
    
    # Bit order is preserved; msb stored in msb
    def read_partial_byte(self, size):
        mask, shift = self.lut[(self.bit_pos, size)]
        rv = (self.curr_byte & mask) >> shift
        self.bit_pos = (self.bit_pos - size) & 0x7
        if (self.bit_pos == 7): self.next()
        return rv
    
    # Bitfields are stored in MSB order
    def read_bits(self, size, signedp):
        # Segment read
        head_len = min(size, self.bit_pos+1)
        body_len = (size - head_len) / 8
        tail_len = (size - head_len) & 0x7
        # Perform read
        rv = 0
        if (head_len):
            rv = self.read_partial_byte(head_len)
        while (body_len):
            rv = rv * 256 + self.curr_byte
            self.next(); body_len -= 1
        if (tail_len):
            rv = (rv << tail_len) + self.read_partial_byte(tail_len)
        if (signedp):
            rv = self.make_signed(rv, size)
        return rv

    # Byte values are stored in LSB order
    def read_bytes(self, size, signedp):
        self.align(); rv = 0; factor = 1; nbits = size*8
        while (size):
            rv += (self.curr_byte*factor)
            self.next(); factor *= 256; size -= 1
        if (signedp):
            rv = self.make_signed(rv, nbits)
        return rv

    def read_raw(self, size):
        self.align()
        data = chr(self.curr_byte) + self.fp.read(size-1)
        self.next()
        return data


# Low-level extraction code
def skip_bytes(bs, size):
    byte_pos, bit_pos = bs.tell()
    bs.seek(byte_pos+size)
    
def read_rect(bs):
    nbits   = bs.read_bits(5, False)
    xmin    = bs.read_bits(nbits, True)
    xmax    = bs.read_bits(nbits, True)
    ymin    = bs.read_bits(nbits, True)
    ymax    = bs.read_bits(nbits, True)
    return (xmin, xmax, ymin, ymax)

def read_movie_header(bs):
    rect    = read_rect(bs)
    fps     = bs.read_bytes(2, False)/256.0
    nframes = bs.read_bytes(2, False)
    return ([n/20.0 for n in rect], fps, nframes)

def read_jpeg_table(bs, length):
    # Strip the trailing end-of-stream
    rv = bs.read_raw(length-2)
    skip_bytes(bs, 2)
    return rv

def read_jpeg_bits(bs, length, table):
    # Get the bitmap ID
    id = bs.read_bytes(2, False)
    # Omit the opening beginning-of-stream
    skip_bytes(bs, 2)
    # Return a complete JPEG
    return (id, table+bs.read_raw(length-4))

def read_jpeg_bits_2(bs, length):
    # Contrary to documentation, this appears to consist of only a single
    # JPEG stream - there is no FF D9 FF D8 quad in the datastream
    id = bs.read_bytes(2, False)
    return (id, bs.read_raw(length-2))

def read_jpeg_bits_3(bs, length):
    # Most apps don't like SWF's two-stream-per-file business, so this
    # crudely strips out the end-of-stream / start-of-stream tag pair.
    # A little risky, but there's only a 2**-32 chance of it occuring randomly
    id          = bs.read_bytes(2, False)
    jpg_len     = bs.read_bytes(4, False)
    img_data    = bs.read_raw(jpg_len).replace('\xff\xd9\xff\xd8', '')
    alpha_data  = zlib.decompress(bs.read_raw(length-6-jpg_len))
    return (id, img_data, alpha_data)


# Extraction utility fxn
def dump_jpegs(fn):
    pn = os.path.split(fn)[0]
    fp = file(fn, 'rb')

    sig, ver, length = struct.unpack('<3sBL', fp.read(8))
    if (sig == 'CWS'):
        bs = BitStream(StringIO.StringIO(zlib.decompress(fp.read())))
    elif (sig == 'FWS'):
        bs = BitStream(fp)
    else:
        return

    rect, fps, nframes = read_movie_header(bs)
    print 'sig:        ' + sig
    print 'version:    %d' % ver
    print 'length:     %d' % length
    print 'screen:     %.1fx%.1f' % (rect[1]-rect[0], rect[3]-rect[2])
    print 'fps:        %.1f' % fps
    print 'num frames: %d' % nframes

    table = None
    while (1):
        # Read tag header
        code    = bs.read_bytes(2, False)
        tag     = code >> 6
        length  = code & 0x3f
        if (length == 63):
            length = bs.read_bytes(4, False)
        # Process JPEG tags, or skip
        if (tag == 0):
            break
        elif (tag == 8):
            table = read_jpeg_table(bs, length)
        elif (tag == 6):
            id, bits = read_jpeg_bits(bs, length, table)
            file(os.path.join(pn, '%d.jpg'%id), 'wb').write(bits)
        elif (tag == 21):
            id, bits = read_jpeg_bits_2(bs, length)
            file(os.path.join(pn, '%d.jpg'%id), 'wb').write(bits)
        elif (tag == 35):
            id, img_bits, alpha_bits = read_jpeg_bits_3(bs, length)
            img = Image.open(StringIO.StringIO(img_bits)).convert('RGBA')
            img.putalpha(Image.fromstring('L', img.size, alpha_bits))
            img.save(os.path.join(pn, '%d.png'%id))
        else:
            skip_bytes(bs, length)

