Misc

d3image

Challenge

我一定是训练模型训练出了幻觉,怎么从这张图里看出了“不存在”的文字?

Sloution

为了还原mysterious_invitation.png中隐藏的信息,我们需要实现your_decode_net函数。根据编码过程,d3net是一个可逆网络,它将DWT变换后的封面图像和秘密信息合并后进行转换。因此,your_decode_net实际上是d3net的逆过程。

具体步骤如下:

  1. 实现 INV_block 的逆操作 INV_block_reverse INV_blockd3net 的基本组成单元。我们需要根据其前向传播的数学关系,推导出反向传播以恢复原始输入。
  2. 实现 D3net 的逆操作 D3net_reverse D3net 由多个 INV_block 串联组成。其逆操作就是将 INV_block_reverse 按相反的顺序串联起来。
  3. decode 函数中使用 D3net_reverse
    • 将待解码的图片进行DWT变换。
    • 构建D3net_reverse的输入。由于d3net的前向传播是(cover_dwt, payload_dwt) -> (stego_dwt, z_channels),那么其逆向传播就是(stego_dwt, z_prior) -> (recovered_cover_dwt, recovered_payload_dwt)。这里的z_prior通常是一个全零张量,表示编码时被压缩或推向零的隐变量。
    • 运行D3net_reverse以获得恢复的秘密信息DWT。
    • 对恢复的秘密信息DWT应用IWT,还原为原始的位图表示。
    • 最后,将位图转换为文本信息。

下面是修改后的文件内容:

block.py:

python
import torchimport torch.nn as nnfrom utils import initialize_weights # Dense connectionclass ResidualDenseBlock_out(nn.Module):    def __init__(self, bias=True):        super(ResidualDenseBlock_out, self).__init__()             self.channel = 12        self.hidden_size = 32           self.conv1 = nn.Conv2d(self.channel, self.hidden_size, 3, 1, 1, bias=bias)        self.conv2 = nn.Conv2d(self.channel + self.hidden_size, self.hidden_size, 3, 1, 1, bias=bias)        self.conv3 = nn.Conv2d(self.channel + 2 * self.hidden_size, self.hidden_size, 3, 1, 1, bias=bias)        self.conv4 = nn.Conv2d(self.channel + 3 * self.hidden_size, self.hidden_size, 3, 1, 1, bias=bias)        self.conv5 = nn.Conv2d(self.channel + 4 * self.hidden_size, self.channel, 3, 1, 1, bias=bias)        self.lrelu = nn.LeakyReLU(inplace=True)        # initialization        initialize_weights([self.conv5], 0.)     def forward(self, x):        x1 = self.lrelu(self.conv1(x))        x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))        x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))        x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))        x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))        return x5 class INV_block(nn.Module):    def __init__(self, clamp=2.0):        super().__init__()                self.channels = 3        self.clamp = clamp        # ρ        self.r = ResidualDenseBlock_out()        # η        self.y = ResidualDenseBlock_out()        # φ        self.f = ResidualDenseBlock_out()     def e(self, s):        return torch.exp(self.clamp * 2 * (torch.sigmoid(s) - 0.5))     def forward(self, x):        x1, x2 = (x.narrow(1, 0, self.channels*4),                  x.narrow(1, self.channels*4, self.channels*4))         t2 = self.f(x2)        y1 = x1 + t2        s1, t1 = self.r(y1), self.y(y1)        y2 = self.e(s1) * x2 + t1         return torch.cat((y1, y2), 1) # Added for inverse operationclass INV_block_reverse(nn.Module):    def __init__(self, inv_block_instance):        super().__init__()        # Store references to the original block's sub-modules        # This is critical to use the SAME trained weights        self.r = inv_block_instance.r        self.y = inv_block_instance.y        self.f = inv_block_instance.f         self.channels = inv_block_instance.channels        self.clamp = inv_block_instance.clamp     def e(self, s):        return torch.exp(self.clamp * 2 * (torch.sigmoid(s) - 0.5))     def forward(self, y_cat):        # y_cat is torch.cat((y1, y2), 1)        y1, y2 = (y_cat.narrow(1, 0, self.channels*4),                  y_cat.narrow(1, self.channels*4, self.channels*4))         # Inverse operations based on INV_block.forward:        # Original:        # t2 = self.f(x2)        # y1 = x1 + t2             => x1 = y1 - t2        # s1, t1 = self.r(y1), self.y(y1)        # y2 = self.e(s1) * x2 + t1 => x2 = (y2 - t1) / self.e(s1)         # Reversing order:        # 1. Calculate s1 and t1 using y1        s1 = self.r(y1)        t1 = self.y(y1)         # 2. Calculate x2 using y2, t1, and s1        e_s1 = self.e(s1)        x2 = (y2 - t1) / e_s1         # 3. Calculate t2 using x2        t2 = self.f(x2)         # 4. Calculate x1 using y1 and t2        x1 = y1 - t2         return torch.cat((x1, x2), 1) 

utils.py:

python
import torch.nn as nnimport torch.nn.init as initimport torchimport numpy as npimport mathfrom reedsolo import RSCodecimport zlib rs = RSCodec(128) def initialize_weights(net_l, scale=1):    if not isinstance(net_l, list):        net_l = [net_l]    for net in net_l:        for m in net.modules():            if isinstance(m, nn.Conv2d):                init.kaiming_normal_(m.weight, a=0, mode='fan_in')                m.weight.data *= scale  # for residual block                if m.bias is not None:                    m.bias.data.zero_()            elif isinstance(m, nn.Linear):                init.kaiming_normal_(m.weight, a=0, mode='fan_in')                m.weight.data *= scale                if m.bias is not None:                    m.bias.data.zero_()            elif isinstance(m, nn.BatchNorm2d):                init.constant_(m.weight, 1)                init.constant_(m.bias.data, 0.0)    class IWT(nn.Module):    def __init__(self):        super(IWT, self).__init__()        self.requires_grad = False     def forward(self, x):        r = 2        in_batch, in_channel, in_height, in_width = x.size()        #print([in_batch, in_channel, in_height, in_width])        out_batch, out_channel, out_height, out_width = in_batch, int(            in_channel / (r ** 2)), r * in_height, r * in_width        x1 = x[:, 0:out_channel, :, :] / 2        x2 = x[:, out_channel:out_channel * 2, :, :] / 2        x3 = x[:, out_channel * 2:out_channel * 3, :, :] / 2        x4 = x[:, out_channel * 3:out_channel * 4, :, :] / 2          h = torch.zeros([out_batch, out_channel, out_height, out_width]).float().cuda()         h[:, :, 0::2, 0::2] = x1 - x2 - x3 + x4        h[:, :, 1::2, 0::2] = x1 - x2 + x3 - x4        h[:, :, 0::2, 1::2] = x1 + x2 - x3 - x4        h[:, :, 1::2, 1::2] = x1 + x2 + x3 + x4         return hclass DWT(nn.Module):    def __init__(self):        super(DWT, self).__init__()        self.requires_grad = False     def forward(self, x):        x01 = x[:, :, 0::2, :] / 2        x02 = x[:, :, 1::2, :] / 2        x1 = x01[:, :, :, 0::2]        x2 = x02[:, :, :, 0::2]        x3 = x01[:, :, :, 1::2]        x4 = x02[:, :, :, 1::2]        x_LL = x1 + x2 + x3 + x4        x_HL = -x1 - x2 + x3 + x4        x_LH = -x1 + x2 - x3 + x4        x_HH = x1 - x2 - x3 + x4        return torch.cat((x_LL, x_HL, x_LH, x_HH), 1)    def random_data(cover,device):    return torch.zeros(cover.size(), device=device).random_(0, 2) def auxiliary_variable(shape):    noise = torch.zeros(shape).cuda()    for i in range(noise.shape[0]):        noise[i] = torch.randn(noise[i].shape).cuda()     return noise def computePSNR(origin,pred):    origin = np.array(origin)    origin = origin.astype(np.float32)    pred = np.array(pred)    pred = pred.astype(np.float32)    mse = np.mean((origin/1.0 - pred/1.0) ** 2 )    if mse < 1.0e-10:      return 100    return 10 * math.log10(255.0**2/mse) def make_payload(width, height, depth, text, batch = 1):    message = text_to_bits(text) + [0] * 32     payload = message    while len(payload) < batch * width * height * depth:        payload += message         payload = payload[:batch * width * height * depth]    return torch.FloatTensor(payload).view(batch, depth, height, width) def text_to_bits(text):    return bytearray_to_bits(text_to_bytearray(text)) def bytearray_to_bits(x):    result = []    for i in x:        bits = bin(i)[2:]        bits = '00000000'[len(bits):] + bits        result.extend([int(b) for b in bits])     return result def text_to_bytearray(text):    assert isinstance(text, str), "expected a string"    x = zlib.compress(text.encode("utf-8"))    x = rs.encode(bytearray(x))     return x def bits_to_bytearray(bits):    ints = []    bits = np.array(bits)    bits = 0 + bits    bits = bits = bits.tolist()    for b in range(len(bits) // 8):        byte = bits[b * 8:(b + 1) * 8]        ints.append(int(''.join([str(bit) for bit in byte]), 2))    return bytearray(ints) def bytearray_to_text(x):    try:        text = rs.decode(x)        text = zlib.decompress(text[0])                    return text.decode("utf-8")    except BaseException:        return False

d3net.py:

python
from model import *from block import INV_block, INV_block_reverse # Import INV_block_reverse class D3net(nn.Module):     def __init__(self):        super(D3net, self).__init__()        self.inv1 = INV_block()        self.inv2 = INV_block()        self.inv3 = INV_block()        self.inv4 = INV_block()        self.inv5 = INV_block()        self.inv6 = INV_block()        self.inv7 = INV_block()        self.inv8 = INV_block()     def forward(self, x):         out = self.inv1(x)        out = self.inv2(out)        out = self.inv3(out)        out = self.inv4(out)        out = self.inv5(out)        out = self.inv6(out)        out = self.inv7(out)        out = self.inv8(out)        return out # Added for inverse operationclass D3net_reverse(nn.Module):    def __init__(self, original_d3net_instance):        super().__init__()        self.inv_blocks_rev = nn.ModuleList()        # Iterate through original blocks in reverse order        # The original D3net has inv1 to inv8. So, index from 7 down to 0.        for i in range(7, -1, -1): # From inv8 down to inv1            original_inv_block = getattr(original_d3net_instance, f'inv{i+1}')            self.inv_blocks_rev.append(INV_block_reverse(original_inv_block))     def forward(self, y_cat):        # y_cat is the output of the forward pass of original D3net        # which is (stego_dwt, z_channels)        out = y_cat        for inv_block_rev in self.inv_blocks_rev:            out = inv_block_rev(out)        # The final 'out' should be (recovered_cover_dwt, recovered_payload_dwt)        return out

model.py:

python
import torch.nn as nnimport torchfrom d3net import D3net  class Model(nn.Module):    def __init__(self,cuda=True):        super(Model, self).__init__()        self.model = D3net()        if cuda:            self.model.cuda()        # init_model(self) # This is commented out, so it won't affect loading pretrained weights     def forward(self, x):        out = self.model(x)        return out  def init_model(mod):    for key, param in mod.named_parameters():        split = key.split('.')        if param.requires_grad:            param.data = 0.01 * torch.randn(param.data.shape).cuda()            if split[-2] == 'conv5':                param.data.fill_(0.)

test.py:

python
import torchfrom model import Modelfrom utils import DWT, IWT, make_payload, auxiliary_variable, bits_to_bytearray, bytearray_to_textimport torchvisionfrom collections import Counterfrom PIL import Imageimport torchvision.transforms as T # Import the reverse D3netfrom d3net import D3net_reverse  transform_test = T.Compose([    T.CenterCrop((720,1280)),    T.ToTensor(),]) def load(name):    state_dicts = torch.load(name)    network_state_dict = {k:v for k,v in state_dicts['net'].items() if 'tmp_var' not in k}    d3net.load_state_dict(network_state_dict) def transform2tensor(img):    img = Image.open(img)    img = img.convert('RGB')    return transform_test(img).unsqueeze(0).to(device) def encode(cover, text):    cover = transform2tensor(cover)    B, C, H, W = cover.size()           payload = make_payload(W, H, C, text, B)    payload = payload.to(device)    cover_input = dwt(cover)    payload_input = dwt(payload)            input_img = torch.cat([cover_input, payload_input], dim=1)     output = d3net(input_img)     output_steg = output.narrow(1, 0, 4 * 3)    output_img = iwt(output_steg)    # torchvision.utils.save_image(cover, f'./{text}.png')    torchvision.utils.save_image(output_img,f'./steg.png')  def decode(steg_path):    steg_tensor = transform2tensor(steg_path)    stego_dwt = dwt(steg_tensor) # This is y1, 12 channels (B, 12, H/2, W/2)     B, C, H, W = stego_dwt.size() # C is 12 (number of channels after DWT, i.e., 4*original_channels)     # Create the 'z_prior' part (y2) for the inverse model.    # In many invertible neural networks, the second part of the output (z_channels)    # is trained to follow a simple distribution (e.g., standard normal or zero-mean).    # For decoding, we feed the known stego_dwt (y1) and a sample from this prior (y2).    # A common and simple choice for z_prior is a zero tensor if the model is designed    # to push these latent channels towards zero.    z_prior = torch.zeros(B, C, H, W).to(device)      # Concatenate stego_dwt (y1) and z_prior (y2) to form the input to D3net_reverse.    # The input to the inverse network should have 24 channels (12 for y1, 12 for y2),    # matching the output of the forward D3net.    input_to_reverse = torch.cat((stego_dwt, z_prior), 1) # Total 24 channels     # Instantiate the decoder network using the original D3net instance.    # `d3net` in `__main__` is an instance of `Model`.     # `d3net.model` is the actual `D3net` instance that holds the trained weights.    your_decode_net_instance = D3net_reverse(d3net.model)    your_decode_net_instance.eval() # Set to evaluation mode    your_decode_net_instance.to(device) # Move to device     # Run the inverse model.    # The output will be (recovered_cover_dwt, recovered_payload_dwt).    # This output also has 24 channels.    recovered_channels = your_decode_net_instance(input_to_reverse)     # Extract the recovered payload DWT.    # The original input to the forward D3net was (cover_input, payload_input), both 12 channels.    # So, the second 12 channels of `recovered_channels` correspond to the payload.    # `4 * 3` means 12 channels. We narrow from channel index 12 for 12 channels.    secret_dwt = recovered_channels.narrow(1, 4 * 3, 4 * 3) # Channels 12 to 23 (inclusive), 12 channels total     # Apply IWT to get the raw secret (back to 3 channels image representation).    secret_rev = iwt(secret_dwt)     # The rest of the decode function (from the original problem statement)    # Reshape and convert to boolean bits.    image = secret_rev.view(-1) > 0 # Convert to boolean tensor (torch.bool)        candidates = Counter()    # Convert boolean tensor to list of integers (0 or 1).    bits = image.data.int().cpu().numpy().tolist()        # The `make_payload` function adds `[0] * 32` as a delimiter.     # This translates to 4 zero bytes (`b'\x00\x00\x00\x00'`) after RS encoding and compression.    for candidate in bits_to_bytearray(bits).split(b'\x00\x00\x00\x00'):        candidate = bytearray_to_text(bytearray(candidate))        if candidate:            candidates[candidate] += 1    if len(candidates) == 0:        raise ValueError('Failed to find message.')    candidate, count = candidates.most_common(1)[0]    print(candidate)         if __name__ == '__main__':    d3net = Model()    load('magic.potions')    d3net.eval()     dwt = DWT()    iwt = IWT()        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")        text = r'd3ctf{Getting that model to converge felt like pure sorcery}'    steg = r'./steg.png'    cover = './poster.png'    # encode(cover, text) # This line is commented out to prevent re-encoding.    decode(steg) # Call decode with the stego image.