Spaces:
Sleeping
Sleeping
| ##################################################### | |
| # AST Composite Server Double Two | |
| # By Guillaume Descoteaux-Isabelle, 20021 | |
| # | |
| # This server compose two Adaptive Style Transfer model (output of the first pass serve as input to the second using the same model) | |
| ######################################################## | |
| #v1-dev | |
| #Receive the 2 res from arguments in the request... | |
| import os | |
| import numpy as np | |
| import tensorflow as tf | |
| import cv2 | |
| from module import encoder, decoder | |
| from glob import glob | |
| import runway | |
| from runway.data_types import number, text | |
| #from utils import * | |
| import scipy | |
| from datetime import datetime | |
| import time | |
| import re | |
| SRV_TYPE="s1" | |
| #set env var RW_ if not already set | |
| if not os.getenv('RW_PORT'): | |
| os.environ["RW_PORT"] = "7860" | |
| if not os.getenv('RW_DEBUG'): | |
| os.environ["RW_DEBUG"] = "0" | |
| if not os.getenv('RW_HOST'): | |
| os.environ["RW_HOST"] = "0.0.0.0" | |
| #RW_MODEL_OPTIONS | |
| if not os.getenv('RW_MODEL_OPTIONS'): | |
| os.environ["RW_MODEL_OPTIONS"]='{"styleCheckpoint":"/data/styleCheckpoint"}' | |
| # Determining the size of the passes | |
| pass1_image_size = 1328 | |
| if not os.getenv('PASS1IMAGESIZE'): | |
| print("PASS1IMAGESIZE env var non existent;using default:" + str(pass1_image_size)) | |
| else: | |
| pass1_image_size = os.getenv('PASS1IMAGESIZE', 1328) | |
| print("PASS1IMAGESIZE value:" + str(pass1_image_size)) | |
| # Determining the size of the passes | |
| autoabc = 1 | |
| if not os.getenv('AUTOABC'): | |
| print("AUTOABC env var non existent;using default:") | |
| print(autoabc) | |
| abcdefault = 1 | |
| print("NOTE----> when running docker, set AUTOABC variable") | |
| print(" docker run ... -e AUTOABC=1 #enabled, 0 to disabled (default)") | |
| else: | |
| autoabc = os.getenv('AUTOABC',1) | |
| print("AUTOABC value:") | |
| print(autoabc) | |
| abcdefault = autoabc | |
| #pass2_image_size = 1024 | |
| #if not os.getenv('PASS2IMAGESIZE'): | |
| # print("PASS2IMAGESIZE env var non existent;using default:" + pass2_image_size) | |
| #else: | |
| # pass2_image_size = os.getenv('PASS2IMAGESIZE') | |
| # print("PASS2IMAGESIZE value:" + pass2_image_size) | |
| # pass3_image_size = 2048 | |
| # if not os.getenv('PASS3IMAGESIZE'): | |
| # print("PASS3IMAGESIZE env var non existent;using default:" + pass3_image_size) | |
| # else: | |
| # pass3_image_size = os.getenv('PASS3IMAGESIZE') | |
| # print("PASS3IMAGESIZE value:" + pass3_image_size) | |
| ########################################## | |
| ## MODELS | |
| #model name for sending it in the response | |
| model1name = "UNNAMED" | |
| if not os.getenv('MODEL1NAME'): | |
| print("MODEL1NAME env var non existent;using default:" + model1name) | |
| else: | |
| model1name = os.getenv('MODEL1NAME', "UNNAMED") | |
| print("MODEL1NAME value:" + model1name) | |
| # #m2 | |
| # model2name = "UNNAMED" | |
| # if not os.getenv('MODEL2NAME'): print("MODEL2NAME env var non existent;using default:" + model2name) | |
| # else: | |
| # model2name = os.getenv('MODEL2NAME') | |
| # print("MODEL2NAME value:" + model2name) | |
| # #m3 | |
| # model3name = "UNNAMED" | |
| # if not os.getenv('MODEL3NAME'): print("MODEL3NAME env var non existent;using default:" + model3name) | |
| # else: | |
| # model3name = os.getenv('MODEL3NAME') | |
| # print("MODEL3NAME value:" + model3name) | |
| ####################################################### | |
| def get_model_simplified_name_from_dirname(dirname): | |
| result_simple_name = dirname.replace("model_","").replace("_864x","").replace("_864","").replace("_new","").replace("-864","") | |
| print(" result_simple_name:" + result_simple_name) | |
| return result_simple_name | |
| def get_padded_checkpoint_no_from_filename(checkpoint_filename): | |
| match = re.search(r'ckpt-(\d+)', checkpoint_filename) | |
| if match: | |
| number = int(match.group(1)) | |
| checkpoint_number = round(number/1000,0) | |
| print(checkpoint_number) | |
| padded_checkpoint_number = str(str(checkpoint_number).zfill(3)) | |
| return padded_checkpoint_number.replace('.0','') | |
| found_model='none' | |
| found_model_checkpoint='0' | |
| ######################################################### | |
| # SETUP | |
| runway_files = runway.file(is_directory=True) | |
| def setup(opts): | |
| global found_model,found_model_checkpoint | |
| sess = tf.Session() | |
| # sess2 = tf.Session() | |
| # sess3 = tf.Session() | |
| init_op = tf.global_variables_initializer() | |
| # init_op2 = tf.global_variables_initializer() | |
| # init_op3 = tf.global_variables_initializer() | |
| sess.run(init_op) | |
| # sess2.run(init_op2) | |
| # sess3.run(init_op3) | |
| with tf.name_scope('placeholder'): | |
| input_photo = tf.placeholder(dtype=tf.float32, | |
| shape=[1, None, None, 3], | |
| name='photo') | |
| input_photo_features = encoder(image=input_photo, | |
| options={'gf_dim': 32}, | |
| reuse=False) | |
| output_photo = decoder(features=input_photo_features, | |
| options={'gf_dim': 32}, | |
| reuse=False) | |
| saver = tf.train.Saver() | |
| # saver2 = tf.train.Saver() | |
| # saver3 = tf.train.Saver() | |
| print("-------------====PATH---------------------->>>>--") | |
| path_default = '/data/styleCheckpoint' | |
| print("opts:") | |
| print(opts) | |
| print("----------------------------------------") | |
| if opts is None: | |
| print("ERROR:opts is None") | |
| path = path_default | |
| try: | |
| path = opts['styleCheckpoint'] | |
| except: | |
| opts= {'styleCheckpoint': u'/data/styleCheckpoint'} | |
| path = opts['styleCheckpoint'] | |
| if not os.path.exists(path): | |
| print("ERROR:Path does not exist:" + path) | |
| path = path_default | |
| print(path) | |
| print("----------------PATH=======---------------<<<<--") | |
| #Getting the model name | |
| model_name = [p for p in os.listdir(path) if os.path.isdir(os.path.join(path, p))][0] | |
| if not os.getenv('MODELNAME'): | |
| dtprint("CONFIG::MODELNAME env var non existent;using default:" + model_name) | |
| else: | |
| model_name = os.getenv('MODELNAME') | |
| # #Getting the model2 name | |
| # model2_name = [p for p in os.listdir(path) if os.path.isdir(os.path.join(path, p))][1] | |
| # if not os.getenv('MODEL2NAME'): | |
| # dtprint("CONFIG::MODEL2NAME env var non existent;using default:" + model2_name) | |
| # else: | |
| # model2_name = os.getenv('MODEL2NAME') | |
| ##Getting the model3 name | |
| # model3_name = [p for p in os.listdir(path) if os.path.isdir(os.path.join(path, p))][2] | |
| # if not os.getenv('MODEL3NAME'): | |
| # dtprint("CONFIG::MODEL3NAME env var non existent;using default:" + model3_name) | |
| # else: | |
| # model3_name = os.getenv('MODEL3NAME') | |
| checkpoint_dir = os.path.join(path, model_name, 'checkpoint_long') | |
| #checkpoint2_dir = os.path.join(path, model2_name, 'checkpoint_long') | |
| # checkpoint3_dir = os.path.join(path, model3_name, 'checkpoint_long') | |
| print("-----------------------------------------") | |
| print("modelname is : " + model_name) | |
| found_model=get_model_simplified_name_from_dirname(model_name) | |
| #print("model2name is : " + model2_name) | |
| # print("model3name is : " + model3_name) | |
| print("checkpoint_dir is : " + checkpoint_dir) | |
| #print("checkpoint2_dir is : " + checkpoint2_dir) | |
| # print("checkpoint3_dir is : " + checkpoint3_dir) | |
| print("-----------------------------------------") | |
| ckpt = tf.train.get_checkpoint_state(checkpoint_dir) | |
| #ckpt2 = tf.train.get_checkpoint_state(checkpoint2_dir) | |
| # ckpt3 = tf.train.get_checkpoint_state(checkpoint3_dir) | |
| ckpt_name = os.path.basename(ckpt.model_checkpoint_path) | |
| found_model_checkpoint= get_padded_checkpoint_no_from_filename(ckpt_name) | |
| #ckpt2_name = os.path.basename(ckpt2.model_checkpoint_path) | |
| # ckpt3_name = os.path.basename(ckpt3.model_checkpoint_path) | |
| saver.restore(sess, os.path.join(checkpoint_dir, ckpt_name)) | |
| #saver2.restore(sess2, os.path.join(checkpoint2_dir, ckpt2_name)) | |
| # saver3.restore(sess3, os.path.join(checkpoint3_dir, ckpt3_name)) | |
| m1 = dict(sess=sess, input_photo=input_photo, output_photo=output_photo) | |
| #m2 = dict(sess=sess2, input_photo=input_photo, output_photo=output_photo) | |
| # m3 = dict(sess=sess3, input_photo=input_photo, output_photo=output_photo) | |
| models = type('', (), {})() | |
| models.m1 = m1 | |
| #models.m2 = m2 | |
| # models.m3 = m3 | |
| return models | |
| def make_target_output_filename( mname,checkpoint, fn='',res1=0,abc=0, ext='.jpg',svrtype="s1", modelid='', suffix='', xtra_model_id='',verbose=False): | |
| fn_base=fn.replace(ext,"") | |
| fn_base=fn_base.replace(".jpg","") | |
| fn_base=fn_base.replace(".jpeg","") | |
| fn_base=fn_base.replace(".JPG","") | |
| fn_base=fn_base.replace(".JPEG","") | |
| fn_base=fn_base.replace(".png","") | |
| fn_base=fn_base.replace(".PNG","") | |
| #pad res1 and res2 to 4 digits | |
| res1_pad=str(res1).zfill(4) | |
| abc_pad=str(abc).zfill(2) | |
| if res1_pad=="0000": | |
| res1_pad="" | |
| #pad checkpoint to 3 digits | |
| checkpoint=checkpoint.zfill(3) | |
| if fn_base=="none": | |
| fn_base="" | |
| if '/' in fn_base: | |
| fn_base=fn_base.split('/')[-1] | |
| # Print out all input info: | |
| if verbose : | |
| print("-----------------------------") | |
| print("fn_base: ",fn_base) | |
| print("mname: ",mname) | |
| print("suffix: ",suffix) | |
| print("res1: ",res1_pad) | |
| print("abc: ",abc_pad) | |
| print("ext: ",ext) | |
| print("svrtype: ",svrtype) | |
| print("modelid: ",modelid) | |
| print("xtra_model_id: ",xtra_model_id) | |
| print("checkpoint: ",checkpoint) | |
| print("fn: ",fn) | |
| mtag = "{}__{}__{}x{}__{}__{}k".format(mname,suffix,res1_pad,abc_pad, svrtype, checkpoint).replace("_0x" + str(abc_pad), "") | |
| if verbose: | |
| print(mtag) | |
| target_output = "{}__{}__{}{}{}".format(fn_base, modelid, mtag, xtra_model_id, ext).replace("_"+str(abc_pad)+"x"+str(abc_pad)+"_","").replace("_0x0_", "").replace("_0_", "").replace("_-", "_").replace("____", "__").replace("___", "__").replace("___", "__").replace("..",".").replace("model_","").replace("_x"+str(abc_pad)+"_","").replace("gia-ds-","") | |
| target_output = replace_values_from_csv(target_output) | |
| return target_output | |
| def replace_values_from_csv(target_output): | |
| # Implement the logic to replace values from CSV | |
| #load replacer.csv and replace the values (src,dst) | |
| src_dest_file = 'replacer.csv' | |
| if os.path.exists(src_dest_file): | |
| with open(src_dest_file, 'r') as file: | |
| lines = file.readlines() | |
| for line in lines: | |
| src, dst = line.split(',') | |
| target_output = target_output.replace(src, dst) | |
| return target_output.replace("\n", "").replace("\r", "").replace(" ", "_") | |
| def _make_meta_as_json(x1=0,c1=0,inp=None,result_dict=None): | |
| global found_model,found_model_checkpoint | |
| fn='none' | |
| if inp['fn'] != 'none': | |
| fn=inp['fn'] | |
| ext='.jpg' | |
| if inp['ext'] != '.jpg': | |
| ext=inp['ext'] | |
| filename=make_target_output_filename(found_model,found_model_checkpoint,fn,x1,c1,ext,SRV_TYPE) | |
| if result_dict is None: | |
| json_return = { | |
| "model": str(found_model), | |
| "checkpoint": str(found_model_checkpoint), | |
| "filename": str(filename) | |
| } | |
| return json_return | |
| else: #support adding to the existing dict the data directly | |
| result_dict['model']=str(found_model) | |
| result_dict['checkpoint']=str(found_model_checkpoint) | |
| result_dict['filename']=str(filename) | |
| return result_dict | |
| meta_inputs={'meta':text} | |
| meta_outputs={'model':text,'filename':text,'checkpoint':text} | |
| def get_geta(models, inp): | |
| global found_model,found_model_checkpoint | |
| json_return = _make_meta_as_json() | |
| # "files": "nothing yet" | |
| print(json_return) | |
| return json_return | |
| def get_geta(models, inp): | |
| global found_model,found_model_checkpoint | |
| json_return = _make_meta_as_json(inp) | |
| # "files": "nothing yet" | |
| print(json_return) | |
| return json_return | |
| #@STCGoal add number or text to specify resolution of the three pass | |
| inputs={'contentImage': runway.image,'x1':number(default=1024,min=24,max=18000),'c1':number(default=0,min=-99,max=99),'fn':text(default='none'),'ext':text(default='.jpg')} | |
| outputs={'stylizedImage': runway.image,'totaltime':number,'x1': number,'c1':number,'model1name':text,'checkpoint':text,'filename':text,'model':text} | |
| def stylize(models, inp): | |
| global found_model,found_model_checkpoint,model1name | |
| start = time.time() | |
| model = models.m1 | |
| #model2 = models.m2 | |
| # model3 = models.m3 | |
| #Getting our names back (even though I think we dont need) | |
| #@STCIssue BUGGED | |
| # m1name=models.m1.name | |
| # m2name=models.m2.name | |
| # m3name=models.m3.name | |
| #get size from inputs rather than env | |
| x1 = int(inp['x1']) | |
| c1 = int(inp['c1']) | |
| # | |
| img = inp['contentImage'] | |
| img = np.array(img) | |
| img = img / 127.5 - 1. | |
| #@a Pass 1 RESIZE to 1368px the smaller side | |
| image_size=pass1_image_size | |
| image_size=x1 | |
| img_shape = img.shape[:2] | |
| alpha = float(image_size) / float(min(img_shape)) | |
| #dtprint ("DEBUG::content.imgshape:" + str(tuple(img_shape)) + ", alpha:" + str(alpha)) | |
| try: | |
| img = scipy.misc.imresize(img, size=alpha) | |
| except: | |
| pass | |
| img = np.expand_dims(img, axis=0) | |
| #@a INFERENCE PASS 1 | |
| dtprint("INFO:Pass1 inference starting") | |
| img = model['sess'].run(model['output_photo'], feed_dict={model['input_photo']: img}) | |
| # | |
| img = (img + 1.) * 127.5 | |
| img = img.astype('uint8') | |
| img = img[0] | |
| #dtprint("INFO:Upresing Pass1 for Pass 2 (STARTING) ") | |
| #@a Pass 2 RESIZE to 1024px the smaller side | |
| #image_size=pass2_image_size | |
| #image_size=x2 | |
| #img_shape = img.shape[:2] | |
| #alpha = float(image_size) / float(min(img_shape)) | |
| #dtprint ("DEBUG::pass1.imgshape:" + str(tuple(img_shape)) + ", alpha:" + str(alpha)) | |
| #img = scipy.misc.imresize(img, size=alpha) | |
| #dtprint("INFO:Upresing Pass1 (DONE) ") | |
| #Iteration 2 | |
| #img = np.array(img) | |
| #img = img / 127.5 - 1. | |
| #img = np.expand_dims(img, axis=0) | |
| #@a INFERENCE PASS 2 using the same model | |
| #dtprint("INFO:Pass2 inference (STARTING)") | |
| #img = model['sess'].run(model['output_photo'], feed_dict={model['input_photo']: img}) | |
| #dtprint("INFO:Pass2 inference (DONE)") | |
| #img = (img + 1.) * 127.5 | |
| #img = img.astype('uint8') | |
| #img = img[0] | |
| # #pass3 | |
| # #@a Pass 3 RESIZE to 2048px the smaller side | |
| # image_size=pass3_image_size | |
| # image_size=x3 | |
| # img_shape = img.shape[:2] | |
| # alpha = float(image_size) / float(min(img_shape)) | |
| # dtprint ("DEBUG::pass2.imgshape:" + str(tuple(img_shape)) + ", alpha:" + str(alpha)) | |
| # img = scipy.misc.imresize(img, size=alpha) | |
| # dtprint("INFO:Upresing Pass2 (DONE) ") | |
| # #Iteration 3 | |
| # img = np.array(img) | |
| # img = img / 127.5 - 1. | |
| # img = np.expand_dims(img, axis=0) | |
| # #@a INFERENCE PASS 3 | |
| # dtprint("INFO:Pass3 inference (STARTING)") | |
| # img = model3['sess'].run(model3['output_photo'], feed_dict={model3['input_photo']: img}) | |
| # dtprint("INFO:Pass3 inference (DONE)") | |
| # img = (img + 1.) * 127.5 | |
| # img = img.astype('uint8') | |
| # img = img[0] | |
| # #pass3 | |
| #dtprint("INFO:Composing done") | |
| if c1 != 0 : | |
| print('Auto Brightening images...' + str(c1)) | |
| img = img, alpha2, beta = automatic_brightness_and_contrast(img,c1) | |
| stop = time.time() | |
| totaltime = stop - start | |
| print("The time of the run:", totaltime) | |
| #if model1name UNNAMED, use found_model | |
| if model1name == "UNNAMED": | |
| model1name=found_model | |
| include_meta_directly_in_result=True | |
| if include_meta_directly_in_result: | |
| result_dict = dict(stylizedImage=img,totaltime=totaltime,x1=x1,model1name=model1name,c1=c1) | |
| result_dict = _make_meta_as_json(x1,c1,inp,result_dict) | |
| else: | |
| meta_data = _make_meta_as_json(x1,c1,inp) | |
| result_dict = dict(stylizedImage=img,totaltime=totaltime,x1=x1,model1name=model1name,c1=c1,meta=meta_data) | |
| return result_dict | |
| def dtprint(msg): | |
| dttag=getdttag() | |
| print(dttag + "::" + msg ) | |
| def getdttag(): | |
| # datetime object containing current date and time | |
| now = datetime.now() | |
| # dd/mm/YY H:M:S | |
| # dt_string = now.strftime("%d/%m/%Y %H:%M:%S") | |
| return now.strftime("%H:%M:%S") | |
| # Automatic brightness and contrast optimization with optional histogram clipping | |
| def automatic_brightness_and_contrast(image, clip_hist_percent=25): | |
| gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) | |
| # Calculate grayscale histogram | |
| hist = cv2.calcHist([gray],[0],None,[256],[0,256]) | |
| hist_size = len(hist) | |
| # Calculate cumulative distribution from the histogram | |
| accumulator = [] | |
| accumulator.append(float(hist[0])) | |
| for index in range(1, hist_size): | |
| accumulator.append(accumulator[index -1] + float(hist[index])) | |
| # Locate points to clip | |
| maximum = accumulator[-1] | |
| clip_hist_percent *= (maximum/100.0) | |
| clip_hist_percent /= 2.0 | |
| # Locate left cut | |
| minimum_gray = 0 | |
| while accumulator[minimum_gray] < clip_hist_percent: | |
| minimum_gray += 1 | |
| # Locate right cut | |
| maximum_gray = hist_size -1 | |
| while accumulator[maximum_gray] >= (maximum - clip_hist_percent): | |
| maximum_gray -= 1 | |
| # Calculate alpha and beta values | |
| alpha = 255 / (maximum_gray - minimum_gray) | |
| beta = -minimum_gray * alpha | |
| ''' | |
| # Calculate new histogram with desired range and show histogram | |
| new_hist = cv2.calcHist([gray],[0],None,[256],[minimum_gray,maximum_gray]) | |
| plt.plot(hist) | |
| plt.plot(new_hist) | |
| plt.xlim([0,256]) | |
| plt.show() | |
| ''' | |
| auto_result = cv2.convertScaleAbs(image, alpha=alpha, beta=beta) | |
| return (auto_result, alpha, beta) | |
| if __name__ == '__main__': | |
| #print('External Service port is:' +os.environ.get('SPORT')) | |
| os.environ["RW_PORT"] = "7860" | |
| print("Launched...") | |
| runway.run() | |