Misc
d3image
Challenge
我一定是训练模型训练出了幻觉,怎么从这张图里看出了“不存在”的文字?
Sloution
为了还原mysterious_invitation.png中隐藏的信息,我们需要实现your_decode_net函数。根据编码过程,d3net是一个可逆网络,它将DWT变换后的封面图像和秘密信息合并后进行转换。因此,your_decode_net实际上是d3net的逆过程。
具体步骤如下:
- 实现
INV_block的逆操作INV_block_reverse:INV_block是d3net的基本组成单元。我们需要根据其前向传播的数学关系,推导出反向传播以恢复原始输入。 - 实现
D3net的逆操作D3net_reverse:D3net由多个INV_block串联组成。其逆操作就是将INV_block_reverse按相反的顺序串联起来。 - 在
decode函数中使用D3net_reverse:- 将待解码的图片进行DWT变换。
- 构建
D3net_reverse的输入。由于d3net的前向传播是(cover_dwt, payload_dwt) -> (stego_dwt, z_channels),那么其逆向传播就是(stego_dwt, z_prior) -> (recovered_cover_dwt, recovered_payload_dwt)。这里的z_prior通常是一个全零张量,表示编码时被压缩或推向零的隐变量。 - 运行
D3net_reverse以获得恢复的秘密信息DWT。 - 对恢复的秘密信息DWT应用IWT,还原为原始的位图表示。
- 最后,将位图转换为文本信息。
下面是修改后的文件内容:
block.py:
python
import torchimport torch.nn as nnfrom utils import initialize_weights # Dense connectionclass ResidualDenseBlock_out(nn.Module): def __init__(self, bias=True): super(ResidualDenseBlock_out, self).__init__() self.channel = 12 self.hidden_size = 32 self.conv1 = nn.Conv2d(self.channel, self.hidden_size, 3, 1, 1, bias=bias) self.conv2 = nn.Conv2d(self.channel + self.hidden_size, self.hidden_size, 3, 1, 1, bias=bias) self.conv3 = nn.Conv2d(self.channel + 2 * self.hidden_size, self.hidden_size, 3, 1, 1, bias=bias) self.conv4 = nn.Conv2d(self.channel + 3 * self.hidden_size, self.hidden_size, 3, 1, 1, bias=bias) self.conv5 = nn.Conv2d(self.channel + 4 * self.hidden_size, self.channel, 3, 1, 1, bias=bias) self.lrelu = nn.LeakyReLU(inplace=True) # initialization initialize_weights([self.conv5], 0.) def forward(self, x): x1 = self.lrelu(self.conv1(x)) x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) return x5 class INV_block(nn.Module): def __init__(self, clamp=2.0): super().__init__() self.channels = 3 self.clamp = clamp # ρ self.r = ResidualDenseBlock_out() # η self.y = ResidualDenseBlock_out() # φ self.f = ResidualDenseBlock_out() def e(self, s): return torch.exp(self.clamp * 2 * (torch.sigmoid(s) - 0.5)) def forward(self, x): x1, x2 = (x.narrow(1, 0, self.channels*4), x.narrow(1, self.channels*4, self.channels*4)) t2 = self.f(x2) y1 = x1 + t2 s1, t1 = self.r(y1), self.y(y1) y2 = self.e(s1) * x2 + t1 return torch.cat((y1, y2), 1) # Added for inverse operationclass INV_block_reverse(nn.Module): def __init__(self, inv_block_instance): super().__init__() # Store references to the original block's sub-modules # This is critical to use the SAME trained weights self.r = inv_block_instance.r self.y = inv_block_instance.y self.f = inv_block_instance.f self.channels = inv_block_instance.channels self.clamp = inv_block_instance.clamp def e(self, s): return torch.exp(self.clamp * 2 * (torch.sigmoid(s) - 0.5)) def forward(self, y_cat): # y_cat is torch.cat((y1, y2), 1) y1, y2 = (y_cat.narrow(1, 0, self.channels*4), y_cat.narrow(1, self.channels*4, self.channels*4)) # Inverse operations based on INV_block.forward: # Original: # t2 = self.f(x2) # y1 = x1 + t2 => x1 = y1 - t2 # s1, t1 = self.r(y1), self.y(y1) # y2 = self.e(s1) * x2 + t1 => x2 = (y2 - t1) / self.e(s1) # Reversing order: # 1. Calculate s1 and t1 using y1 s1 = self.r(y1) t1 = self.y(y1) # 2. Calculate x2 using y2, t1, and s1 e_s1 = self.e(s1) x2 = (y2 - t1) / e_s1 # 3. Calculate t2 using x2 t2 = self.f(x2) # 4. Calculate x1 using y1 and t2 x1 = y1 - t2 return torch.cat((x1, x2), 1) utils.py:
python
import torch.nn as nnimport torch.nn.init as initimport torchimport numpy as npimport mathfrom reedsolo import RSCodecimport zlib rs = RSCodec(128) def initialize_weights(net_l, scale=1): if not isinstance(net_l, list): net_l = [net_l] for net in net_l: for m in net.modules(): if isinstance(m, nn.Conv2d): init.kaiming_normal_(m.weight, a=0, mode='fan_in') m.weight.data *= scale # for residual block if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.Linear): init.kaiming_normal_(m.weight, a=0, mode='fan_in') m.weight.data *= scale if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.BatchNorm2d): init.constant_(m.weight, 1) init.constant_(m.bias.data, 0.0) class IWT(nn.Module): def __init__(self): super(IWT, self).__init__() self.requires_grad = False def forward(self, x): r = 2 in_batch, in_channel, in_height, in_width = x.size() #print([in_batch, in_channel, in_height, in_width]) out_batch, out_channel, out_height, out_width = in_batch, int( in_channel / (r ** 2)), r * in_height, r * in_width x1 = x[:, 0:out_channel, :, :] / 2 x2 = x[:, out_channel:out_channel * 2, :, :] / 2 x3 = x[:, out_channel * 2:out_channel * 3, :, :] / 2 x4 = x[:, out_channel * 3:out_channel * 4, :, :] / 2 h = torch.zeros([out_batch, out_channel, out_height, out_width]).float().cuda() h[:, :, 0::2, 0::2] = x1 - x2 - x3 + x4 h[:, :, 1::2, 0::2] = x1 - x2 + x3 - x4 h[:, :, 0::2, 1::2] = x1 + x2 - x3 - x4 h[:, :, 1::2, 1::2] = x1 + x2 + x3 + x4 return hclass DWT(nn.Module): def __init__(self): super(DWT, self).__init__() self.requires_grad = False def forward(self, x): x01 = x[:, :, 0::2, :] / 2 x02 = x[:, :, 1::2, :] / 2 x1 = x01[:, :, :, 0::2] x2 = x02[:, :, :, 0::2] x3 = x01[:, :, :, 1::2] x4 = x02[:, :, :, 1::2] x_LL = x1 + x2 + x3 + x4 x_HL = -x1 - x2 + x3 + x4 x_LH = -x1 + x2 - x3 + x4 x_HH = x1 - x2 - x3 + x4 return torch.cat((x_LL, x_HL, x_LH, x_HH), 1) def random_data(cover,device): return torch.zeros(cover.size(), device=device).random_(0, 2) def auxiliary_variable(shape): noise = torch.zeros(shape).cuda() for i in range(noise.shape[0]): noise[i] = torch.randn(noise[i].shape).cuda() return noise def computePSNR(origin,pred): origin = np.array(origin) origin = origin.astype(np.float32) pred = np.array(pred) pred = pred.astype(np.float32) mse = np.mean((origin/1.0 - pred/1.0) ** 2 ) if mse < 1.0e-10: return 100 return 10 * math.log10(255.0**2/mse) def make_payload(width, height, depth, text, batch = 1): message = text_to_bits(text) + [0] * 32 payload = message while len(payload) < batch * width * height * depth: payload += message payload = payload[:batch * width * height * depth] return torch.FloatTensor(payload).view(batch, depth, height, width) def text_to_bits(text): return bytearray_to_bits(text_to_bytearray(text)) def bytearray_to_bits(x): result = [] for i in x: bits = bin(i)[2:] bits = '00000000'[len(bits):] + bits result.extend([int(b) for b in bits]) return result def text_to_bytearray(text): assert isinstance(text, str), "expected a string" x = zlib.compress(text.encode("utf-8")) x = rs.encode(bytearray(x)) return x def bits_to_bytearray(bits): ints = [] bits = np.array(bits) bits = 0 + bits bits = bits = bits.tolist() for b in range(len(bits) // 8): byte = bits[b * 8:(b + 1) * 8] ints.append(int(''.join([str(bit) for bit in byte]), 2)) return bytearray(ints) def bytearray_to_text(x): try: text = rs.decode(x) text = zlib.decompress(text[0]) return text.decode("utf-8") except BaseException: return Falsed3net.py:
python
from model import *from block import INV_block, INV_block_reverse # Import INV_block_reverse class D3net(nn.Module): def __init__(self): super(D3net, self).__init__() self.inv1 = INV_block() self.inv2 = INV_block() self.inv3 = INV_block() self.inv4 = INV_block() self.inv5 = INV_block() self.inv6 = INV_block() self.inv7 = INV_block() self.inv8 = INV_block() def forward(self, x): out = self.inv1(x) out = self.inv2(out) out = self.inv3(out) out = self.inv4(out) out = self.inv5(out) out = self.inv6(out) out = self.inv7(out) out = self.inv8(out) return out # Added for inverse operationclass D3net_reverse(nn.Module): def __init__(self, original_d3net_instance): super().__init__() self.inv_blocks_rev = nn.ModuleList() # Iterate through original blocks in reverse order # The original D3net has inv1 to inv8. So, index from 7 down to 0. for i in range(7, -1, -1): # From inv8 down to inv1 original_inv_block = getattr(original_d3net_instance, f'inv{i+1}') self.inv_blocks_rev.append(INV_block_reverse(original_inv_block)) def forward(self, y_cat): # y_cat is the output of the forward pass of original D3net # which is (stego_dwt, z_channels) out = y_cat for inv_block_rev in self.inv_blocks_rev: out = inv_block_rev(out) # The final 'out' should be (recovered_cover_dwt, recovered_payload_dwt) return outmodel.py:
python
import torch.nn as nnimport torchfrom d3net import D3net class Model(nn.Module): def __init__(self,cuda=True): super(Model, self).__init__() self.model = D3net() if cuda: self.model.cuda() # init_model(self) # This is commented out, so it won't affect loading pretrained weights def forward(self, x): out = self.model(x) return out def init_model(mod): for key, param in mod.named_parameters(): split = key.split('.') if param.requires_grad: param.data = 0.01 * torch.randn(param.data.shape).cuda() if split[-2] == 'conv5': param.data.fill_(0.)test.py:
python
import torchfrom model import Modelfrom utils import DWT, IWT, make_payload, auxiliary_variable, bits_to_bytearray, bytearray_to_textimport torchvisionfrom collections import Counterfrom PIL import Imageimport torchvision.transforms as T # Import the reverse D3netfrom d3net import D3net_reverse transform_test = T.Compose([ T.CenterCrop((720,1280)), T.ToTensor(),]) def load(name): state_dicts = torch.load(name) network_state_dict = {k:v for k,v in state_dicts['net'].items() if 'tmp_var' not in k} d3net.load_state_dict(network_state_dict) def transform2tensor(img): img = Image.open(img) img = img.convert('RGB') return transform_test(img).unsqueeze(0).to(device) def encode(cover, text): cover = transform2tensor(cover) B, C, H, W = cover.size() payload = make_payload(W, H, C, text, B) payload = payload.to(device) cover_input = dwt(cover) payload_input = dwt(payload) input_img = torch.cat([cover_input, payload_input], dim=1) output = d3net(input_img) output_steg = output.narrow(1, 0, 4 * 3) output_img = iwt(output_steg) # torchvision.utils.save_image(cover, f'./{text}.png') torchvision.utils.save_image(output_img,f'./steg.png') def decode(steg_path): steg_tensor = transform2tensor(steg_path) stego_dwt = dwt(steg_tensor) # This is y1, 12 channels (B, 12, H/2, W/2) B, C, H, W = stego_dwt.size() # C is 12 (number of channels after DWT, i.e., 4*original_channels) # Create the 'z_prior' part (y2) for the inverse model. # In many invertible neural networks, the second part of the output (z_channels) # is trained to follow a simple distribution (e.g., standard normal or zero-mean). # For decoding, we feed the known stego_dwt (y1) and a sample from this prior (y2). # A common and simple choice for z_prior is a zero tensor if the model is designed # to push these latent channels towards zero. z_prior = torch.zeros(B, C, H, W).to(device) # Concatenate stego_dwt (y1) and z_prior (y2) to form the input to D3net_reverse. # The input to the inverse network should have 24 channels (12 for y1, 12 for y2), # matching the output of the forward D3net. input_to_reverse = torch.cat((stego_dwt, z_prior), 1) # Total 24 channels # Instantiate the decoder network using the original D3net instance. # `d3net` in `__main__` is an instance of `Model`. # `d3net.model` is the actual `D3net` instance that holds the trained weights. your_decode_net_instance = D3net_reverse(d3net.model) your_decode_net_instance.eval() # Set to evaluation mode your_decode_net_instance.to(device) # Move to device # Run the inverse model. # The output will be (recovered_cover_dwt, recovered_payload_dwt). # This output also has 24 channels. recovered_channels = your_decode_net_instance(input_to_reverse) # Extract the recovered payload DWT. # The original input to the forward D3net was (cover_input, payload_input), both 12 channels. # So, the second 12 channels of `recovered_channels` correspond to the payload. # `4 * 3` means 12 channels. We narrow from channel index 12 for 12 channels. secret_dwt = recovered_channels.narrow(1, 4 * 3, 4 * 3) # Channels 12 to 23 (inclusive), 12 channels total # Apply IWT to get the raw secret (back to 3 channels image representation). secret_rev = iwt(secret_dwt) # The rest of the decode function (from the original problem statement) # Reshape and convert to boolean bits. image = secret_rev.view(-1) > 0 # Convert to boolean tensor (torch.bool) candidates = Counter() # Convert boolean tensor to list of integers (0 or 1). bits = image.data.int().cpu().numpy().tolist() # The `make_payload` function adds `[0] * 32` as a delimiter. # This translates to 4 zero bytes (`b'\x00\x00\x00\x00'`) after RS encoding and compression. for candidate in bits_to_bytearray(bits).split(b'\x00\x00\x00\x00'): candidate = bytearray_to_text(bytearray(candidate)) if candidate: candidates[candidate] += 1 if len(candidates) == 0: raise ValueError('Failed to find message.') candidate, count = candidates.most_common(1)[0] print(candidate) if __name__ == '__main__': d3net = Model() load('magic.potions') d3net.eval() dwt = DWT() iwt = IWT() device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") text = r'd3ctf{Getting that model to converge felt like pure sorcery}' steg = r'./steg.png' cover = './poster.png' # encode(cover, text) # This line is commented out to prevent re-encoding. decode(steg) # Call decode with the stego image.