deep-tempest/gr-tempest/python/buttonToFileSink.py

222 lines
8.2 KiB
Python

#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright 2023 Gabriel Varela, Emilio Martínez.
#
# SPDX-License-Identifier: GPL-3.0-or-later
#
### Classic imports
import numpy as np
from gnuradio import gr
import pmt
from scipy import signal
from datetime import datetime
from PIL import Image
### Deep-TEMPEST imports
import os
import argparse
import torch
import sys
from .DTutils import apply_blanking_shift, remove_outliers, adjust_dynamic_range
from . import utils_option as option
from . import utils_image as util
from .utils_dist import get_dist_info, init_dist
from .select_model import define_Model
from . import basicblock as B
from .network_unet import UNetRes as net
def load_enhancement_model(json_path=None):
'''
# ----------------------------------------
# Step - 1 Prepare options
# ----------------------------------------
'''
parser = argparse.ArgumentParser()
parser.add_argument('--opt', type=str, default=json_path, help='Path to option JSON file.')
parser.add_argument('--launcher', default='pytorch', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
parser.add_argument('--dist', default=False)
opt = option.parse(parser.parse_args().opt, is_train=True)
opt['dist'] = parser.parse_args().dist
"""
# ----------------------------------------
# Step 2 - distributed settings
# ----------------------------------------
"""
if opt['dist']:
init_dist('pytorch')
opt['rank'], opt['world_size'] = get_dist_info()
opt = option.dict_to_nonedict(opt)
"""
# ----------------------------------------
# Step 3 - Load model with option setup
# ----------------------------------------
"""
model_path = opt['path']['pretrained_netG']
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
opt_netG = opt['netG']
in_nc = opt_netG['in_nc']
out_nc = opt_netG['out_nc']
nc = opt_netG['nc']
nb = opt_netG['nb']
act_mode = opt_netG['act_mode']
bias = opt_netG['bias']
model = net(in_nc=in_nc, out_nc=out_nc, nc=nc, nb=nb, act_mode=act_mode, bias=bias)
model.load_state_dict(torch.load(model_path), strict=True)
model.eval()
for k, v in model.named_parameters():
v.requires_grad = False
model = model.to(device)
return model
class buttonToFileSink(gr.sync_block):
f"""
Block that saves num_samples of complex samples after recieving a TRUE boolean message in the 'en' port
"""
def __init__(self, Filename = "output.png", input_width=740, H_size=2200, V_size=1125,
remove_blanking=False, enhance_image=False, option_path=None):
gr.sync_block.__init__(self,
name="buttonToFileSink",
in_sig=[(np.complex64)],
out_sig=[],
)
self.Filename = Filename
self.input_width = input_width
self.H_size = H_size
self.V_size = V_size
self.enhance_image = enhance_image
self.option_path = option_path
self.remove_blanking = remove_blanking
self.num_samples = int(input_width*V_size)
self.en = False #default
self.remaining2Save = 0
self.savingSamples = 0
self.message_port_register_in(pmt.intern("en")) #declare message port
self.set_msg_handler(pmt.intern("en"), self.handle_msg) #declare handler for messages
self.stream_image = [] # initialize list to apppend samples
# TODO: fancy active-blanking resolution identification
self.V_active = (self.V_size==1125)*1080 + (self.V_size==1000)*900 + (self.V_size==750) *720 + (self.V_size==628) *600 + (self.V_size==525)*480
self.H_active = (self.H_size==2200)*1920 + (self.H_size==1800)*1600 + (self.H_size==1650)*1280 + (self.H_size==1056)*800 + (self.H_size==800)*640
self.V_blanking = self.V_size - self.V_active
self.H_blanking = self.H_size - self.H_active
if self.enhance_image:
# Load model
self.model = load_enhancement_model(self.option_path)
def work(self, input_items, output_items):
# Don't process, just save available samples
self.available_samples = len(input_items[0])
if self.en == True:
self.stream_image.extend(input_items[0])
self.stream_image = self.stream_image[-self.num_samples:]
if len(self.stream_image)==self.num_samples:#or self.remaining2Save > 0:
# Save the number of samples calculated before
self.save_samples()
# Back to default
self.en = False
# Empty stream for new upcoming screenshots
self.stream_image = []
return self.available_samples #consume all the samples at the input saved or not
# return len(output_items)
def save_samples(self):
# Interpolate signal to original image size
interpolated_signal = signal.resample(self.stream_image, self.H_size*self.V_size)
# Reshape signal to image
captured_image_complex = np.array(interpolated_signal).reshape((self.V_size,self.H_size))
# Create png image
captured_image = np.zeros((self.V_size,self.H_size,3))
captured_image[:,:,0] = np.real(captured_image_complex)
captured_image[:,:,1] = np.imag(captured_image_complex)
# Stretching contrast and mantaining complex phase unchanged
min_value, max_value = np.min(captured_image[:,:,:2]), np.max(captured_image[:,:,:2])
captured_image[:,:,:2] = 255*(captured_image[:,:,:2] - min_value) / (max_value - min_value)
# Image to uint8
captured_image = captured_image.astype('uint8')
# Date and time of screenshot
date_time = datetime.now().strftime("%d-%m-%Y_%H:%M:%S")
# Check if removing blanking
if self.remove_blanking:
# Fix shift with blanking redundance information
captured_image = apply_blanking_shift(captured_image, h_active=self.H_active, v_active=self.V_active,
h_blanking=self.H_blanking, v_blanking=self.V_blanking)
# Remove outliers with median thresholding heuristic
img_L = remove_outliers(captured_image)
# Stretch dynamic range to [0,255]
captured_image = adjust_dynamic_range(img_L)
if self.enhance_image:
#######################################################################
### Preprocess image and create inference with deep-learning model ###
#######################################################################
# Remove outliers with median thresholding heuristic
img_L = remove_outliers(captured_image)
# Stretch dynamic range to [0,255]
img_L = adjust_dynamic_range(img_L)
img_L = img_L[:,:,:2]
# uint8 to tensor
img_L = util.uint2single(img_L)
img_L = util.single2tensor4(img_L)
# Model inference on image
img_E = self.model(img_L)
capture_enhanced = util.tensor2uint(img_E)
# Save image as png
im = Image.fromarray(capture_enhanced)
im.save(self.Filename+'-gr-tempest_screenshot_enhanced_'+date_time+'.png')
# Captured image vs enhanced image
height, width = captured_image.shape[:2]
imgshow = np.zeros((height, 2*width))
imgshow[:,:width] = np.mean(captured_image,axis=2).astype('uint8')
imgshow[:,width:] = capture_enhanced
# Show images at runtime
im = Image.fromarray(imgshow)
im.show()
# Save complex capture image as png
im_complex = Image.fromarray(captured_image)
im_complex.save(self.Filename+'-gr-tempest_screenshot_'+date_time+'.png')
if not(self.enhance_image):
# Show image at runtime
im_complex.show()
# Handler of msg
def handle_msg(self, msg):
Msg_value = pmt.cdr(msg)
self.en = pmt.to_bool(Msg_value) #the message input of the button block is (msgName,msgValue) the first part is not useful for this