vidformer

A Python library for creating and viewing videos with vidformer.

  1"""A Python library for creating and viewing videos with vidformer."""
  2
  3__version__ = "0.5.2"
  4
  5import subprocess
  6from fractions import Fraction
  7import random
  8import time
  9import json
 10import socket
 11import os
 12import sys
 13import multiprocessing
 14import uuid
 15import threading
 16import gzip
 17import base64
 18
 19import requests
 20import msgpack
 21import numpy as np
 22
 23_in_notebook = False
 24try:
 25    from IPython import get_ipython
 26
 27    if "IPKernelApp" in get_ipython().config:
 28        _in_notebook = True
 29except:
 30    pass
 31
 32
 33def _check_hls_link_exists(url, max_attempts=150, delay=0.1):
 34    for attempt in range(max_attempts):
 35        try:
 36            response = requests.get(url)
 37            if response.status_code == 200:
 38                return response.text.strip()
 39            else:
 40                time.sleep(delay)
 41        except requests.exceptions.RequestException as e:
 42            time.sleep(delay)
 43    return None
 44
 45
 46class Spec:
 47    def __init__(self, domain: list[Fraction], render, fmt: dict):
 48        self._domain = domain
 49        self._render = render
 50        self._fmt = fmt
 51
 52    def __repr__(self):
 53        lines = []
 54        for i, t in enumerate(self._domain):
 55            frame_expr = self._render(t, i)
 56            lines.append(
 57                f"{t.numerator}/{t.denominator} => {frame_expr}",
 58            )
 59        return "\n".join(lines)
 60
 61    def _sources(self):
 62        s = set()
 63        for i, t in enumerate(self._domain):
 64            frame_expr = self._render(t, i)
 65            s = s.union(frame_expr._sources())
 66        return s
 67
 68    def _to_json_spec(self):
 69        frames = []
 70        s = set()
 71        f = {}
 72        for i, t in enumerate(self._domain):
 73            frame_expr = self._render(t, i)
 74            s = s.union(frame_expr._sources())
 75            f = {**f, **frame_expr._filters()}
 76            frame = [[t.numerator, t.denominator], frame_expr._to_json_spec()]
 77            frames.append(frame)
 78        return {"frames": frames}, s, f
 79
 80    def play(self, server, method="html", verbose=False):
 81        """Play the video live in the notebook."""
 82
 83        spec, sources, filters = self._to_json_spec()
 84        spec_json_bytes = json.dumps(spec).encode("utf-8")
 85        spec_obj_json_gzip = gzip.compress(spec_json_bytes, compresslevel=1)
 86        spec_obj_json_gzip_b64 = base64.b64encode(spec_obj_json_gzip).decode("utf-8")
 87
 88        sources = [
 89            {
 90                "name": s._name,
 91                "path": s._path,
 92                "stream": s._stream,
 93                "service": s._service.as_json() if s._service is not None else None,
 94            }
 95            for s in sources
 96        ]
 97        filters = {
 98            k: {
 99                "filter": v._func,
100                "args": v._kwargs,
101            }
102            for k, v in filters.items()
103        }
104        arrays = []
105
106        if verbose:
107            print(f"Sending to server. Spec is {len(spec_obj_json_gzip_b64)} bytes")
108
109        resp = server._new(spec_obj_json_gzip_b64, sources, filters, arrays, self._fmt)
110        hls_video_url = resp["stream_url"]
111        hls_player_url = resp["player_url"]
112        namespace = resp["namespace"]
113        hls_js_url = server.hls_js_url()
114
115        if method == "link":
116            return hls_video_url
117        if method == "player":
118            return hls_player_url
119        if method == "iframe":
120            from IPython.display import IFrame
121
122            return IFrame(hls_player_url, width=1280, height=720)
123        if method == "html":
124            from IPython.display import HTML
125
126            # We add a namespace to the video element to avoid conflicts with other videos
127            html_code = f"""
128<!DOCTYPE html>
129<html>
130<head>
131    <title>HLS Video Player</title>
132    <!-- Include hls.js library -->
133    <script src="{hls_js_url}"></script>
134</head>
135<body>
136    <!-- Video element -->
137    <video id="video-{namespace}" controls width="640" height="360" autoplay></video>
138    <script>
139        var video = document.getElementById('video-{namespace}');
140        var videoSrc = '{hls_video_url}';
141        var hls = new Hls();
142        hls.loadSource(videoSrc);
143        hls.attachMedia(video);
144        hls.on(Hls.Events.MANIFEST_PARSED, function() {{
145            video.play();
146        }});
147    </script>
148</body>
149</html>
150"""
151            return HTML(data=html_code)
152        else:
153            return hls_player_url
154
155    def load(self, server):
156        spec, sources, filters = self._to_json_spec()
157        spec_json_bytes = json.dumps(spec).encode("utf-8")
158        spec_obj_json_gzip = gzip.compress(spec_json_bytes, compresslevel=1)
159        spec_obj_json_gzip_b64 = base64.b64encode(spec_obj_json_gzip).decode("utf-8")
160
161        sources = [
162            {
163                "name": s._name,
164                "path": s._path,
165                "stream": s._stream,
166                "service": s._service.as_json() if s._service is not None else None,
167            }
168            for s in sources
169        ]
170        filters = {
171            k: {
172                "filter": v._func,
173                "args": v._kwargs,
174            }
175            for k, v in filters.items()
176        }
177        arrays = []
178
179        resp = server._new(spec_obj_json_gzip_b64, sources, filters, arrays, self._fmt)
180        namespace = resp["namespace"]
181        return Loader(server, namespace, self._domain)
182
183    def save(self, server, pth, encoder=None, encoder_opts=None, format=None):
184        """Save the video to a file."""
185
186        assert encoder is None or type(encoder) == str
187        assert encoder_opts is None or type(encoder_opts) == dict
188        if encoder_opts is not None:
189            for k, v in encoder_opts.items():
190                assert type(k) == str and type(v) == str
191
192        spec, sources, filters = self._to_json_spec()
193        spec_json_bytes = json.dumps(spec).encode("utf-8")
194        spec_obj_json_gzip = gzip.compress(spec_json_bytes, compresslevel=1)
195        spec_obj_json_gzip_b64 = base64.b64encode(spec_obj_json_gzip).decode("utf-8")
196
197        sources = [
198            {
199                "name": s._name,
200                "path": s._path,
201                "stream": s._stream,
202                "service": s._service.as_json() if s._service is not None else None,
203            }
204            for s in sources
205        ]
206        filters = {
207            k: {
208                "filter": v._func,
209                "args": v._kwargs,
210            }
211            for k, v in filters.items()
212        }
213        arrays = []
214
215        resp = server._export(
216            pth,
217            spec_obj_json_gzip_b64,
218            sources,
219            filters,
220            arrays,
221            self._fmt,
222            encoder,
223            encoder_opts,
224            format,
225        )
226
227        return resp
228
229    def _vrod_bench(self, server):
230        out = {}
231        pth = "spec.json"
232        start_t = time.time()
233        with open(pth, "w") as outfile:
234            spec, sources, filters = self._to_json_spec()
235            outfile.write(json.dumps(spec))
236
237        sources = [
238            {
239                "name": s._name,
240                "path": s._path,
241                "stream": s._stream,
242                "service": s._service.as_json() if s._service is not None else None,
243            }
244            for s in sources
245        ]
246        filters = {
247            k: {
248                "filter": v._func,
249                "args": v._kwargs,
250            }
251            for k, v in filters.items()
252        }
253        arrays = []
254        end_t = time.time()
255        out["vrod_create_spec"] = end_t - start_t
256
257        start = time.time()
258        resp = server._new(pth, sources, filters, arrays, self._fmt)
259        end = time.time()
260        out["vrod_register"] = end - start
261
262        stream_url = resp["stream_url"]
263        first_segment = stream_url.replace("stream.m3u8", "segment-0.ts")
264
265        start = time.time()
266        r = requests.get(first_segment)
267        r.raise_for_status()
268        end = time.time()
269        out["vrod_first_segment"] = end - start
270        return out
271
272    def _dve2_bench(self, server):
273        pth = "spec.json"
274        out = {}
275        start_t = time.time()
276        with open(pth, "w") as outfile:
277            spec, sources, filters = self._to_json_spec()
278            outfile.write(json.dumps(spec))
279
280        sources = [
281            {
282                "name": s._name,
283                "path": s._path,
284                "stream": s._stream,
285                "service": s._service.as_json() if s._service is not None else None,
286            }
287            for s in sources
288        ]
289        filters = {
290            k: {
291                "filter": v._func,
292                "args": v._kwargs,
293            }
294            for k, v in filters.items()
295        }
296        arrays = []
297        end_t = time.time()
298        out["dve2_create_spec"] = end_t - start_t
299
300        start = time.time()
301        resp = server._export(pth, sources, filters, arrays, self._fmt, None, None)
302        end = time.time()
303        out["dve2_exec"] = end - start
304        return out
305
306
307class Loader:
308    def __init__(self, server, namespace: str, domain):
309        self._server = server
310        self._namespace = namespace
311        self._domain = domain
312
313    def _chunk(self, start_i, end_i):
314        return self._server._raw(self._namespace, start_i, end_i)
315
316    def __len__(self):
317        return len(self._domain)
318
319    def _find_index_by_rational(self, value):
320        if value not in self._domain:
321            raise ValueError(f"Rational timestamp {value} is not in the domain")
322        return self._domain.index(value)
323
324    def __getitem__(self, index):
325        if isinstance(index, slice):
326            start = index.start if index.start is not None else 0
327            end = index.stop if index.stop is not None else len(self._domain)
328            assert start >= 0 and start < len(self._domain)
329            assert end >= 0 and end <= len(self._domain)
330            assert start <= end
331            num_frames = end - start
332            all_bytes = self._chunk(start, end - 1)
333            all_bytes_len = len(all_bytes)
334            assert all_bytes_len % num_frames == 0
335            return [
336                all_bytes[
337                    i
338                    * all_bytes_len
339                    // num_frames : (i + 1)
340                    * all_bytes_len
341                    // num_frames
342                ]
343                for i in range(num_frames)
344            ]
345        elif isinstance(index, int):
346            assert index >= 0 and index < len(self._domain)
347            return self._chunk(index, index)
348        else:
349            raise TypeError(
350                "Invalid argument type for iloc. Use a slice or an integer."
351            )
352
353
354class YrdenServer:
355    """A connection to a Yrden server"""
356
357    def __init__(self, domain=None, port=None, bin="vidformer-cli"):
358        """Connect to a Yrden server
359
360        Can either connect to an existing server, if domain and port are provided, or start a new server using the provided binary
361        """
362
363        self._domain = domain
364        self._port = port
365        self._proc = None
366        if self._port is None:
367            assert bin is not None
368            self._domain = "localhost"
369            self._port = random.randint(49152, 65535)
370            cmd = [bin, "yrden", "--port", str(self._port)]
371            if _in_notebook:
372                # We need to print the URL in the notebook
373                # This is a trick to get VS Code to forward the port
374                cmd += ["--print-url"]
375            self._proc = subprocess.Popen(cmd)
376
377        version = _check_hls_link_exists(f"http://{self._domain}:{self._port}/")
378        if version is None:
379            raise Exception("Failed to connect to server")
380
381        expected_version = f"vidformer-yrden v{__version__}"
382        if version != expected_version:
383            print(
384                f"Warning: Expected version `{expected_version}`, got `{version}`. API may not be compatible!"
385            )
386
387    def _source(self, name: str, path: str, stream: int, service):
388        r = requests.post(
389            f"http://{self._domain}:{self._port}/source",
390            json={
391                "name": name,
392                "path": path,
393                "stream": stream,
394                "service": service.as_json() if service is not None else None,
395            },
396        )
397        if not r.ok:
398            raise Exception(r.text)
399
400        resp = r.json()
401        resp["ts"] = [Fraction(x[0], x[1]) for x in resp["ts"]]
402        return resp
403
404    def _new(self, spec, sources, filters, arrays, fmt):
405        req = {
406            "spec": spec,
407            "sources": sources,
408            "filters": filters,
409            "arrays": arrays,
410            "width": fmt["width"],
411            "height": fmt["height"],
412            "pix_fmt": fmt["pix_fmt"],
413        }
414
415        r = requests.post(f"http://{self._domain}:{self._port}/new", json=req)
416        if not r.ok:
417            raise Exception(r.text)
418
419        return r.json()
420
421    def _export(
422        self, pth, spec, sources, filters, arrays, fmt, encoder, encoder_opts, format
423    ):
424        req = {
425            "spec": spec,
426            "sources": sources,
427            "filters": filters,
428            "arrays": arrays,
429            "width": fmt["width"],
430            "height": fmt["height"],
431            "pix_fmt": fmt["pix_fmt"],
432            "output_path": pth,
433            "encoder": encoder,
434            "encoder_opts": encoder_opts,
435            "format": format,
436        }
437
438        r = requests.post(f"http://{self._domain}:{self._port}/export", json=req)
439        if not r.ok:
440            raise Exception(r.text)
441
442        return r.json()
443
444    def _raw(self, namespace, start_i, end_i):
445        r = requests.get(
446            f"http://{self._domain}:{self._port}/{namespace}/raw/{start_i}-{end_i}"
447        )
448        if not r.ok:
449            raise Exception(r.text)
450        return r.content
451
452    def hls_js_url(self):
453        """Return the link to the yrden-hosted hls.js file"""
454        return f"http://{self._domain}:{self._port}/hls.js"
455
456    def __del__(self):
457        if self._proc is not None:
458            self._proc.kill()
459
460
461class SourceExpr:
462    def __init__(self, source, idx, is_iloc):
463        self._source = source
464        self._idx = idx
465        self._is_iloc = is_iloc
466
467    def __repr__(self):
468        if self._is_iloc:
469            return f"{self._source._name}.iloc[{self._idx}]"
470        else:
471            return f"{self._source._name}[{self._idx}]"
472
473    def _to_json_spec(self):
474        if self._is_iloc:
475            return {
476                "Source": {
477                    "video": self._source._name,
478                    "index": {"ILoc": int(self._idx)},
479                }
480            }
481        else:
482            return {
483                "Source": {
484                    "video": self._source._name,
485                    "index": {"T": [self._idx.numerator, self._idx.denominator]},
486                }
487            }
488
489    def _sources(self):
490        return set([self._source])
491
492    def _filters(self):
493        return {}
494
495
496class SourceILoc:
497    def __init__(self, source):
498        self._source = source
499
500    def __getitem__(self, idx):
501        if type(idx) != int:
502            raise Exception("Source iloc index must be an integer")
503        return SourceExpr(self._source, idx, True)
504
505
506class Source:
507    def __init__(
508        self, server: YrdenServer, name: str, path: str, stream: int, service=None
509    ):
510        self._server = server
511        self._name = name
512        self._path = path
513        self._stream = stream
514        self._service = service
515
516        self.iloc = SourceILoc(self)
517
518        self._src = self._server._source(
519            self._name, self._path, self._stream, self._service
520        )
521
522    def fmt(self):
523        return {
524            "width": self._src["width"],
525            "height": self._src["height"],
526            "pix_fmt": self._src["pix_fmt"],
527        }
528
529    def ts(self):
530        return self._src["ts"]
531
532    def __len__(self):
533        return len(self._src["ts"])
534
535    def __getitem__(self, idx):
536        if type(idx) != Fraction:
537            raise Exception("Source index must be a Fraction")
538        return SourceExpr(self, idx, False)
539
540    def play(self, *args, **kwargs):
541        """Play the video live in the notebook."""
542
543        domain = self.ts()
544        render = lambda t, i: self[t]
545        spec = Spec(domain, render, self.fmt())
546        return spec.play(*args, **kwargs)
547
548
549class StorageService:
550    def __init__(self, service: str, **kwargs):
551        if type(service) != str:
552            raise Exception("Service name must be a string")
553        self._service = service
554        for k, v in kwargs.items():
555            if type(v) != str:
556                raise Exception(f"Value of {k} must be a string")
557        self._config = kwargs
558
559    def as_json(self):
560        return {"service": self._service, "config": self._config}
561
562    def __repr__(self):
563        return f"{self._service}(config={self._config})"
564
565
566def _json_arg(arg):
567    if type(arg) == FilterExpr or type(arg) == SourceExpr:
568        return {"Frame": arg._to_json_spec()}
569    elif type(arg) == int:
570        return {"Data": {"Int": arg}}
571    elif type(arg) == str:
572        return {"Data": {"String": arg}}
573    elif type(arg) == bool:
574        return {"Data": {"Bool": arg}}
575    else:
576        assert False
577
578
579class Filter:
580    def __init__(self, name: str, tl_func=None, **kwargs):
581        self._name = name
582
583        # tl_func is the top level func, which is the true implementation, not just a pretty name
584        if tl_func is None:
585            self._func = name
586        else:
587            self._func = tl_func
588
589        # filter infra args, not invocation args
590        for k, v in kwargs.items():
591            if type(v) != str:
592                raise Exception(f"Value of {k} must be a string")
593        self._kwargs = kwargs
594
595    def __call__(self, *args, **kwargs):
596        return FilterExpr(self, args, kwargs)
597
598
599class FilterExpr:
600    def __init__(self, filter: Filter, args, kwargs):
601        self._filter = filter
602        self._args = args
603        self._kwargs = kwargs
604
605    def __repr__(self):
606        args = []
607        for arg in self._args:
608            val = f'"{arg}"' if type(arg) == str else str(arg)
609            args.append(str(val))
610        for k, v in self._kwargs.items():
611            val = f'"{v}"' if type(v) == str else str(v)
612            args.append(f"{k}={val}")
613        return f"{self._filter._name}({', '.join(args)})"
614
615    def _to_json_spec(self):
616        args = []
617        for arg in self._args:
618            args.append(_json_arg(arg))
619        kwargs = {}
620        for k, v in self._kwargs.items():
621            kwargs[k] = _json_arg(v)
622        return {"Filter": {"name": self._filter._name, "args": args, "kwargs": kwargs}}
623
624    def _sources(self):
625        s = set()
626        for arg in self._args:
627            if type(arg) == FilterExpr or type(arg) == SourceExpr:
628                s = s.union(arg._sources())
629        for arg in self._kwargs.values():
630            if type(arg) == FilterExpr or type(arg) == SourceExpr:
631                s = s.union(arg._sources())
632        return s
633
634    def _filters(self):
635        f = {self._filter._name: self._filter}
636        for arg in self._args:
637            if type(arg) == FilterExpr:
638                f = {**f, **arg._filters()}
639        for arg in self._kwargs.values():
640            if type(arg) == FilterExpr:
641                f = {**f, **arg._filters()}
642        return f
643
644
645class UDF:
646    """User-defined filter superclass"""
647
648    def __init__(self, name: str):
649        self._name = name
650        self._socket_path = None
651        self._p = None
652
653    def filter(self, *args, **kwargs):
654        raise Exception("User must implement the filter method")
655
656    def filter_type(self, *args, **kwargs):
657        raise Exception("User must implement the filter_type method")
658
659    def into_filter(self):
660        assert self._socket_path is None
661        self._socket_path = f"/tmp/vidformer-{self._name}-{str(uuid.uuid4())}.sock"
662        self._p = multiprocessing.Process(
663            target=_run_udf_host, args=(self, self._socket_path)
664        )
665        self._p.start()
666        return Filter(
667            name=self._name, tl_func="IPC", socket=self._socket_path, func=self._name
668        )
669
670    def _handle_connection(self, connection):
671        try:
672            while True:
673                frame_len = connection.recv(4)
674                if not frame_len or len(frame_len) != 4:
675                    break
676                frame_len = int.from_bytes(frame_len, byteorder="big")
677                data = connection.recv(frame_len)
678                if not data:
679                    break
680
681                while len(data) < frame_len:
682                    new_data = connection.recv(frame_len - len(data))
683                    if not new_data:
684                        raise Exception("Partial data received")
685                    data += new_data
686
687                obj = msgpack.unpackb(data, raw=False)
688                f_func, f_op, f_args, f_kwargs = (
689                    obj["func"],
690                    obj["op"],
691                    obj["args"],
692                    obj["kwargs"],
693                )
694
695                response = None
696                if f_op == "filter":
697                    f_args = [self._deser_filter(x) for x in f_args]
698                    f_kwargs = {k: self._deser_filter(v) for k, v in f_kwargs}
699                    response = self.filter(*f_args, **f_kwargs)
700                    if type(response) != UDFFrame:
701                        raise Exception(
702                            f"filter must return a UDFFrame, got {type(response)}"
703                        )
704                    if response.frame_type().pix_fmt() != "rgb24":
705                        raise Exception(
706                            f"filter must return a frame with pix_fmt 'rgb24', got {response.frame_type().pix_fmt()}"
707                        )
708
709                    response = response._response_ser()
710                elif f_op == "filter_type":
711                    f_args = [self._deser_filter_type(x) for x in f_args]
712                    f_kwargs = {k: self._deser_filter_type(v) for k, v in f_kwargs}
713                    response = self.filter_type(*f_args, **f_kwargs)
714                    if type(response) != UDFFrameType:
715                        raise Exception(
716                            f"filter_type must return a UDFFrameType, got {type(response)}"
717                        )
718                    if response.pix_fmt() != "rgb24":
719                        raise Exception(
720                            f"filter_type must return a frame with pix_fmt 'rgb24', got {response.pix_fmt()}"
721                        )
722                    response = response._response_ser()
723                else:
724                    raise Exception(f"Unknown operation: {f_op}")
725
726                response = msgpack.packb(response, use_bin_type=True)
727                response_len = len(response).to_bytes(4, byteorder="big")
728                connection.sendall(response_len)
729                connection.sendall(response)
730        finally:
731            connection.close()
732
733    def _deser_filter_type(self, obj):
734        assert type(obj) == dict
735        keys = list(obj.keys())
736        assert len(keys) == 1
737        type_key = keys[0]
738        assert type_key in ["Frame", "String", "Int", "Bool"]
739
740        if type_key == "Frame":
741            frame = obj[type_key]
742            assert type(frame) == dict
743            assert "width" in frame
744            assert "height" in frame
745            assert "format" in frame
746            assert type(frame["width"]) == int
747            assert type(frame["height"]) == int
748            assert frame["format"] == 2  # AV_PIX_FMT_RGB24
749            return UDFFrameType(frame["width"], frame["height"], "rgb24")
750        elif type_key == "String":
751            assert type(obj[type_key]) == str
752            return obj[type_key]
753        elif type_key == "Int":
754            assert type(obj[type_key]) == int
755            return obj[type_key]
756        elif type_key == "Bool":
757            assert type(obj[type_key]) == bool
758            return obj[type_key]
759        else:
760            assert False
761
762    def _deser_filter(self, obj):
763        assert type(obj) == dict
764        keys = list(obj.keys())
765        assert len(keys) == 1
766        type_key = keys[0]
767        assert type_key in ["Frame", "String", "Int", "Bool"]
768
769        if type_key == "Frame":
770            frame = obj[type_key]
771            assert type(frame) == dict
772            assert "data" in frame
773            assert "width" in frame
774            assert "height" in frame
775            assert "format" in frame
776            assert type(frame["width"]) == int
777            assert type(frame["height"]) == int
778            assert frame["format"] == "rgb24"
779            assert type(frame["data"]) == bytes
780
781            data = np.frombuffer(frame["data"], dtype=np.uint8)
782            data = data.reshape(frame["height"], frame["width"], 3)
783            return UDFFrame(
784                data, UDFFrameType(frame["width"], frame["height"], "rgb24")
785            )
786        elif type_key == "String":
787            assert type(obj[type_key]) == str
788            return obj[type_key]
789        elif type_key == "Int":
790            assert type(obj[type_key]) == int
791            return obj[type_key]
792        elif type_key == "Bool":
793            assert type(obj[type_key]) == bool
794            return obj[type_key]
795        else:
796            assert False
797
798    def _host(self, socket_path: str):
799        if os.path.exists(socket_path):
800            os.remove(socket_path)
801
802        # start listening on the socket
803        sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
804        sock.bind(socket_path)
805        sock.listen(1)
806
807        while True:
808            # accept incoming connection
809            connection, client_address = sock.accept()
810            thread = threading.Thread(
811                target=self._handle_connection, args=(connection,)
812            )
813            thread.start()
814
815    def __del__(self):
816        if self._socket_path is not None:
817            self._p.terminate()
818            if os.path.exists(self._socket_path):
819                # it's possible the process hasn't even created the socket yet
820                os.remove(self._socket_path)
821
822
823class UDFFrameType:
824    def __init__(self, width: int, height: int, pix_fmt: str):
825        assert type(width) == int
826        assert type(height) == int
827        assert type(pix_fmt) == str
828
829        self._width = width
830        self._height = height
831        self._pix_fmt = pix_fmt
832
833    def width(self):
834        return self._width
835
836    def height(self):
837        return self._height
838
839    def pix_fmt(self):
840        return self._pix_fmt
841
842    def _response_ser(self):
843        return {
844            "frame_type": {
845                "width": self._width,
846                "height": self._height,
847                "format": 2,  # AV_PIX_FMT_RGB24
848            }
849        }
850
851    def __repr__(self):
852        return f"FrameType<{self._width}x{self._height}, {self._pix_fmt}>"
853
854
855class UDFFrame:
856    def __init__(self, data: np.ndarray, f_type: UDFFrameType):
857        assert type(data) == np.ndarray
858        assert type(f_type) == UDFFrameType
859
860        # We only support RGB24 for now
861        assert data.dtype == np.uint8
862        assert data.shape[2] == 3
863
864        # check type matches
865        assert data.shape[0] == f_type.height()
866        assert data.shape[1] == f_type.width()
867        assert f_type.pix_fmt() == "rgb24"
868
869        self._data = data
870        self._f_type = f_type
871
872    def data(self):
873        return self._data
874
875    def frame_type(self):
876        return self._f_type
877
878    def _response_ser(self):
879        return {
880            "frame": {
881                "data": self._data.tobytes(),
882                "width": self._f_type.width(),
883                "height": self._f_type.height(),
884                "format": "rgb24",
885            }
886        }
887
888    def __repr__(self):
889        return f"Frame<{self._f_type.width()}x{self._f_type.height()}, {self._f_type.pix_fmt()}>"
890
891
892def _run_udf_host(udf: UDF, socket_path: str):
893    udf._host(socket_path)
class Spec:
 47class Spec:
 48    def __init__(self, domain: list[Fraction], render, fmt: dict):
 49        self._domain = domain
 50        self._render = render
 51        self._fmt = fmt
 52
 53    def __repr__(self):
 54        lines = []
 55        for i, t in enumerate(self._domain):
 56            frame_expr = self._render(t, i)
 57            lines.append(
 58                f"{t.numerator}/{t.denominator} => {frame_expr}",
 59            )
 60        return "\n".join(lines)
 61
 62    def _sources(self):
 63        s = set()
 64        for i, t in enumerate(self._domain):
 65            frame_expr = self._render(t, i)
 66            s = s.union(frame_expr._sources())
 67        return s
 68
 69    def _to_json_spec(self):
 70        frames = []
 71        s = set()
 72        f = {}
 73        for i, t in enumerate(self._domain):
 74            frame_expr = self._render(t, i)
 75            s = s.union(frame_expr._sources())
 76            f = {**f, **frame_expr._filters()}
 77            frame = [[t.numerator, t.denominator], frame_expr._to_json_spec()]
 78            frames.append(frame)
 79        return {"frames": frames}, s, f
 80
 81    def play(self, server, method="html", verbose=False):
 82        """Play the video live in the notebook."""
 83
 84        spec, sources, filters = self._to_json_spec()
 85        spec_json_bytes = json.dumps(spec).encode("utf-8")
 86        spec_obj_json_gzip = gzip.compress(spec_json_bytes, compresslevel=1)
 87        spec_obj_json_gzip_b64 = base64.b64encode(spec_obj_json_gzip).decode("utf-8")
 88
 89        sources = [
 90            {
 91                "name": s._name,
 92                "path": s._path,
 93                "stream": s._stream,
 94                "service": s._service.as_json() if s._service is not None else None,
 95            }
 96            for s in sources
 97        ]
 98        filters = {
 99            k: {
100                "filter": v._func,
101                "args": v._kwargs,
102            }
103            for k, v in filters.items()
104        }
105        arrays = []
106
107        if verbose:
108            print(f"Sending to server. Spec is {len(spec_obj_json_gzip_b64)} bytes")
109
110        resp = server._new(spec_obj_json_gzip_b64, sources, filters, arrays, self._fmt)
111        hls_video_url = resp["stream_url"]
112        hls_player_url = resp["player_url"]
113        namespace = resp["namespace"]
114        hls_js_url = server.hls_js_url()
115
116        if method == "link":
117            return hls_video_url
118        if method == "player":
119            return hls_player_url
120        if method == "iframe":
121            from IPython.display import IFrame
122
123            return IFrame(hls_player_url, width=1280, height=720)
124        if method == "html":
125            from IPython.display import HTML
126
127            # We add a namespace to the video element to avoid conflicts with other videos
128            html_code = f"""
129<!DOCTYPE html>
130<html>
131<head>
132    <title>HLS Video Player</title>
133    <!-- Include hls.js library -->
134    <script src="{hls_js_url}"></script>
135</head>
136<body>
137    <!-- Video element -->
138    <video id="video-{namespace}" controls width="640" height="360" autoplay></video>
139    <script>
140        var video = document.getElementById('video-{namespace}');
141        var videoSrc = '{hls_video_url}';
142        var hls = new Hls();
143        hls.loadSource(videoSrc);
144        hls.attachMedia(video);
145        hls.on(Hls.Events.MANIFEST_PARSED, function() {{
146            video.play();
147        }});
148    </script>
149</body>
150</html>
151"""
152            return HTML(data=html_code)
153        else:
154            return hls_player_url
155
156    def load(self, server):
157        spec, sources, filters = self._to_json_spec()
158        spec_json_bytes = json.dumps(spec).encode("utf-8")
159        spec_obj_json_gzip = gzip.compress(spec_json_bytes, compresslevel=1)
160        spec_obj_json_gzip_b64 = base64.b64encode(spec_obj_json_gzip).decode("utf-8")
161
162        sources = [
163            {
164                "name": s._name,
165                "path": s._path,
166                "stream": s._stream,
167                "service": s._service.as_json() if s._service is not None else None,
168            }
169            for s in sources
170        ]
171        filters = {
172            k: {
173                "filter": v._func,
174                "args": v._kwargs,
175            }
176            for k, v in filters.items()
177        }
178        arrays = []
179
180        resp = server._new(spec_obj_json_gzip_b64, sources, filters, arrays, self._fmt)
181        namespace = resp["namespace"]
182        return Loader(server, namespace, self._domain)
183
184    def save(self, server, pth, encoder=None, encoder_opts=None, format=None):
185        """Save the video to a file."""
186
187        assert encoder is None or type(encoder) == str
188        assert encoder_opts is None or type(encoder_opts) == dict
189        if encoder_opts is not None:
190            for k, v in encoder_opts.items():
191                assert type(k) == str and type(v) == str
192
193        spec, sources, filters = self._to_json_spec()
194        spec_json_bytes = json.dumps(spec).encode("utf-8")
195        spec_obj_json_gzip = gzip.compress(spec_json_bytes, compresslevel=1)
196        spec_obj_json_gzip_b64 = base64.b64encode(spec_obj_json_gzip).decode("utf-8")
197
198        sources = [
199            {
200                "name": s._name,
201                "path": s._path,
202                "stream": s._stream,
203                "service": s._service.as_json() if s._service is not None else None,
204            }
205            for s in sources
206        ]
207        filters = {
208            k: {
209                "filter": v._func,
210                "args": v._kwargs,
211            }
212            for k, v in filters.items()
213        }
214        arrays = []
215
216        resp = server._export(
217            pth,
218            spec_obj_json_gzip_b64,
219            sources,
220            filters,
221            arrays,
222            self._fmt,
223            encoder,
224            encoder_opts,
225            format,
226        )
227
228        return resp
229
230    def _vrod_bench(self, server):
231        out = {}
232        pth = "spec.json"
233        start_t = time.time()
234        with open(pth, "w") as outfile:
235            spec, sources, filters = self._to_json_spec()
236            outfile.write(json.dumps(spec))
237
238        sources = [
239            {
240                "name": s._name,
241                "path": s._path,
242                "stream": s._stream,
243                "service": s._service.as_json() if s._service is not None else None,
244            }
245            for s in sources
246        ]
247        filters = {
248            k: {
249                "filter": v._func,
250                "args": v._kwargs,
251            }
252            for k, v in filters.items()
253        }
254        arrays = []
255        end_t = time.time()
256        out["vrod_create_spec"] = end_t - start_t
257
258        start = time.time()
259        resp = server._new(pth, sources, filters, arrays, self._fmt)
260        end = time.time()
261        out["vrod_register"] = end - start
262
263        stream_url = resp["stream_url"]
264        first_segment = stream_url.replace("stream.m3u8", "segment-0.ts")
265
266        start = time.time()
267        r = requests.get(first_segment)
268        r.raise_for_status()
269        end = time.time()
270        out["vrod_first_segment"] = end - start
271        return out
272
273    def _dve2_bench(self, server):
274        pth = "spec.json"
275        out = {}
276        start_t = time.time()
277        with open(pth, "w") as outfile:
278            spec, sources, filters = self._to_json_spec()
279            outfile.write(json.dumps(spec))
280
281        sources = [
282            {
283                "name": s._name,
284                "path": s._path,
285                "stream": s._stream,
286                "service": s._service.as_json() if s._service is not None else None,
287            }
288            for s in sources
289        ]
290        filters = {
291            k: {
292                "filter": v._func,
293                "args": v._kwargs,
294            }
295            for k, v in filters.items()
296        }
297        arrays = []
298        end_t = time.time()
299        out["dve2_create_spec"] = end_t - start_t
300
301        start = time.time()
302        resp = server._export(pth, sources, filters, arrays, self._fmt, None, None)
303        end = time.time()
304        out["dve2_exec"] = end - start
305        return out
Spec(domain: list[fractions.Fraction], render, fmt: dict)
48    def __init__(self, domain: list[Fraction], render, fmt: dict):
49        self._domain = domain
50        self._render = render
51        self._fmt = fmt
def play(self, server, method='html', verbose=False):
 81    def play(self, server, method="html", verbose=False):
 82        """Play the video live in the notebook."""
 83
 84        spec, sources, filters = self._to_json_spec()
 85        spec_json_bytes = json.dumps(spec).encode("utf-8")
 86        spec_obj_json_gzip = gzip.compress(spec_json_bytes, compresslevel=1)
 87        spec_obj_json_gzip_b64 = base64.b64encode(spec_obj_json_gzip).decode("utf-8")
 88
 89        sources = [
 90            {
 91                "name": s._name,
 92                "path": s._path,
 93                "stream": s._stream,
 94                "service": s._service.as_json() if s._service is not None else None,
 95            }
 96            for s in sources
 97        ]
 98        filters = {
 99            k: {
100                "filter": v._func,
101                "args": v._kwargs,
102            }
103            for k, v in filters.items()
104        }
105        arrays = []
106
107        if verbose:
108            print(f"Sending to server. Spec is {len(spec_obj_json_gzip_b64)} bytes")
109
110        resp = server._new(spec_obj_json_gzip_b64, sources, filters, arrays, self._fmt)
111        hls_video_url = resp["stream_url"]
112        hls_player_url = resp["player_url"]
113        namespace = resp["namespace"]
114        hls_js_url = server.hls_js_url()
115
116        if method == "link":
117            return hls_video_url
118        if method == "player":
119            return hls_player_url
120        if method == "iframe":
121            from IPython.display import IFrame
122
123            return IFrame(hls_player_url, width=1280, height=720)
124        if method == "html":
125            from IPython.display import HTML
126
127            # We add a namespace to the video element to avoid conflicts with other videos
128            html_code = f"""
129<!DOCTYPE html>
130<html>
131<head>
132    <title>HLS Video Player</title>
133    <!-- Include hls.js library -->
134    <script src="{hls_js_url}"></script>
135</head>
136<body>
137    <!-- Video element -->
138    <video id="video-{namespace}" controls width="640" height="360" autoplay></video>
139    <script>
140        var video = document.getElementById('video-{namespace}');
141        var videoSrc = '{hls_video_url}';
142        var hls = new Hls();
143        hls.loadSource(videoSrc);
144        hls.attachMedia(video);
145        hls.on(Hls.Events.MANIFEST_PARSED, function() {{
146            video.play();
147        }});
148    </script>
149</body>
150</html>
151"""
152            return HTML(data=html_code)
153        else:
154            return hls_player_url

Play the video live in the notebook.

def load(self, server):
156    def load(self, server):
157        spec, sources, filters = self._to_json_spec()
158        spec_json_bytes = json.dumps(spec).encode("utf-8")
159        spec_obj_json_gzip = gzip.compress(spec_json_bytes, compresslevel=1)
160        spec_obj_json_gzip_b64 = base64.b64encode(spec_obj_json_gzip).decode("utf-8")
161
162        sources = [
163            {
164                "name": s._name,
165                "path": s._path,
166                "stream": s._stream,
167                "service": s._service.as_json() if s._service is not None else None,
168            }
169            for s in sources
170        ]
171        filters = {
172            k: {
173                "filter": v._func,
174                "args": v._kwargs,
175            }
176            for k, v in filters.items()
177        }
178        arrays = []
179
180        resp = server._new(spec_obj_json_gzip_b64, sources, filters, arrays, self._fmt)
181        namespace = resp["namespace"]
182        return Loader(server, namespace, self._domain)
def save(self, server, pth, encoder=None, encoder_opts=None, format=None):
184    def save(self, server, pth, encoder=None, encoder_opts=None, format=None):
185        """Save the video to a file."""
186
187        assert encoder is None or type(encoder) == str
188        assert encoder_opts is None or type(encoder_opts) == dict
189        if encoder_opts is not None:
190            for k, v in encoder_opts.items():
191                assert type(k) == str and type(v) == str
192
193        spec, sources, filters = self._to_json_spec()
194        spec_json_bytes = json.dumps(spec).encode("utf-8")
195        spec_obj_json_gzip = gzip.compress(spec_json_bytes, compresslevel=1)
196        spec_obj_json_gzip_b64 = base64.b64encode(spec_obj_json_gzip).decode("utf-8")
197
198        sources = [
199            {
200                "name": s._name,
201                "path": s._path,
202                "stream": s._stream,
203                "service": s._service.as_json() if s._service is not None else None,
204            }
205            for s in sources
206        ]
207        filters = {
208            k: {
209                "filter": v._func,
210                "args": v._kwargs,
211            }
212            for k, v in filters.items()
213        }
214        arrays = []
215
216        resp = server._export(
217            pth,
218            spec_obj_json_gzip_b64,
219            sources,
220            filters,
221            arrays,
222            self._fmt,
223            encoder,
224            encoder_opts,
225            format,
226        )
227
228        return resp

Save the video to a file.

class Loader:
308class Loader:
309    def __init__(self, server, namespace: str, domain):
310        self._server = server
311        self._namespace = namespace
312        self._domain = domain
313
314    def _chunk(self, start_i, end_i):
315        return self._server._raw(self._namespace, start_i, end_i)
316
317    def __len__(self):
318        return len(self._domain)
319
320    def _find_index_by_rational(self, value):
321        if value not in self._domain:
322            raise ValueError(f"Rational timestamp {value} is not in the domain")
323        return self._domain.index(value)
324
325    def __getitem__(self, index):
326        if isinstance(index, slice):
327            start = index.start if index.start is not None else 0
328            end = index.stop if index.stop is not None else len(self._domain)
329            assert start >= 0 and start < len(self._domain)
330            assert end >= 0 and end <= len(self._domain)
331            assert start <= end
332            num_frames = end - start
333            all_bytes = self._chunk(start, end - 1)
334            all_bytes_len = len(all_bytes)
335            assert all_bytes_len % num_frames == 0
336            return [
337                all_bytes[
338                    i
339                    * all_bytes_len
340                    // num_frames : (i + 1)
341                    * all_bytes_len
342                    // num_frames
343                ]
344                for i in range(num_frames)
345            ]
346        elif isinstance(index, int):
347            assert index >= 0 and index < len(self._domain)
348            return self._chunk(index, index)
349        else:
350            raise TypeError(
351                "Invalid argument type for iloc. Use a slice or an integer."
352            )
Loader(server, namespace: str, domain)
309    def __init__(self, server, namespace: str, domain):
310        self._server = server
311        self._namespace = namespace
312        self._domain = domain
class YrdenServer:
355class YrdenServer:
356    """A connection to a Yrden server"""
357
358    def __init__(self, domain=None, port=None, bin="vidformer-cli"):
359        """Connect to a Yrden server
360
361        Can either connect to an existing server, if domain and port are provided, or start a new server using the provided binary
362        """
363
364        self._domain = domain
365        self._port = port
366        self._proc = None
367        if self._port is None:
368            assert bin is not None
369            self._domain = "localhost"
370            self._port = random.randint(49152, 65535)
371            cmd = [bin, "yrden", "--port", str(self._port)]
372            if _in_notebook:
373                # We need to print the URL in the notebook
374                # This is a trick to get VS Code to forward the port
375                cmd += ["--print-url"]
376            self._proc = subprocess.Popen(cmd)
377
378        version = _check_hls_link_exists(f"http://{self._domain}:{self._port}/")
379        if version is None:
380            raise Exception("Failed to connect to server")
381
382        expected_version = f"vidformer-yrden v{__version__}"
383        if version != expected_version:
384            print(
385                f"Warning: Expected version `{expected_version}`, got `{version}`. API may not be compatible!"
386            )
387
388    def _source(self, name: str, path: str, stream: int, service):
389        r = requests.post(
390            f"http://{self._domain}:{self._port}/source",
391            json={
392                "name": name,
393                "path": path,
394                "stream": stream,
395                "service": service.as_json() if service is not None else None,
396            },
397        )
398        if not r.ok:
399            raise Exception(r.text)
400
401        resp = r.json()
402        resp["ts"] = [Fraction(x[0], x[1]) for x in resp["ts"]]
403        return resp
404
405    def _new(self, spec, sources, filters, arrays, fmt):
406        req = {
407            "spec": spec,
408            "sources": sources,
409            "filters": filters,
410            "arrays": arrays,
411            "width": fmt["width"],
412            "height": fmt["height"],
413            "pix_fmt": fmt["pix_fmt"],
414        }
415
416        r = requests.post(f"http://{self._domain}:{self._port}/new", json=req)
417        if not r.ok:
418            raise Exception(r.text)
419
420        return r.json()
421
422    def _export(
423        self, pth, spec, sources, filters, arrays, fmt, encoder, encoder_opts, format
424    ):
425        req = {
426            "spec": spec,
427            "sources": sources,
428            "filters": filters,
429            "arrays": arrays,
430            "width": fmt["width"],
431            "height": fmt["height"],
432            "pix_fmt": fmt["pix_fmt"],
433            "output_path": pth,
434            "encoder": encoder,
435            "encoder_opts": encoder_opts,
436            "format": format,
437        }
438
439        r = requests.post(f"http://{self._domain}:{self._port}/export", json=req)
440        if not r.ok:
441            raise Exception(r.text)
442
443        return r.json()
444
445    def _raw(self, namespace, start_i, end_i):
446        r = requests.get(
447            f"http://{self._domain}:{self._port}/{namespace}/raw/{start_i}-{end_i}"
448        )
449        if not r.ok:
450            raise Exception(r.text)
451        return r.content
452
453    def hls_js_url(self):
454        """Return the link to the yrden-hosted hls.js file"""
455        return f"http://{self._domain}:{self._port}/hls.js"
456
457    def __del__(self):
458        if self._proc is not None:
459            self._proc.kill()

A connection to a Yrden server

YrdenServer(domain=None, port=None, bin='vidformer-cli')
358    def __init__(self, domain=None, port=None, bin="vidformer-cli"):
359        """Connect to a Yrden server
360
361        Can either connect to an existing server, if domain and port are provided, or start a new server using the provided binary
362        """
363
364        self._domain = domain
365        self._port = port
366        self._proc = None
367        if self._port is None:
368            assert bin is not None
369            self._domain = "localhost"
370            self._port = random.randint(49152, 65535)
371            cmd = [bin, "yrden", "--port", str(self._port)]
372            if _in_notebook:
373                # We need to print the URL in the notebook
374                # This is a trick to get VS Code to forward the port
375                cmd += ["--print-url"]
376            self._proc = subprocess.Popen(cmd)
377
378        version = _check_hls_link_exists(f"http://{self._domain}:{self._port}/")
379        if version is None:
380            raise Exception("Failed to connect to server")
381
382        expected_version = f"vidformer-yrden v{__version__}"
383        if version != expected_version:
384            print(
385                f"Warning: Expected version `{expected_version}`, got `{version}`. API may not be compatible!"
386            )

Connect to a Yrden server

Can either connect to an existing server, if domain and port are provided, or start a new server using the provided binary

def hls_js_url(self):
453    def hls_js_url(self):
454        """Return the link to the yrden-hosted hls.js file"""
455        return f"http://{self._domain}:{self._port}/hls.js"

Return the link to the yrden-hosted hls.js file

class SourceExpr:
462class SourceExpr:
463    def __init__(self, source, idx, is_iloc):
464        self._source = source
465        self._idx = idx
466        self._is_iloc = is_iloc
467
468    def __repr__(self):
469        if self._is_iloc:
470            return f"{self._source._name}.iloc[{self._idx}]"
471        else:
472            return f"{self._source._name}[{self._idx}]"
473
474    def _to_json_spec(self):
475        if self._is_iloc:
476            return {
477                "Source": {
478                    "video": self._source._name,
479                    "index": {"ILoc": int(self._idx)},
480                }
481            }
482        else:
483            return {
484                "Source": {
485                    "video": self._source._name,
486                    "index": {"T": [self._idx.numerator, self._idx.denominator]},
487                }
488            }
489
490    def _sources(self):
491        return set([self._source])
492
493    def _filters(self):
494        return {}
SourceExpr(source, idx, is_iloc)
463    def __init__(self, source, idx, is_iloc):
464        self._source = source
465        self._idx = idx
466        self._is_iloc = is_iloc
class SourceILoc:
497class SourceILoc:
498    def __init__(self, source):
499        self._source = source
500
501    def __getitem__(self, idx):
502        if type(idx) != int:
503            raise Exception("Source iloc index must be an integer")
504        return SourceExpr(self._source, idx, True)
SourceILoc(source)
498    def __init__(self, source):
499        self._source = source
class Source:
507class Source:
508    def __init__(
509        self, server: YrdenServer, name: str, path: str, stream: int, service=None
510    ):
511        self._server = server
512        self._name = name
513        self._path = path
514        self._stream = stream
515        self._service = service
516
517        self.iloc = SourceILoc(self)
518
519        self._src = self._server._source(
520            self._name, self._path, self._stream, self._service
521        )
522
523    def fmt(self):
524        return {
525            "width": self._src["width"],
526            "height": self._src["height"],
527            "pix_fmt": self._src["pix_fmt"],
528        }
529
530    def ts(self):
531        return self._src["ts"]
532
533    def __len__(self):
534        return len(self._src["ts"])
535
536    def __getitem__(self, idx):
537        if type(idx) != Fraction:
538            raise Exception("Source index must be a Fraction")
539        return SourceExpr(self, idx, False)
540
541    def play(self, *args, **kwargs):
542        """Play the video live in the notebook."""
543
544        domain = self.ts()
545        render = lambda t, i: self[t]
546        spec = Spec(domain, render, self.fmt())
547        return spec.play(*args, **kwargs)
Source( server: YrdenServer, name: str, path: str, stream: int, service=None)
508    def __init__(
509        self, server: YrdenServer, name: str, path: str, stream: int, service=None
510    ):
511        self._server = server
512        self._name = name
513        self._path = path
514        self._stream = stream
515        self._service = service
516
517        self.iloc = SourceILoc(self)
518
519        self._src = self._server._source(
520            self._name, self._path, self._stream, self._service
521        )
iloc
def fmt(self):
523    def fmt(self):
524        return {
525            "width": self._src["width"],
526            "height": self._src["height"],
527            "pix_fmt": self._src["pix_fmt"],
528        }
def ts(self):
530    def ts(self):
531        return self._src["ts"]
def play(self, *args, **kwargs):
541    def play(self, *args, **kwargs):
542        """Play the video live in the notebook."""
543
544        domain = self.ts()
545        render = lambda t, i: self[t]
546        spec = Spec(domain, render, self.fmt())
547        return spec.play(*args, **kwargs)

Play the video live in the notebook.

class StorageService:
550class StorageService:
551    def __init__(self, service: str, **kwargs):
552        if type(service) != str:
553            raise Exception("Service name must be a string")
554        self._service = service
555        for k, v in kwargs.items():
556            if type(v) != str:
557                raise Exception(f"Value of {k} must be a string")
558        self._config = kwargs
559
560    def as_json(self):
561        return {"service": self._service, "config": self._config}
562
563    def __repr__(self):
564        return f"{self._service}(config={self._config})"
StorageService(service: str, **kwargs)
551    def __init__(self, service: str, **kwargs):
552        if type(service) != str:
553            raise Exception("Service name must be a string")
554        self._service = service
555        for k, v in kwargs.items():
556            if type(v) != str:
557                raise Exception(f"Value of {k} must be a string")
558        self._config = kwargs
def as_json(self):
560    def as_json(self):
561        return {"service": self._service, "config": self._config}
class Filter:
580class Filter:
581    def __init__(self, name: str, tl_func=None, **kwargs):
582        self._name = name
583
584        # tl_func is the top level func, which is the true implementation, not just a pretty name
585        if tl_func is None:
586            self._func = name
587        else:
588            self._func = tl_func
589
590        # filter infra args, not invocation args
591        for k, v in kwargs.items():
592            if type(v) != str:
593                raise Exception(f"Value of {k} must be a string")
594        self._kwargs = kwargs
595
596    def __call__(self, *args, **kwargs):
597        return FilterExpr(self, args, kwargs)
Filter(name: str, tl_func=None, **kwargs)
581    def __init__(self, name: str, tl_func=None, **kwargs):
582        self._name = name
583
584        # tl_func is the top level func, which is the true implementation, not just a pretty name
585        if tl_func is None:
586            self._func = name
587        else:
588            self._func = tl_func
589
590        # filter infra args, not invocation args
591        for k, v in kwargs.items():
592            if type(v) != str:
593                raise Exception(f"Value of {k} must be a string")
594        self._kwargs = kwargs
class FilterExpr:
600class FilterExpr:
601    def __init__(self, filter: Filter, args, kwargs):
602        self._filter = filter
603        self._args = args
604        self._kwargs = kwargs
605
606    def __repr__(self):
607        args = []
608        for arg in self._args:
609            val = f'"{arg}"' if type(arg) == str else str(arg)
610            args.append(str(val))
611        for k, v in self._kwargs.items():
612            val = f'"{v}"' if type(v) == str else str(v)
613            args.append(f"{k}={val}")
614        return f"{self._filter._name}({', '.join(args)})"
615
616    def _to_json_spec(self):
617        args = []
618        for arg in self._args:
619            args.append(_json_arg(arg))
620        kwargs = {}
621        for k, v in self._kwargs.items():
622            kwargs[k] = _json_arg(v)
623        return {"Filter": {"name": self._filter._name, "args": args, "kwargs": kwargs}}
624
625    def _sources(self):
626        s = set()
627        for arg in self._args:
628            if type(arg) == FilterExpr or type(arg) == SourceExpr:
629                s = s.union(arg._sources())
630        for arg in self._kwargs.values():
631            if type(arg) == FilterExpr or type(arg) == SourceExpr:
632                s = s.union(arg._sources())
633        return s
634
635    def _filters(self):
636        f = {self._filter._name: self._filter}
637        for arg in self._args:
638            if type(arg) == FilterExpr:
639                f = {**f, **arg._filters()}
640        for arg in self._kwargs.values():
641            if type(arg) == FilterExpr:
642                f = {**f, **arg._filters()}
643        return f
FilterExpr(filter: Filter, args, kwargs)
601    def __init__(self, filter: Filter, args, kwargs):
602        self._filter = filter
603        self._args = args
604        self._kwargs = kwargs
class UDF:
646class UDF:
647    """User-defined filter superclass"""
648
649    def __init__(self, name: str):
650        self._name = name
651        self._socket_path = None
652        self._p = None
653
654    def filter(self, *args, **kwargs):
655        raise Exception("User must implement the filter method")
656
657    def filter_type(self, *args, **kwargs):
658        raise Exception("User must implement the filter_type method")
659
660    def into_filter(self):
661        assert self._socket_path is None
662        self._socket_path = f"/tmp/vidformer-{self._name}-{str(uuid.uuid4())}.sock"
663        self._p = multiprocessing.Process(
664            target=_run_udf_host, args=(self, self._socket_path)
665        )
666        self._p.start()
667        return Filter(
668            name=self._name, tl_func="IPC", socket=self._socket_path, func=self._name
669        )
670
671    def _handle_connection(self, connection):
672        try:
673            while True:
674                frame_len = connection.recv(4)
675                if not frame_len or len(frame_len) != 4:
676                    break
677                frame_len = int.from_bytes(frame_len, byteorder="big")
678                data = connection.recv(frame_len)
679                if not data:
680                    break
681
682                while len(data) < frame_len:
683                    new_data = connection.recv(frame_len - len(data))
684                    if not new_data:
685                        raise Exception("Partial data received")
686                    data += new_data
687
688                obj = msgpack.unpackb(data, raw=False)
689                f_func, f_op, f_args, f_kwargs = (
690                    obj["func"],
691                    obj["op"],
692                    obj["args"],
693                    obj["kwargs"],
694                )
695
696                response = None
697                if f_op == "filter":
698                    f_args = [self._deser_filter(x) for x in f_args]
699                    f_kwargs = {k: self._deser_filter(v) for k, v in f_kwargs}
700                    response = self.filter(*f_args, **f_kwargs)
701                    if type(response) != UDFFrame:
702                        raise Exception(
703                            f"filter must return a UDFFrame, got {type(response)}"
704                        )
705                    if response.frame_type().pix_fmt() != "rgb24":
706                        raise Exception(
707                            f"filter must return a frame with pix_fmt 'rgb24', got {response.frame_type().pix_fmt()}"
708                        )
709
710                    response = response._response_ser()
711                elif f_op == "filter_type":
712                    f_args = [self._deser_filter_type(x) for x in f_args]
713                    f_kwargs = {k: self._deser_filter_type(v) for k, v in f_kwargs}
714                    response = self.filter_type(*f_args, **f_kwargs)
715                    if type(response) != UDFFrameType:
716                        raise Exception(
717                            f"filter_type must return a UDFFrameType, got {type(response)}"
718                        )
719                    if response.pix_fmt() != "rgb24":
720                        raise Exception(
721                            f"filter_type must return a frame with pix_fmt 'rgb24', got {response.pix_fmt()}"
722                        )
723                    response = response._response_ser()
724                else:
725                    raise Exception(f"Unknown operation: {f_op}")
726
727                response = msgpack.packb(response, use_bin_type=True)
728                response_len = len(response).to_bytes(4, byteorder="big")
729                connection.sendall(response_len)
730                connection.sendall(response)
731        finally:
732            connection.close()
733
734    def _deser_filter_type(self, obj):
735        assert type(obj) == dict
736        keys = list(obj.keys())
737        assert len(keys) == 1
738        type_key = keys[0]
739        assert type_key in ["Frame", "String", "Int", "Bool"]
740
741        if type_key == "Frame":
742            frame = obj[type_key]
743            assert type(frame) == dict
744            assert "width" in frame
745            assert "height" in frame
746            assert "format" in frame
747            assert type(frame["width"]) == int
748            assert type(frame["height"]) == int
749            assert frame["format"] == 2  # AV_PIX_FMT_RGB24
750            return UDFFrameType(frame["width"], frame["height"], "rgb24")
751        elif type_key == "String":
752            assert type(obj[type_key]) == str
753            return obj[type_key]
754        elif type_key == "Int":
755            assert type(obj[type_key]) == int
756            return obj[type_key]
757        elif type_key == "Bool":
758            assert type(obj[type_key]) == bool
759            return obj[type_key]
760        else:
761            assert False
762
763    def _deser_filter(self, obj):
764        assert type(obj) == dict
765        keys = list(obj.keys())
766        assert len(keys) == 1
767        type_key = keys[0]
768        assert type_key in ["Frame", "String", "Int", "Bool"]
769
770        if type_key == "Frame":
771            frame = obj[type_key]
772            assert type(frame) == dict
773            assert "data" in frame
774            assert "width" in frame
775            assert "height" in frame
776            assert "format" in frame
777            assert type(frame["width"]) == int
778            assert type(frame["height"]) == int
779            assert frame["format"] == "rgb24"
780            assert type(frame["data"]) == bytes
781
782            data = np.frombuffer(frame["data"], dtype=np.uint8)
783            data = data.reshape(frame["height"], frame["width"], 3)
784            return UDFFrame(
785                data, UDFFrameType(frame["width"], frame["height"], "rgb24")
786            )
787        elif type_key == "String":
788            assert type(obj[type_key]) == str
789            return obj[type_key]
790        elif type_key == "Int":
791            assert type(obj[type_key]) == int
792            return obj[type_key]
793        elif type_key == "Bool":
794            assert type(obj[type_key]) == bool
795            return obj[type_key]
796        else:
797            assert False
798
799    def _host(self, socket_path: str):
800        if os.path.exists(socket_path):
801            os.remove(socket_path)
802
803        # start listening on the socket
804        sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
805        sock.bind(socket_path)
806        sock.listen(1)
807
808        while True:
809            # accept incoming connection
810            connection, client_address = sock.accept()
811            thread = threading.Thread(
812                target=self._handle_connection, args=(connection,)
813            )
814            thread.start()
815
816    def __del__(self):
817        if self._socket_path is not None:
818            self._p.terminate()
819            if os.path.exists(self._socket_path):
820                # it's possible the process hasn't even created the socket yet
821                os.remove(self._socket_path)

User-defined filter superclass

UDF(name: str)
649    def __init__(self, name: str):
650        self._name = name
651        self._socket_path = None
652        self._p = None
def filter(self, *args, **kwargs):
654    def filter(self, *args, **kwargs):
655        raise Exception("User must implement the filter method")
def filter_type(self, *args, **kwargs):
657    def filter_type(self, *args, **kwargs):
658        raise Exception("User must implement the filter_type method")
def into_filter(self):
660    def into_filter(self):
661        assert self._socket_path is None
662        self._socket_path = f"/tmp/vidformer-{self._name}-{str(uuid.uuid4())}.sock"
663        self._p = multiprocessing.Process(
664            target=_run_udf_host, args=(self, self._socket_path)
665        )
666        self._p.start()
667        return Filter(
668            name=self._name, tl_func="IPC", socket=self._socket_path, func=self._name
669        )
class UDFFrameType:
824class UDFFrameType:
825    def __init__(self, width: int, height: int, pix_fmt: str):
826        assert type(width) == int
827        assert type(height) == int
828        assert type(pix_fmt) == str
829
830        self._width = width
831        self._height = height
832        self._pix_fmt = pix_fmt
833
834    def width(self):
835        return self._width
836
837    def height(self):
838        return self._height
839
840    def pix_fmt(self):
841        return self._pix_fmt
842
843    def _response_ser(self):
844        return {
845            "frame_type": {
846                "width": self._width,
847                "height": self._height,
848                "format": 2,  # AV_PIX_FMT_RGB24
849            }
850        }
851
852    def __repr__(self):
853        return f"FrameType<{self._width}x{self._height}, {self._pix_fmt}>"
UDFFrameType(width: int, height: int, pix_fmt: str)
825    def __init__(self, width: int, height: int, pix_fmt: str):
826        assert type(width) == int
827        assert type(height) == int
828        assert type(pix_fmt) == str
829
830        self._width = width
831        self._height = height
832        self._pix_fmt = pix_fmt
def width(self):
834    def width(self):
835        return self._width
def height(self):
837    def height(self):
838        return self._height
def pix_fmt(self):
840    def pix_fmt(self):
841        return self._pix_fmt
class UDFFrame:
856class UDFFrame:
857    def __init__(self, data: np.ndarray, f_type: UDFFrameType):
858        assert type(data) == np.ndarray
859        assert type(f_type) == UDFFrameType
860
861        # We only support RGB24 for now
862        assert data.dtype == np.uint8
863        assert data.shape[2] == 3
864
865        # check type matches
866        assert data.shape[0] == f_type.height()
867        assert data.shape[1] == f_type.width()
868        assert f_type.pix_fmt() == "rgb24"
869
870        self._data = data
871        self._f_type = f_type
872
873    def data(self):
874        return self._data
875
876    def frame_type(self):
877        return self._f_type
878
879    def _response_ser(self):
880        return {
881            "frame": {
882                "data": self._data.tobytes(),
883                "width": self._f_type.width(),
884                "height": self._f_type.height(),
885                "format": "rgb24",
886            }
887        }
888
889    def __repr__(self):
890        return f"Frame<{self._f_type.width()}x{self._f_type.height()}, {self._f_type.pix_fmt()}>"
UDFFrame(data: numpy.ndarray, f_type: UDFFrameType)
857    def __init__(self, data: np.ndarray, f_type: UDFFrameType):
858        assert type(data) == np.ndarray
859        assert type(f_type) == UDFFrameType
860
861        # We only support RGB24 for now
862        assert data.dtype == np.uint8
863        assert data.shape[2] == 3
864
865        # check type matches
866        assert data.shape[0] == f_type.height()
867        assert data.shape[1] == f_type.width()
868        assert f_type.pix_fmt() == "rgb24"
869
870        self._data = data
871        self._f_type = f_type
def data(self):
873    def data(self):
874        return self._data
def frame_type(self):
876    def frame_type(self):
877        return self._f_type