|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| import argparse
|
| import os
|
| import tqdm
|
| from statistics import fmean
|
| from eval.syncnet import SyncNetEval
|
| from eval.syncnet_detect import SyncNetDetector
|
| from latentsync.utils.util import red_text
|
| import torch
|
|
|
|
|
| def syncnet_eval(syncnet, syncnet_detector, video_path, temp_dir, detect_results_dir="detect_results"):
|
| syncnet_detector(video_path=video_path, min_track=50)
|
| crop_videos = os.listdir(os.path.join(detect_results_dir, "crop"))
|
| if crop_videos == []:
|
| raise Exception(red_text(f"Face not detected in {video_path}"))
|
| av_offset_list = []
|
| conf_list = []
|
| for video in crop_videos:
|
| av_offset, _, conf = syncnet.evaluate(
|
| video_path=os.path.join(detect_results_dir, "crop", video), temp_dir=temp_dir
|
| )
|
| av_offset_list.append(av_offset)
|
| conf_list.append(conf)
|
| av_offset = int(fmean(av_offset_list))
|
| conf = fmean(conf_list)
|
| print(f"Input video: {video_path}\nSyncNet confidence: {conf:.2f}\nAV offset: {av_offset}")
|
| return av_offset, conf
|
|
|
|
|
| def main():
|
| parser = argparse.ArgumentParser(description="SyncNet")
|
| parser.add_argument("--initial_model", type=str, default="checkpoints/auxiliary/syncnet_v2.model", help="")
|
| parser.add_argument("--video_path", type=str, default=None, help="")
|
| parser.add_argument("--videos_dir", type=str, default="/root/processed")
|
| parser.add_argument("--temp_dir", type=str, default="temp", help="")
|
|
|
| args = parser.parse_args()
|
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
| syncnet = SyncNetEval(device=device)
|
| syncnet.loadParameters(args.initial_model)
|
|
|
| syncnet_detector = SyncNetDetector(device=device, detect_results_dir="detect_results")
|
|
|
| if args.video_path is not None:
|
| syncnet_eval(syncnet, syncnet_detector, args.video_path, args.temp_dir)
|
| else:
|
| sync_conf_list = []
|
| video_names = sorted([f for f in os.listdir(args.videos_dir) if f.endswith(".mp4")])
|
| for video_name in tqdm.tqdm(video_names):
|
| try:
|
| _, conf = syncnet_eval(
|
| syncnet, syncnet_detector, os.path.join(args.videos_dir, video_name), args.temp_dir
|
| )
|
| sync_conf_list.append(conf)
|
| except Exception as e:
|
| print(e)
|
| print(f"The average sync confidence is {fmean(sync_conf_list):.02f}")
|
|
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|