summaryrefslogtreecommitdiff
path: root/tools/3D-Reconstruction/MotionEST/MotionEST.py
diff options
context:
space:
mode:
authorDan Zhu <zxdan@google.com>2019-07-17 15:20:27 -0700
committerDan Zhu <zxdan@google.com>2019-07-23 09:49:59 -0700
commit3244f4738d16cfef982ebf94e70dadd95d4c47b8 (patch)
treed3f2f89a1e3c86994fddb450f0879eae26d0777d /tools/3D-Reconstruction/MotionEST/MotionEST.py
parent6b5db3f9da3f8b21cf582ba0d8e375a70427d0e1 (diff)
downloadlibvpx-3244f4738d16cfef982ebf94e70dadd95d4c47b8.tar
libvpx-3244f4738d16cfef982ebf94e70dadd95d4c47b8.tar.gz
libvpx-3244f4738d16cfef982ebf94e70dadd95d4c47b8.tar.bz2
libvpx-3244f4738d16cfef982ebf94e70dadd95d4c47b8.zip
Based Class of Motion Field Estimators
Change-Id: Id01ce15273c0cab0cd61d064099d200708360265
Diffstat (limited to 'tools/3D-Reconstruction/MotionEST/MotionEST.py')
-rw-r--r--tools/3D-Reconstruction/MotionEST/MotionEST.py113
1 files changed, 113 insertions, 0 deletions
diff --git a/tools/3D-Reconstruction/MotionEST/MotionEST.py b/tools/3D-Reconstruction/MotionEST/MotionEST.py
new file mode 100644
index 000000000..0e04bdd40
--- /dev/null
+++ b/tools/3D-Reconstruction/MotionEST/MotionEST.py
@@ -0,0 +1,113 @@
+#!/usr/bin/env python
+# coding: utf-8
+import numpy as np
+import numpy.linalg as LA
+import matplotlib.pyplot as plt
+from Util import drawMF, MSE
+"""The Base Class of Estimators"""
+
+
+class MotionEST(object):
+ """
+ constructor:
+ cur_f: current frame
+ ref_f: reference frame
+ blk_sz: block size
+ """
+
+ def __init__(self, cur_f, ref_f, blk_sz):
+ self.cur_f = cur_f
+ self.ref_f = ref_f
+ self.blk_sz = blk_sz
+ #convert RGB to YUV
+ self.cur_yuv = np.array(self.cur_f.convert('YCbCr'))
+ self.ref_yuv = np.array(self.ref_f.convert('YCbCr'))
+ #frame size
+ self.width = self.cur_f.size[0]
+ self.height = self.cur_f.size[1]
+ #motion field size
+ self.num_row = self.height // self.blk_sz
+ self.num_col = self.width // self.blk_sz
+ #initialize motion field
+ self.mf = np.zeros((self.num_row, self.num_col, 2))
+
+ """
+ estimation function
+ Override by child classes
+ """
+
+ def motion_field_estimation(self):
+ pass
+
+ """
+ distortion of a block:
+ cur_r: current row
+ cur_c: current column
+ mv: motion vector
+ metric: distortion metric
+ """
+
+ def block_dist(self, cur_r, cur_c, mv, metric=MSE):
+ cur_x = cur_c * self.blk_sz
+ cur_y = cur_r * self.blk_sz
+ h = min(self.blk_sz, self.height - cur_y)
+ w = min(self.blk_sz, self.width - cur_x)
+ cur_blk = self.cur_yuv[cur_y:cur_y + h, cur_x:cur_x + w, :]
+ ref_x = cur_x + mv[1]
+ ref_y = cur_y + mv[0]
+ if 0 <= ref_x < self.width and 0 <= ref_y < self.height:
+ ref_blk = self.ref_yuv[ref_y:ref_y + h, ref_x:ref_x + w, :]
+ else:
+ ref_blk = np.zeros((h, w, 3))
+ return self.metric(cur_blk, ref_blk)
+
+ """
+ distortion of motion field
+ """
+
+ def distortion(self, metric=MSE):
+ loss = 0
+ for i in xrange(self.num_row):
+ for j in xrange(self.num_col):
+ loss += self.dist(i, j, self.mf[i, j], metric)
+ return loss / self.num_row / self.num_col
+
+ """
+ evaluation
+ compare the difference with ground truth
+ """
+
+ def motion_field_evaluation(self, ground_truth):
+ loss = 0
+ count = 0
+ gt = ground_truth.mf
+ mask = ground_truth.mask
+ for i in xrange(self.num_row):
+ for j in xrange(self.num_col):
+ if not mask is None and mask[i][j]:
+ continue
+ loss += LA.norm(gt[i, j] - self.mf[i, j])
+ count += 1
+ return loss / count
+
+ """
+ render the motion field
+ """
+
+ def show(self, ground_truth=None):
+ cur_mf = drawMF(self.cur_f, self.blk_sz, self.mf)
+ if ground_truth is None:
+ n_row = 1
+ else:
+ gt_mf = drawMF(self.cur_f, self.blk_sz, ground_truth)
+ n_row = 2
+ plt.figure(figsize=(n_row * 10, 10))
+ plt.subplot(1, n_row, 1)
+ plt.imshow(cur_mf)
+ plt.title('Estimated Motion Field')
+ if not ground_truth is None:
+ plt.subplot(1, n_row, 2)
+ plt.imshow(gt_mf)
+ plt.title('Ground Truth')
+ plt.tight_layout()
+ plt.show()