-
Notifications
You must be signed in to change notification settings - Fork 16
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
dev(narugo): add nai metadata support
- Loading branch information
1 parent
34af200
commit 577cbce
Showing
11 changed files
with
385 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .extract import LSBExtractor, ImageLsbDataExtractor | ||
from .inject import serialize_metadata, inject_data | ||
from .metadata import get_naimeta_from_image, NAIMetadata, add_naimeta_to_image, save_image_with_naimeta |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
import gzip | ||
import json | ||
|
||
import numpy as np | ||
from PIL import Image | ||
|
||
|
||
# MIT: https://github.com/NovelAI/novelai-image-metadata/blob/main/nai_meta.py | ||
class LSBExtractor(object): | ||
def __init__(self, data: np.ndarray): | ||
self.data = data | ||
self.rows, self.cols, self.dim = data.shape | ||
self.bits = 0 | ||
self.byte = 0 | ||
self.row = 0 | ||
self.col = 0 | ||
|
||
def _extract_next_bit(self): | ||
if self.row < self.rows and self.col < self.cols: | ||
bit = self.data[self.row, self.col, self.dim - 1] & 1 | ||
self.bits += 1 | ||
self.byte <<= 1 | ||
self.byte |= bit | ||
self.row += 1 | ||
if self.row == self.rows: | ||
self.row = 0 | ||
self.col += 1 | ||
|
||
def get_one_byte(self): | ||
while self.bits < 8: | ||
self._extract_next_bit() | ||
byte = bytearray([self.byte]) | ||
self.bits = 0 | ||
self.byte = 0 | ||
return byte | ||
|
||
def get_next_n_bytes(self, n): | ||
bytes_list = bytearray() | ||
for _ in range(n): | ||
byte = self.get_one_byte() | ||
if not byte: | ||
break | ||
bytes_list.extend(byte) | ||
return bytes_list | ||
|
||
def read_32bit_integer(self): | ||
bytes_list = self.get_next_n_bytes(4) | ||
if len(bytes_list) == 4: | ||
integer_value = int.from_bytes(bytes_list, byteorder='big') | ||
return integer_value | ||
else: | ||
return None | ||
|
||
|
||
# MIT: https://github.com/NovelAI/novelai-image-metadata/blob/main/nai_meta.py | ||
class ImageLsbDataExtractor(object): | ||
def __init__(self, magic: str = "stealth_pngcomp"): | ||
self._magic_bytes = magic.encode('utf-8') | ||
|
||
def extract_data(self, image: Image.Image) -> dict: | ||
if image.mode != 'RGBA': | ||
raise ValueError(f'Image should be in RGBA mode, but {image.mode!r} found.') | ||
image = np.array(image) | ||
reader = LSBExtractor(image) | ||
|
||
read_magic = reader.get_next_n_bytes(len(self._magic_bytes)) | ||
if not (self._magic_bytes == read_magic): | ||
raise ValueError(f'Image magic number mismatch, ' | ||
f'{self._magic_bytes!r} expected but {read_magic!r}.') | ||
|
||
read_len = reader.read_32bit_integer() // 8 | ||
json_data = reader.get_next_n_bytes(read_len) | ||
|
||
json_data = json.loads(gzip.decompress(json_data).decode("utf-8")) | ||
return json_data |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
# MIT: https://github.com/NovelAI/novelai-image-metadata/blob/main/nai_meta_writer.py | ||
import gzip | ||
import json | ||
|
||
# BCH error correction | ||
import bchlib | ||
import numpy as np | ||
from PIL import Image | ||
from PIL.PngImagePlugin import PngInfo | ||
|
||
correctable_bits = 16 | ||
block_length = 2019 | ||
code_block_len = 1920 | ||
|
||
|
||
def bit_shuffle(data_bytes, w, h, use_bytes=False): | ||
bits = np.frombuffer(data_bytes, dtype=np.uint8) | ||
bit_fac = 8 | ||
if use_bytes: | ||
bit_fac = 1 | ||
else: | ||
bits = np.unpackbits(bits) | ||
bits = bits.reshape((h, w, 3 * bit_fac)) | ||
code_block_len = 1920 | ||
flat_tile_len = (w * h * 3) // code_block_len | ||
tile_w = 32 | ||
if flat_tile_len // tile_w > 100: | ||
tile_w = 64 | ||
tile_h = flat_tile_len // tile_w | ||
h_cutoff = (h // tile_h) * tile_h | ||
tile_hr = h - h_cutoff | ||
easy_tiles = bits[:h_cutoff].reshape(h_cutoff // tile_h, tile_h, w // tile_w, tile_w, 3 * bit_fac) | ||
easy_tiles = easy_tiles.swapaxes(1, 2) | ||
easy_tiles = easy_tiles.reshape(-1, tile_h * tile_w) | ||
easy_tiles = easy_tiles.T | ||
rest_tiles = bits[h_cutoff:] | ||
rest_tiles = rest_tiles.reshape(tile_hr, 1, w // tile_w, tile_w, 3 * bit_fac) | ||
rest_tiles = rest_tiles.swapaxes(1, 2) | ||
rest_tiles = rest_tiles.reshape(-1, tile_hr * tile_w) | ||
rest_tiles = rest_tiles.T | ||
rest_dim = rest_tiles.shape[-1] | ||
rest_tiles = np.pad(rest_tiles, ((0, 0), (0, easy_tiles.shape[-1] - rest_tiles.shape[-1])), mode='constant', | ||
constant_values=0) | ||
bits = np.concatenate((easy_tiles, rest_tiles), axis=0) | ||
dim = bits.shape[-1] | ||
bits = bits.reshape((-1,)) | ||
if not use_bytes: | ||
bits = np.packbits(bits) | ||
return bytearray(bits.tobytes()), dim, rest_tiles.shape[0], rest_dim | ||
|
||
|
||
def split_byte_ranges(data_bytes, n, w, h): | ||
data_bytes, dim, rest_size, rest_dim = bit_shuffle(data_bytes.copy(), w, h, use_bytes=True) | ||
chunks = [] | ||
for i in range(0, len(data_bytes), n): | ||
chunks.append(data_bytes[i:i + n]) | ||
return chunks, dim, rest_size, rest_dim | ||
|
||
|
||
def pad(data_bytes): | ||
return bytearray(data_bytes + b'\x00' * (2019 - len(data_bytes))) | ||
|
||
|
||
# Returns codes for the data in data_bytes | ||
def fec_encode(data_bytes, w, h): | ||
encoder = bchlib.BCH(16, prim_poly=17475) | ||
# import galois | ||
# encoder = galois.BCH(16383, 16383-224, d=17, c=224) | ||
chunks = [bytearray(encoder.encode(pad(x))) for x in split_byte_ranges(data_bytes, 2019, w, h)[0]] | ||
return b''.join(chunks) | ||
|
||
|
||
class LSBInjector: | ||
def __init__(self, data): | ||
self.data = data | ||
self.buffer = bytearray() | ||
|
||
def put_byte(self, byte): | ||
self.buffer.append(byte) | ||
|
||
def put_32bit_integer(self, integer_value): | ||
self.buffer.extend(integer_value.to_bytes(4, byteorder='big')) | ||
|
||
def put_bytes(self, bytes_list): | ||
self.buffer.extend(bytes_list) | ||
|
||
def put_string(self, string): | ||
self.put_bytes(string.encode('utf-8')) | ||
|
||
def finalize(self): | ||
buffer = np.frombuffer(self.buffer, dtype=np.uint8) | ||
buffer = np.unpackbits(buffer) | ||
data = self.data[..., -1].T | ||
h, w = data.shape | ||
data = data.reshape((-1,)) | ||
data[:] = 0xff | ||
buf_len = buffer.shape[0] | ||
data[:buf_len] = 0xfe | ||
data[:buf_len] = np.bitwise_or(data[:buf_len], buffer) | ||
data = data.reshape((h, w)).T | ||
self.data[..., -1] = data | ||
|
||
|
||
def serialize_metadata(metadata: PngInfo) -> bytes: | ||
# Extract metadata from PNG chunks | ||
data = { | ||
k: v | ||
for k, v in [ | ||
data[1] | ||
.decode("latin-1" if data[0] == b"tEXt" else "utf-8") | ||
.split("\x00" if data[0] == b"tEXt" else "\x00\x00\x00\x00\x00") | ||
for data in metadata.chunks | ||
if data[0] == b"tEXt" or data[0] == b"iTXt" | ||
] | ||
} | ||
# Save space by getting rid of reduntant metadata (Title is static) | ||
if "Title" in data: | ||
del data["Title"] | ||
# Encode and compress data using gzip | ||
data_encoded = json.dumps(data) | ||
return gzip.compress(bytes(data_encoded, "utf-8")) | ||
|
||
|
||
def inject_data(image: Image.Image, data: PngInfo) -> Image.Image: | ||
rgb = np.array(image.convert('RGB')) | ||
image = image.convert('RGBA') | ||
w, h = image.size | ||
pixels = np.array(image) | ||
injector = LSBInjector(pixels) | ||
injector.put_string("stealth_pngcomp") | ||
data = serialize_metadata(data) | ||
injector.put_32bit_integer(len(data) * 8) | ||
injector.put_bytes(data) | ||
fec_data = fec_encode(bytearray(rgb.tobytes()), w, h) | ||
injector.put_32bit_integer(len(fec_data) * 8) | ||
injector.put_bytes(fec_data) | ||
injector.finalize() | ||
return Image.fromarray(injector.data) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
import json | ||
import os | ||
from dataclasses import dataclass | ||
from typing import Optional, Union | ||
|
||
from PIL import Image | ||
from PIL.PngImagePlugin import PngInfo | ||
|
||
from .extract import ImageLsbDataExtractor | ||
from .inject import inject_data | ||
from ...data import load_image, ImageTyping | ||
|
||
|
||
@dataclass | ||
class NAIMetadata: | ||
software: str | ||
source: str | ||
title: Optional[str] = None | ||
generation_time: Optional[float] = None | ||
description: Optional[str] = None | ||
parameters: Optional[dict] = None | ||
|
||
@property | ||
def pnginfo(self) -> PngInfo: | ||
info = PngInfo() | ||
info.add_text('Software', self.software) | ||
info.add_text('Source', self.source) | ||
if self.title is not None: | ||
info.add_text('Title', self.title) | ||
if self.generation_time is not None: | ||
info.add_text('Generation time', json.dumps(self.generation_time)), | ||
if self.description is not None: | ||
info.add_text('Description', self.description) | ||
if self.parameters is not None: | ||
info.add_text('Comment', json.dumps(self.parameters)) | ||
return info | ||
|
||
|
||
def _get_naimeta_raw(image: ImageTyping) -> dict: | ||
image = load_image(image, force_background=None, mode=None) | ||
try: | ||
return ImageLsbDataExtractor().extract_data(image) | ||
except (ValueError, json.JSONDecodeError): | ||
return image.info or {} | ||
|
||
|
||
def get_naimeta_from_image(image: ImageTyping) -> Optional[NAIMetadata]: | ||
data = _get_naimeta_raw(image) | ||
if data.get('Software') and data.get('Source'): | ||
return NAIMetadata( | ||
software=data['Software'], | ||
source=data['Source'], | ||
title=data.get('Title'), | ||
generation_time=float(data['Generation time']) if data.get('Generation time') else None, | ||
description=data.get('Description'), | ||
parameters=json.loads(data['Comment']) if data.get('Comment') else None, | ||
) | ||
else: | ||
return None | ||
|
||
|
||
def _get_pnginfo(metadata: Union[NAIMetadata, PngInfo]) -> PngInfo: | ||
if isinstance(metadata, NAIMetadata): | ||
pnginfo = metadata.pnginfo | ||
elif isinstance(metadata, PngInfo): | ||
pnginfo = metadata | ||
else: | ||
raise TypeError(f'Unknown metadata type for NAI - {metadata!r}.') | ||
return pnginfo | ||
|
||
|
||
def add_naimeta_to_image(image: ImageTyping, metadata: Union[NAIMetadata, PngInfo]) -> Image.Image: | ||
pnginfo = _get_pnginfo(metadata) | ||
image = load_image(image, mode=None, force_background=None) | ||
return inject_data(image, data=pnginfo) | ||
|
||
|
||
def save_image_with_naimeta(image: ImageTyping, dst_file: Union[str, os.PathLike], | ||
metadata: Union[NAIMetadata, PngInfo], | ||
add_lsb_meta: bool = True, save_pnginfo: bool = True, **kwargs) -> Image.Image: | ||
pnginfo = _get_pnginfo(metadata) | ||
image = load_image(image, mode=None, force_background=None) | ||
if add_lsb_meta: | ||
image = add_naimeta_to_image(image, metadata=pnginfo) | ||
if save_pnginfo: | ||
kwargs['pnginfo'] = pnginfo | ||
image.save(dst_file, **kwargs) | ||
return image |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,4 +14,5 @@ shapely | |
pyclipper | ||
deprecation>=2.0.0 | ||
hfutils>=0.2.3 | ||
filelock | ||
filelock | ||
bchlib |
Oops, something went wrong.