import datetime
import os

import torch
from torch import nn
from torch import optim
from torch.autograd import Variable
from torch.backends import cudnn
from torch.utils.data import DataLoader
from torchvision import transforms
# from tensorboardX import SummaryWriter
# from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import torch.nn.functional as F

import joint_transforms_edge
# from config import msd_training_root
from config import msd_training_root
from config import msd_testing_root
from config import backbone_path
from dataset_edge import ImageFolder
from misc import AvgMeter, check_mkdir
# from model.mirrornet_edge_cat_img_nonlocal import TAYLOR5_TWOC
from model.vcn import VCN

import loss as L

cudnn.benchmark = True

#cuda_devices = os.environ["CUDA_VISIBLE_DEVICES"]
#print(cuda_devices)
#print("$$$$")
#device_ids = [0]
os.environ["CUDA_VISIBLE_DEVICES"]='1'

ckpt_path = 'ckpt'
exp_name = 'vcnet'

args = {
    'epoch_num': 160,
    'train_batch_size': 10,
    'last_epoch': 0,
    'lr': 1e-3,
    'lr_decay': 0.9,
    'weight_decay': 5e-4,
    'momentum': 0.9,
    'snapshot': '',
    'scale': 384,
    'save_point': [10, 20, 25, 30, 35, 40, 45, 50, 60, 70, 80, 100, 110, 120, 130, 140, 150, 160],
    'add_graph': False,
    'poly_train': True,
    'optimizer': 'SGD'
}

# Path.
check_mkdir(ckpt_path)
check_mkdir(os.path.join(ckpt_path, exp_name))
vis_path = os.path.join(ckpt_path, exp_name, 'log')
check_mkdir(vis_path)
log_path = os.path.join(ckpt_path, exp_name, str(123) + '.txt')
val_log_path = os.path.join(ckpt_path, exp_name, 'val' + str(123) + '.txt')
#writer = SummaryWriter(log_dir=vis_path, comment=exp_name)

# Transform Data.
joint_transform = joint_transforms_edge.Compose([
    joint_transforms_edge.RandomRotate(),
    joint_transforms_edge.Resize((args['scale'], args['scale']))
])
img_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # maybe can be optimized.
])
target_transform = transforms.ToTensor()
edge_transform = transforms.ToTensor()


# Prepare Data Set.
train_set = ImageFolder(msd_training_root, joint_transform, img_transform, target_transform, edge_transform)
print("Train set: {}".format(train_set.__len__()))
train_loader = DataLoader(train_set, batch_size=args['train_batch_size'], num_workers=0, shuffle=True)

val_set = ImageFolder(msd_testing_root, joint_transform, img_transform, target_transform, edge_transform)
val_loader = DataLoader(val_set, batch_size=args['train_batch_size'], num_workers=0, shuffle=True)


total_epoch = args['epoch_num'] * len(train_loader)



def bce2d_new_weights(input, target):
    assert(input.size() == target.size())
    pos = torch.eq(target, 1).float()
    neg = torch.eq(target, 0).float()
    # ing = ((torch.gt(target, 0) & torch.lt(target, 1))).float()

    num_pos = torch.sum(pos)
    num_neg = torch.sum(neg)
    num_total = num_pos + num_neg

    alpha = num_neg  / num_total
    beta = 1.1 * num_pos  / num_total
    # target pixel = 1 -> weight beta
    # target pixel = 0 -> weight 1-beta
    weights = alpha * pos + beta * neg

    return weights

def main():
    print(args)
    print(exp_name)

    net = VCN(backbone_path)
    net = nn.DataParallel(net)
    net.train()
    net.cuda()

    if args['optimizer'] == 'Adam':
        print("Adam")
        optimizer = optim.Adam([
            {'params': [param for name, param in net.named_parameters() if name[-4:] == 'bias'],
             'lr': 2 * args['lr']},
            {'params': [param for name, param in net.named_parameters() if name[-4:] != 'bias'],
             'lr': 1 * args['lr'], 'weight_decay': args['weight_decay']}
    ])
    else:
        print("SGD")
        optimizer = optim.SGD([
            {'params': [param for name, param in net.named_parameters() if name[-4:] == 'bias'],
             'lr': 2 * args['lr']},
            {'params': [param for name, param in net.named_parameters() if name[-4:] != 'bias'],
             'lr': 1 * args['lr'], 'weight_decay': args['weight_decay']}
        ], momentum=args['momentum'])

    if len(args['snapshot']) > 0:
        print('Training Resumes From \'%s\'' % args['snapshot'])
        net.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'] + '.pth')))
        total_epoch = (args['epoch_num'] - int(args['snapshot'])) * len(train_loader)
        print(total_epoch)

    #net = nn.DataParallel(net)
    #print("Using {} GPU(s) to Train.".format(len(device_ids)))
    #print(str(args))
    open(log_path, 'w').write(str(args) + '\n\n')
    train(net, optimizer)
    # writer.close()




def train(net, optimizer):
    curr_iter = 1

    for epoch in range(args['last_epoch'] + 1, args['last_epoch'] + 1 + args['epoch_num']):
        loss_4_record, loss_3_record, loss_2_record, loss_1_record, loss_edge_record, loss_final_record,\
        loss_record = AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter()

        train_iterator = tqdm(train_loader, total=len(train_loader))
        for data in train_iterator:
            if args['poly_train']:
                base_lr = args['lr'] * (1 - float(curr_iter) / float(total_epoch)) ** args['lr_decay']
                optimizer.param_groups[0]['lr'] = 2 * base_lr
                optimizer.param_groups[1]['lr'] = 1 * base_lr

            inputs, labels, edges = data
            batch_size = inputs.size(0)
            inputs = Variable(inputs).cuda()
            labels = Variable(labels).cuda()
            edges = torch.tensor(edges).cuda()

            optimizer.zero_grad()

            predict_4, predict_3, predict_2, predict_1, predict_edge, final_predict = net(inputs)

            # print(final_predict.shape)
            # print(labels.shape)
            loss_4 = L.lovasz_hinge(predict_4, labels)
            loss_3 = L.lovasz_hinge(predict_3, labels)
            loss_2 = L.lovasz_hinge(predict_2, labels)
            loss_1 = L.lovasz_hinge(predict_1, labels)


            # loss_edge = L.lovasz_hinge(predict_edge, edges)
            # loss_edge = bce_logit(predict_edge, edges)

            bce_weight = bce2d_new_weights(predict_edge, edges)

            bce_logit_weighted = nn.BCEWithLogitsLoss(weight=bce_weight).cuda()

            loss_edge = bce_logit_weighted(predict_edge, edges)
            loss_edge *= 100

            loss_final = L.lovasz_hinge(final_predict, labels)

            # loss_edge = bce2d_new(predict_edge, edges, reduction='sum')

            # bce_logit = nn.BCEWithLogitsLoss().cuda(device_ids[0])
            # loss_4 = bce_logit(predict_4, labels)
            # loss_3 = bce_logit(predict_3, labels)
            # loss_2 = bce_logit(predict_2, labels)
            # loss_1 = bce_logit(predict_1, labels)
            # loss_final = bce_logit(final_predict, labels)

            loss = loss_4 + loss_3 + loss_2 + loss_1 + loss_edge + loss_final

            loss.backward()

            optimizer.step()

            loss_record.update(loss.data, batch_size)
            loss_4_record.update(loss_4.data, batch_size)
            loss_3_record.update(loss_3.data, batch_size)
            loss_2_record.update(loss_2.data, batch_size)
            loss_1_record.update(loss_1.data, batch_size)
            loss_1_record.update(loss_1.data, batch_size)
            loss_final_record.update(loss_final.data, batch_size)

            loss_edge_record.update(loss_edge.data, batch_size)

            #if curr_iter % 50 == 0:
            #    writer.add_scalar('loss', loss, curr_iter)
            #    writer.add_scalar('loss_4', loss_4, curr_iter)
            #    writer.add_scalar('loss_3', loss_3, curr_iter)
            #    writer.add_scalar('loss_2', loss_2, curr_iter)
            #    writer.add_scalar('loss_1', loss_1, curr_iter)

            log = '[%3d], [%6d], [%.6f], [%.5f], [L4: %.5f], [L3: %.5f], [L2: %.5f], [L1: %.5f], [edge: %.5f], [final_l: %.5f]' % \
                  (epoch, curr_iter, base_lr, loss_record.avg, loss_4_record.avg, loss_3_record.avg, loss_2_record.avg,
                   loss_1_record.avg, loss_edge_record.avg, loss_final_record.avg)
            train_iterator.set_description(log)
            open(log_path, 'a').write(log + '\n')

            curr_iter += 1
        ## start validation...

        
        print('start validation')
        with torch.no_grad():
            val_iterator = tqdm(val_loader, total=len(val_loader))
            loss_4_record, loss_3_record, loss_2_record, loss_1_record, loss_edge_record, loss_final_record,\
            loss_record = AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter()
            for data in val_iterator:
                inputs, labels, edges = data
                batch_size = inputs.size(0)
                inputs = Variable(inputs).cuda()
                labels = Variable(labels).cuda()
                edges = torch.tensor(edges).cuda()
    
    
                predict_4, predict_3, predict_2, predict_1, predict_edge, final_predict = net(inputs)
    
                # print(final_predict.shape)
                # print(labels.shape)
                loss_4 = L.lovasz_hinge(predict_4, labels)
                loss_3 = L.lovasz_hinge(predict_3, labels)
                loss_2 = L.lovasz_hinge(predict_2, labels)
                loss_1 = L.lovasz_hinge(predict_1, labels)
    
    
                # loss_edge = L.lovasz_hinge(predict_edge, edges)
                # loss_edge = bce_logit(predict_edge, edges)
    
                bce_weight = bce2d_new_weights(predict_edge, edges)
    
                bce_logit = nn.BCEWithLogitsLoss(weight=bce_weight).cuda()
    
                loss_edge = bce_logit(predict_edge, edges)
                loss_edge *= 100
    
                loss_final = L.lovasz_hinge(final_predict, labels)
    
                # loss_edge = bce2d_new(predict_edge, edges, reduction='sum')
    
                # loss_4 = bce_logit(predict_4, labels)
                # loss_3 = bce_logit(predict_3, labels)
                # loss_2 = bce_logit(predict_2, labels)
                # loss_1 = bce_logit(predict_1, labels)
    
                loss = loss_4 + loss_3 + loss_2 + loss_1 + loss_edge + loss_final
                #print(loss.data)
                loss_record.update(loss.data, batch_size)
                loss_4_record.update(loss_4.data, batch_size)
                loss_3_record.update(loss_3.data, batch_size)
                loss_2_record.update(loss_2.data, batch_size)
                loss_1_record.update(loss_1.data, batch_size)
                loss_final_record.update(loss_final.data, batch_size)
                loss_edge_record.update(loss_edge, batch_size)

                log = '[%3d], [%.5f], [L4: %.5f], [L3: %.5f], [L2: %.5f], [L1: %.5f], [edge: %.5f], [final_l: %.5f]' % \
                  (epoch, loss_record.avg, loss_4_record.avg, loss_3_record.avg, loss_2_record.avg,
                   loss_1_record.avg, loss_edge_record.avg, loss_final_record.avg)
                val_iterator.set_description(log)
                open(val_log_path, 'a').write(log + '\n')            
        
        if epoch % 10 == 0:
            net.cpu()
            torch.save(net.module.state_dict(), os.path.join(ckpt_path, exp_name, '%d.pth' % epoch))
            net.cuda()

        if epoch >= args['epoch_num']:
            net.cpu()
            torch.save(net.module.state_dict(), os.path.join(ckpt_path, exp_name, '%d.pth' % epoch))
            print("Optimization Have Done!")
            return


if __name__ == '__main__':
    main()
