"""Classes for creating video streams.
"""
import math
import cv2
import numpy as np
from abc import ABC, abstractmethod
from pathlib import Path
from pubsub import pub
import logging
logger = logging.getLogger('lavaflow')
# -----------------------------------------------------------------------------
# Streams
[docs]class AbstractStream(ABC):
"""Generic stream class (equivalent to an iterator with length).
"""
@abstractmethod
def __len__(self):
"""Get total number of items that can be retrieved.
Returns:
length (int): number of items
"""
pass
[docs] @abstractmethod
def __iter__(self):
"""Return iterator.
Returns:
self (AbstractStream): iterator
"""
pass
[docs] @abstractmethod
def __next__(self):
"""Get next item.
Returns:
img (np.ndarray): item
"""
pass
# Video streams
[docs]class VideoStream(AbstractStream):
"""Video stream frame iterator.
"""
def __init__(self, file, res=(-1, -1), frame_seek=0, step_size=1, max_steps=-1):
"""
Args:
file (str): file name
res (tuple): output frame size written as (w, h)
frame_seek (int): frame to seek, default: 0 (start of video)
step_size (int): number of frames to iterate each step
max_steps (int): total number of frames to capture, default: -1 (end of video)
"""
logger.info(f'init')
self.file = Path(file)
if not self.file.exists():
raise FileNotFoundError(f"Invalid video filepath '{self.file}'")
self.video = cv2.VideoCapture(str(self.file))
self.num_frames = int(self.video.get(cv2.CAP_PROP_FRAME_COUNT))
if res[0] > -1 and res[1] > -1:
self.resize = True
self.res = tuple(res)
else:
self.resize = False
self.res = self.get_source_size()
self.seek(frame_seek, step_size, max_steps)
[docs] def seek(self, frame_seek=0, step_size=1, max_steps=-1):
"""Reset video frame iterator back to ``frame_seek``.
Args:
frame_seek (int): frame to seek, default: 0 (start of video)
step_size (int): number of frames to iterate each step
max_steps (int): total number of frames to capture, default: -1 (end of video)
"""
logger.info(f'seek')
if frame_seek < 0 or frame_seek > self.num_frames - 1:
raise ValueError(f'frame index to seek ({frame_seek}) > num frames ({self.num_frames})')
max_frames = max_steps * step_size
rem_frames = self.num_frames - frame_seek
logger.info(self.num_frames)
logger.info(frame_seek)
logger.info(max_frames)
logger.info(rem_frames)
if max_frames > rem_frames:
raise ValueError(f'max steps * step size ({max_frames}) > num frames - frame index to seek ({rem_frames})')
self.frame_seek = frame_seek
self.frame = frame_seek
self.step_size = step_size
self.max_steps = max_steps
self.step_counter = 0
self.video.set(cv2.CAP_PROP_POS_FRAMES, frame_seek)
self.pub = math.ceil(math.log10(self.max_steps - 1))
[docs] def get_source_size(self):
"""Get input video size.
Returns:
res (tuple): video size ``(w, h)``
"""
return (int(self.video.get(cv2.CAP_PROP_FRAME_WIDTH)), int(self.video.get(cv2.CAP_PROP_FRAME_HEIGHT)))
[docs] def get_output_size(self):
"""Get output frame size.
Returns:
res (tuple): frame size ``(w, h)``
"""
return self.res
[docs] def get_position(self):
"""Get current frame index.
Returns:
position (int): current frame index
"""
return self.frame
def __len__(self):
"""Get total number of frames that will be retrieved.
Returns:
num_frames (int): either max_steps or the number of frames remaining
"""
if self.max_steps > -1:
return self.max_steps
else:
return int((self.num_frames - self.frame_seek) / self.step_size)
[docs] def __iter__(self):
"""Return iterator.
Returns:
self (VideoStream): video frame iterator
"""
return self
[docs] def __next__(self):
"""Get next frame.
Note that this will automatically stop iteration if the video is removed
or becomes unavailable.
Returns:
img (np.ndarray): frame
frame (int): frame index
"""
# TODO
# - catch downstream exceptions and release video capture
# - improve logging
if not self.video or (self.max_steps > -1 and self.step_counter >= self.max_steps - 1):
self.video.release()
raise StopIteration
ret = self.video.grab()
if not ret:
self.video.release()
raise StopIteration
if self.step_counter > 0:
for i in range(1, self.step_size):
ret = self.video.grab()
if not ret:
self.video.release()
raise StopIteration
_, img = self.video.retrieve()
if self.resize:
img = cv2.resize(img, self.res, 0, 0, cv2.INTER_AREA)
frame = self.frame
self.step_counter += 1
self.frame += self.step_size
logger.info(f"next: index={self.step_counter} max_steps={self.max_steps} frame={frame} img.shape={img.shape}")
pub.sendMessage("log", message=f"{self.step_counter:0{self.pub}d} / {self.max_steps - 1}")
return img, frame