222 lines
8.2 KiB
Python
222 lines
8.2 KiB
Python
#!/usr/bin/env python
|
|
# -*- coding: utf-8 -*-
|
|
#
|
|
# Copyright 2023 Gabriel Varela, Emilio Martínez.
|
|
#
|
|
# SPDX-License-Identifier: GPL-3.0-or-later
|
|
#
|
|
|
|
### Classic imports
|
|
import numpy as np
|
|
from gnuradio import gr
|
|
import pmt
|
|
from scipy import signal
|
|
from datetime import datetime
|
|
from PIL import Image
|
|
|
|
### Deep-TEMPEST imports
|
|
import os
|
|
import argparse
|
|
import torch
|
|
import sys
|
|
|
|
from .DTutils import apply_blanking_shift, remove_outliers, adjust_dynamic_range
|
|
|
|
from . import utils_option as option
|
|
from . import utils_image as util
|
|
from .utils_dist import get_dist_info, init_dist
|
|
from .select_model import define_Model
|
|
from . import basicblock as B
|
|
from .network_unet import UNetRes as net
|
|
|
|
def load_enhancement_model(json_path=None):
|
|
'''
|
|
# ----------------------------------------
|
|
# Step - 1 Prepare options
|
|
# ----------------------------------------
|
|
'''
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--opt', type=str, default=json_path, help='Path to option JSON file.')
|
|
parser.add_argument('--launcher', default='pytorch', help='job launcher')
|
|
parser.add_argument('--local_rank', type=int, default=0)
|
|
parser.add_argument('--dist', default=False)
|
|
|
|
opt = option.parse(parser.parse_args().opt, is_train=True)
|
|
opt['dist'] = parser.parse_args().dist
|
|
|
|
"""
|
|
# ----------------------------------------
|
|
# Step 2 - distributed settings
|
|
# ----------------------------------------
|
|
"""
|
|
if opt['dist']:
|
|
init_dist('pytorch')
|
|
opt['rank'], opt['world_size'] = get_dist_info()
|
|
|
|
opt = option.dict_to_nonedict(opt)
|
|
|
|
"""
|
|
# ----------------------------------------
|
|
# Step 3 - Load model with option setup
|
|
# ----------------------------------------
|
|
"""
|
|
|
|
model_path = opt['path']['pretrained_netG']
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
opt_netG = opt['netG']
|
|
|
|
in_nc = opt_netG['in_nc']
|
|
out_nc = opt_netG['out_nc']
|
|
nc = opt_netG['nc']
|
|
nb = opt_netG['nb']
|
|
act_mode = opt_netG['act_mode']
|
|
bias = opt_netG['bias']
|
|
|
|
model = net(in_nc=in_nc, out_nc=out_nc, nc=nc, nb=nb, act_mode=act_mode, bias=bias)
|
|
model.load_state_dict(torch.load(model_path), strict=True)
|
|
model.eval()
|
|
for k, v in model.named_parameters():
|
|
v.requires_grad = False
|
|
model = model.to(device)
|
|
|
|
return model
|
|
|
|
|
|
class buttonToFileSink(gr.sync_block):
|
|
f"""
|
|
Block that saves num_samples of complex samples after recieving a TRUE boolean message in the 'en' port
|
|
"""
|
|
def __init__(self, Filename = "output.png", input_width=740, H_size=2200, V_size=1125,
|
|
remove_blanking=False, enhance_image=False, option_path=None):
|
|
gr.sync_block.__init__(self,
|
|
name="buttonToFileSink",
|
|
in_sig=[(np.complex64)],
|
|
out_sig=[],
|
|
)
|
|
self.Filename = Filename
|
|
self.input_width = input_width
|
|
self.H_size = H_size
|
|
self.V_size = V_size
|
|
self.enhance_image = enhance_image
|
|
self.option_path = option_path
|
|
self.remove_blanking = remove_blanking
|
|
self.num_samples = int(input_width*V_size)
|
|
self.en = False #default
|
|
self.remaining2Save = 0
|
|
self.savingSamples = 0
|
|
self.message_port_register_in(pmt.intern("en")) #declare message port
|
|
self.set_msg_handler(pmt.intern("en"), self.handle_msg) #declare handler for messages
|
|
self.stream_image = [] # initialize list to apppend samples
|
|
|
|
# TODO: fancy active-blanking resolution identification
|
|
self.V_active = (self.V_size==1125)*1080 + (self.V_size==1000)*900 + (self.V_size==750) *720 + (self.V_size==628) *600 + (self.V_size==525)*480
|
|
self.H_active = (self.H_size==2200)*1920 + (self.H_size==1800)*1600 + (self.H_size==1650)*1280 + (self.H_size==1056)*800 + (self.H_size==800)*640
|
|
|
|
self.V_blanking = self.V_size - self.V_active
|
|
self.H_blanking = self.H_size - self.H_active
|
|
|
|
if self.enhance_image:
|
|
# Load model
|
|
self.model = load_enhancement_model(self.option_path)
|
|
|
|
def work(self, input_items, output_items):
|
|
# Don't process, just save available samples
|
|
self.available_samples = len(input_items[0])
|
|
|
|
if self.en == True:
|
|
|
|
self.stream_image.extend(input_items[0])
|
|
|
|
self.stream_image = self.stream_image[-self.num_samples:]
|
|
|
|
if len(self.stream_image)==self.num_samples:#or self.remaining2Save > 0:
|
|
|
|
# Save the number of samples calculated before
|
|
self.save_samples()
|
|
# Back to default
|
|
self.en = False
|
|
# Empty stream for new upcoming screenshots
|
|
self.stream_image = []
|
|
|
|
return self.available_samples #consume all the samples at the input saved or not
|
|
# return len(output_items)
|
|
|
|
def save_samples(self):
|
|
|
|
# Interpolate signal to original image size
|
|
interpolated_signal = signal.resample(self.stream_image, self.H_size*self.V_size)
|
|
|
|
# Reshape signal to image
|
|
captured_image_complex = np.array(interpolated_signal).reshape((self.V_size,self.H_size))
|
|
|
|
# Create png image
|
|
captured_image = np.zeros((self.V_size,self.H_size,3))
|
|
captured_image[:,:,0] = np.real(captured_image_complex)
|
|
captured_image[:,:,1] = np.imag(captured_image_complex)
|
|
# Stretching contrast and mantaining complex phase unchanged
|
|
min_value, max_value = np.min(captured_image[:,:,:2]), np.max(captured_image[:,:,:2])
|
|
captured_image[:,:,:2] = 255*(captured_image[:,:,:2] - min_value) / (max_value - min_value)
|
|
|
|
# Image to uint8
|
|
captured_image = captured_image.astype('uint8')
|
|
|
|
# Date and time of screenshot
|
|
date_time = datetime.now().strftime("%d-%m-%Y_%H:%M:%S")
|
|
|
|
# Check if removing blanking
|
|
if self.remove_blanking:
|
|
# Fix shift with blanking redundance information
|
|
captured_image = apply_blanking_shift(captured_image, h_active=self.H_active, v_active=self.V_active,
|
|
h_blanking=self.H_blanking, v_blanking=self.V_blanking)
|
|
|
|
# Remove outliers with median thresholding heuristic
|
|
img_L = remove_outliers(captured_image)
|
|
# Stretch dynamic range to [0,255]
|
|
captured_image = adjust_dynamic_range(img_L)
|
|
|
|
if self.enhance_image:
|
|
|
|
#######################################################################
|
|
### Preprocess image and create inference with deep-learning model ###
|
|
#######################################################################
|
|
|
|
# Remove outliers with median thresholding heuristic
|
|
img_L = remove_outliers(captured_image)
|
|
# Stretch dynamic range to [0,255]
|
|
img_L = adjust_dynamic_range(img_L)
|
|
img_L = img_L[:,:,:2]
|
|
# uint8 to tensor
|
|
img_L = util.uint2single(img_L)
|
|
img_L = util.single2tensor4(img_L)
|
|
# Model inference on image
|
|
img_E = self.model(img_L)
|
|
capture_enhanced = util.tensor2uint(img_E)
|
|
|
|
# Save image as png
|
|
im = Image.fromarray(capture_enhanced)
|
|
im.save(self.Filename+'-gr-tempest_screenshot_enhanced_'+date_time+'.png')
|
|
|
|
# Captured image vs enhanced image
|
|
height, width = captured_image.shape[:2]
|
|
imgshow = np.zeros((height, 2*width))
|
|
imgshow[:,:width] = np.mean(captured_image,axis=2).astype('uint8')
|
|
imgshow[:,width:] = capture_enhanced
|
|
# Show images at runtime
|
|
im = Image.fromarray(imgshow)
|
|
im.show()
|
|
|
|
# Save complex capture image as png
|
|
im_complex = Image.fromarray(captured_image)
|
|
im_complex.save(self.Filename+'-gr-tempest_screenshot_'+date_time+'.png')
|
|
if not(self.enhance_image):
|
|
# Show image at runtime
|
|
im_complex.show()
|
|
|
|
|
|
# Handler of msg
|
|
def handle_msg(self, msg):
|
|
Msg_value = pmt.cdr(msg)
|
|
self.en = pmt.to_bool(Msg_value) #the message input of the button block is (msgName,msgValue) the first part is not useful for this
|