Skip to content

Commit

Permalink
dev(narugo): add nai metadata support
Browse files Browse the repository at this point in the history
  • Loading branch information
narugo1992 committed Sep 7, 2024
1 parent 34af200 commit 577cbce
Show file tree
Hide file tree
Showing 11 changed files with 385 additions and 1 deletion.
1 change: 1 addition & 0 deletions imgutils/sd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
"""
from .metadata import parse_sdmeta_from_text, get_sdmeta_from_image, SDMetaData
from .model import read_metadata, save_with_metadata
from .nai import get_naimeta_from_image, NAIMetadata
3 changes: 3 additions & 0 deletions imgutils/sd/nai/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .extract import LSBExtractor, ImageLsbDataExtractor
from .inject import serialize_metadata, inject_data
from .metadata import get_naimeta_from_image, NAIMetadata, add_naimeta_to_image, save_image_with_naimeta
75 changes: 75 additions & 0 deletions imgutils/sd/nai/extract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import gzip
import json

import numpy as np
from PIL import Image


# MIT: https://github.com/NovelAI/novelai-image-metadata/blob/main/nai_meta.py
class LSBExtractor(object):
def __init__(self, data: np.ndarray):
self.data = data
self.rows, self.cols, self.dim = data.shape
self.bits = 0
self.byte = 0
self.row = 0
self.col = 0

def _extract_next_bit(self):
if self.row < self.rows and self.col < self.cols:
bit = self.data[self.row, self.col, self.dim - 1] & 1
self.bits += 1
self.byte <<= 1
self.byte |= bit
self.row += 1
if self.row == self.rows:
self.row = 0
self.col += 1

def get_one_byte(self):
while self.bits < 8:
self._extract_next_bit()
byte = bytearray([self.byte])
self.bits = 0
self.byte = 0
return byte

def get_next_n_bytes(self, n):
bytes_list = bytearray()
for _ in range(n):
byte = self.get_one_byte()
if not byte:
break
bytes_list.extend(byte)
return bytes_list

def read_32bit_integer(self):
bytes_list = self.get_next_n_bytes(4)
if len(bytes_list) == 4:
integer_value = int.from_bytes(bytes_list, byteorder='big')
return integer_value
else:
return None


# MIT: https://github.com/NovelAI/novelai-image-metadata/blob/main/nai_meta.py
class ImageLsbDataExtractor(object):
def __init__(self, magic: str = "stealth_pngcomp"):
self._magic_bytes = magic.encode('utf-8')

def extract_data(self, image: Image.Image) -> dict:
if image.mode != 'RGBA':
raise ValueError(f'Image should be in RGBA mode, but {image.mode!r} found.')
image = np.array(image)
reader = LSBExtractor(image)

read_magic = reader.get_next_n_bytes(len(self._magic_bytes))
if not (self._magic_bytes == read_magic):
raise ValueError(f'Image magic number mismatch, '
f'{self._magic_bytes!r} expected but {read_magic!r}.')

read_len = reader.read_32bit_integer() // 8
json_data = reader.get_next_n_bytes(read_len)

json_data = json.loads(gzip.decompress(json_data).decode("utf-8"))
return json_data
138 changes: 138 additions & 0 deletions imgutils/sd/nai/inject.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# MIT: https://github.com/NovelAI/novelai-image-metadata/blob/main/nai_meta_writer.py
import gzip
import json

# BCH error correction
import bchlib
import numpy as np
from PIL import Image
from PIL.PngImagePlugin import PngInfo

correctable_bits = 16
block_length = 2019
code_block_len = 1920


def bit_shuffle(data_bytes, w, h, use_bytes=False):
bits = np.frombuffer(data_bytes, dtype=np.uint8)
bit_fac = 8
if use_bytes:
bit_fac = 1
else:
bits = np.unpackbits(bits)
bits = bits.reshape((h, w, 3 * bit_fac))
code_block_len = 1920
flat_tile_len = (w * h * 3) // code_block_len
tile_w = 32
if flat_tile_len // tile_w > 100:
tile_w = 64
tile_h = flat_tile_len // tile_w
h_cutoff = (h // tile_h) * tile_h
tile_hr = h - h_cutoff
easy_tiles = bits[:h_cutoff].reshape(h_cutoff // tile_h, tile_h, w // tile_w, tile_w, 3 * bit_fac)
easy_tiles = easy_tiles.swapaxes(1, 2)
easy_tiles = easy_tiles.reshape(-1, tile_h * tile_w)
easy_tiles = easy_tiles.T
rest_tiles = bits[h_cutoff:]
rest_tiles = rest_tiles.reshape(tile_hr, 1, w // tile_w, tile_w, 3 * bit_fac)
rest_tiles = rest_tiles.swapaxes(1, 2)
rest_tiles = rest_tiles.reshape(-1, tile_hr * tile_w)
rest_tiles = rest_tiles.T
rest_dim = rest_tiles.shape[-1]
rest_tiles = np.pad(rest_tiles, ((0, 0), (0, easy_tiles.shape[-1] - rest_tiles.shape[-1])), mode='constant',
constant_values=0)
bits = np.concatenate((easy_tiles, rest_tiles), axis=0)
dim = bits.shape[-1]
bits = bits.reshape((-1,))
if not use_bytes:
bits = np.packbits(bits)
return bytearray(bits.tobytes()), dim, rest_tiles.shape[0], rest_dim


def split_byte_ranges(data_bytes, n, w, h):
data_bytes, dim, rest_size, rest_dim = bit_shuffle(data_bytes.copy(), w, h, use_bytes=True)
chunks = []
for i in range(0, len(data_bytes), n):
chunks.append(data_bytes[i:i + n])
return chunks, dim, rest_size, rest_dim


def pad(data_bytes):
return bytearray(data_bytes + b'\x00' * (2019 - len(data_bytes)))


# Returns codes for the data in data_bytes
def fec_encode(data_bytes, w, h):
encoder = bchlib.BCH(16, prim_poly=17475)
# import galois
# encoder = galois.BCH(16383, 16383-224, d=17, c=224)
chunks = [bytearray(encoder.encode(pad(x))) for x in split_byte_ranges(data_bytes, 2019, w, h)[0]]
return b''.join(chunks)


class LSBInjector:
def __init__(self, data):
self.data = data
self.buffer = bytearray()

def put_byte(self, byte):
self.buffer.append(byte)

def put_32bit_integer(self, integer_value):
self.buffer.extend(integer_value.to_bytes(4, byteorder='big'))

def put_bytes(self, bytes_list):
self.buffer.extend(bytes_list)

def put_string(self, string):
self.put_bytes(string.encode('utf-8'))

def finalize(self):
buffer = np.frombuffer(self.buffer, dtype=np.uint8)
buffer = np.unpackbits(buffer)
data = self.data[..., -1].T
h, w = data.shape
data = data.reshape((-1,))
data[:] = 0xff
buf_len = buffer.shape[0]
data[:buf_len] = 0xfe
data[:buf_len] = np.bitwise_or(data[:buf_len], buffer)
data = data.reshape((h, w)).T
self.data[..., -1] = data


def serialize_metadata(metadata: PngInfo) -> bytes:
# Extract metadata from PNG chunks
data = {
k: v
for k, v in [
data[1]
.decode("latin-1" if data[0] == b"tEXt" else "utf-8")
.split("\x00" if data[0] == b"tEXt" else "\x00\x00\x00\x00\x00")
for data in metadata.chunks
if data[0] == b"tEXt" or data[0] == b"iTXt"
]
}
# Save space by getting rid of reduntant metadata (Title is static)
if "Title" in data:
del data["Title"]
# Encode and compress data using gzip
data_encoded = json.dumps(data)
return gzip.compress(bytes(data_encoded, "utf-8"))


def inject_data(image: Image.Image, data: PngInfo) -> Image.Image:
rgb = np.array(image.convert('RGB'))
image = image.convert('RGBA')
w, h = image.size
pixels = np.array(image)
injector = LSBInjector(pixels)
injector.put_string("stealth_pngcomp")
data = serialize_metadata(data)
injector.put_32bit_integer(len(data) * 8)
injector.put_bytes(data)
fec_data = fec_encode(bytearray(rgb.tobytes()), w, h)
injector.put_32bit_integer(len(fec_data) * 8)
injector.put_bytes(fec_data)
injector.finalize()
return Image.fromarray(injector.data)
88 changes: 88 additions & 0 deletions imgutils/sd/nai/metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import json
import os
from dataclasses import dataclass
from typing import Optional, Union

from PIL import Image
from PIL.PngImagePlugin import PngInfo

from .extract import ImageLsbDataExtractor
from .inject import inject_data
from ...data import load_image, ImageTyping


@dataclass
class NAIMetadata:
software: str
source: str
title: Optional[str] = None
generation_time: Optional[float] = None
description: Optional[str] = None
parameters: Optional[dict] = None

@property
def pnginfo(self) -> PngInfo:
info = PngInfo()
info.add_text('Software', self.software)
info.add_text('Source', self.source)
if self.title is not None:
info.add_text('Title', self.title)
if self.generation_time is not None:
info.add_text('Generation time', json.dumps(self.generation_time)),
if self.description is not None:
info.add_text('Description', self.description)
if self.parameters is not None:
info.add_text('Comment', json.dumps(self.parameters))
return info


def _get_naimeta_raw(image: ImageTyping) -> dict:
image = load_image(image, force_background=None, mode=None)
try:
return ImageLsbDataExtractor().extract_data(image)
except (ValueError, json.JSONDecodeError):
return image.info or {}


def get_naimeta_from_image(image: ImageTyping) -> Optional[NAIMetadata]:
data = _get_naimeta_raw(image)
if data.get('Software') and data.get('Source'):
return NAIMetadata(
software=data['Software'],
source=data['Source'],
title=data.get('Title'),
generation_time=float(data['Generation time']) if data.get('Generation time') else None,
description=data.get('Description'),
parameters=json.loads(data['Comment']) if data.get('Comment') else None,
)
else:
return None


def _get_pnginfo(metadata: Union[NAIMetadata, PngInfo]) -> PngInfo:
if isinstance(metadata, NAIMetadata):
pnginfo = metadata.pnginfo
elif isinstance(metadata, PngInfo):
pnginfo = metadata
else:
raise TypeError(f'Unknown metadata type for NAI - {metadata!r}.')
return pnginfo


def add_naimeta_to_image(image: ImageTyping, metadata: Union[NAIMetadata, PngInfo]) -> Image.Image:
pnginfo = _get_pnginfo(metadata)
image = load_image(image, mode=None, force_background=None)
return inject_data(image, data=pnginfo)


def save_image_with_naimeta(image: ImageTyping, dst_file: Union[str, os.PathLike],
metadata: Union[NAIMetadata, PngInfo],
add_lsb_meta: bool = True, save_pnginfo: bool = True, **kwargs) -> Image.Image:
pnginfo = _get_pnginfo(metadata)
image = load_image(image, mode=None, force_background=None)
if add_lsb_meta:
image = add_naimeta_to_image(image, metadata=pnginfo)
if save_pnginfo:
kwargs['pnginfo'] = pnginfo
image.save(dst_file, **kwargs)
return image
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@ shapely
pyclipper
deprecation>=2.0.0
hfutils>=0.2.3
filelock
filelock
bchlib
Loading

0 comments on commit 577cbce

Please sign in to comment.