From 97a1f15bc5f396c417b3550ced3cbedf54bf4c3b Mon Sep 17 00:00:00 2001 From: Jonathan Lennox Date: Mon, 29 Jul 2024 13:34:59 -0700 Subject: [PATCH] Fix: Support VP9 flexible mode. (#2199) VP9 flexible mode doesn't announce temporal layers in advance, so add them to the encoding desc as they are encountered. --- .../kotlin/org/jitsi/nlj/MediaSourceDesc.kt | 8 ++ .../main/kotlin/org/jitsi/nlj/RtpLayerDesc.kt | 2 +- .../nlj/rtp/codec/av1/Av1DDRtpLayerDesc.kt | 7 +- .../org/jitsi/nlj/rtp/codec/vp9/Vp9Packet.kt | 3 + .../org/jitsi/nlj/rtp/codec/vp9/Vp9Parser.kt | 80 +++++++++++++++++-- .../nlj/rtp/codec/vpx/VpxRtpLayerDesc.kt | 10 ++- .../node/incoming/BitrateCalculator.kt | 8 +- .../cc/allocation/BitrateControllerTest.kt | 2 +- 8 files changed, 104 insertions(+), 16 deletions(-) diff --git a/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/MediaSourceDesc.kt b/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/MediaSourceDesc.kt index 2a6fb30188..b529791bcc 100644 --- a/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/MediaSourceDesc.kt +++ b/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/MediaSourceDesc.kt @@ -148,6 +148,14 @@ class MediaSourceDesc @Synchronized fun findRtpEncodingDesc(ssrc: Long): RtpEncodingDesc? = rtpEncodings.find { it.matches(ssrc) } + @Synchronized + fun getEncodingLayers(ssrc: Long): Array { + val enc = findRtpEncodingDesc(ssrc) ?: return emptyArray() + return Array(enc.layers.size) { i -> + enc.layers[i].copy() + } + } + @Synchronized fun setEncodingLayers(layers: Array, ssrc: Long) { val enc = findRtpEncodingDesc(ssrc) ?: return diff --git a/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/RtpLayerDesc.kt b/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/RtpLayerDesc.kt index 9cd30d9383..49f0474c45 100644 --- a/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/RtpLayerDesc.kt +++ b/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/RtpLayerDesc.kt @@ -56,7 +56,7 @@ constructor( */ val frameRate: Double, ) { - abstract fun copy(height: Int = this.height): RtpLayerDesc + abstract fun copy(height: Int = this.height, tid: Int = this.tid, inherit: Boolean = true): RtpLayerDesc /** * The [BitrateTracker] instance used to calculate the receiving bitrate of this RTP layer. diff --git a/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/rtp/codec/av1/Av1DDRtpLayerDesc.kt b/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/rtp/codec/av1/Av1DDRtpLayerDesc.kt index 38952a2415..2d1eccbf0b 100644 --- a/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/rtp/codec/av1/Av1DDRtpLayerDesc.kt +++ b/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/rtp/codec/av1/Av1DDRtpLayerDesc.kt @@ -52,7 +52,12 @@ class Av1DDRtpLayerDesc( */ frameRate: Double, ) : RtpLayerDesc(eid, tid, sid, height, frameRate) { - override fun copy(height: Int): RtpLayerDesc = Av1DDRtpLayerDesc(eid, dt, tid, sid, height, frameRate) + override fun copy(height: Int, tid: Int, inherit: Boolean): RtpLayerDesc = + Av1DDRtpLayerDesc(eid, dt, tid, sid, height, frameRate).also { + if (inherit) { + it.inheritFrom(this) + } + } override val layerId = dt override val index = getIndex(eid, dt) diff --git a/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/rtp/codec/vp9/Vp9Packet.kt b/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/rtp/codec/vp9/Vp9Packet.kt index ab05bf5d82..3b1a31e99b 100644 --- a/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/rtp/codec/vp9/Vp9Packet.kt +++ b/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/rtp/codec/vp9/Vp9Packet.kt @@ -96,6 +96,9 @@ class Vp9Packet private constructor( val hasExtendedPictureId = DePacketizer.VP9PayloadDescriptor.hasExtendedPictureId(buffer, payloadOffset, payloadLength) + val isFlexibleMode = + DePacketizer.VP9PayloadDescriptor.isFlexibleMode(buffer, payloadOffset, payloadLength) + val hasScalabilityStructure = DePacketizer.VP9PayloadDescriptor.hasScalabilityStructure(buffer, payloadOffset, payloadLength) diff --git a/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/rtp/codec/vp9/Vp9Parser.kt b/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/rtp/codec/vp9/Vp9Parser.kt index d1fa5a9493..05300fff8b 100644 --- a/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/rtp/codec/vp9/Vp9Parser.kt +++ b/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/rtp/codec/vp9/Vp9Parser.kt @@ -18,12 +18,14 @@ package org.jitsi.nlj.rtp.codec.vp9 import org.jitsi.nlj.MediaSourceDesc import org.jitsi.nlj.PacketInfo +import org.jitsi.nlj.RtpLayerDesc import org.jitsi.nlj.rtp.codec.VideoCodecParser import org.jitsi.nlj.rtp.codec.vpx.VpxRtpLayerDesc import org.jitsi.nlj.util.StateChangeLogger import org.jitsi.rtp.extensions.toHex import org.jitsi.utils.logging2.Logger import org.jitsi.utils.logging2.createChildLogger +import kotlin.math.max /** * Some [Vp9Packet] fields are not able to be determined by looking at a single VP9 packet (for example the scalability @@ -40,13 +42,23 @@ class Vp9Parser( private val extendedPictureIdState = StateChangeLogger("missing extended picture ID", logger) private var numSpatialLayers = -1 - /** Encodings we've actually seen. Used to clear out inferred-from-signaling encoding information. */ - private val ssrcsSeen = HashSet() + /** Encodings we've actually seen, and the layers seen for each one. + * Used to clear out inferred-from-signaling encoding information, and to synthesize temporal layers + * for flexible-mode encodings. */ + private val ssrcsInfo = HashMap>() override fun parse(packetInfo: PacketInfo) { val vp9Packet = packetInfo.packetAs() - ssrcsSeen.add(vp9Packet.ssrc) + val layerMap = ssrcsInfo.getOrPut(vp9Packet.ssrc) { + HashMap() + } + + layerMap[vp9Packet.spatialLayerIndex]?.let { + layerMap[vp9Packet.spatialLayerIndex] = max(it, vp9Packet.temporalLayerIndex) + } ?: run { + layerMap[vp9Packet.spatialLayerIndex] = vp9Packet.temporalLayerIndex + } if (vp9Packet.hasScalabilityStructure) { // TODO: handle case where new SS is from a packet older than the @@ -58,12 +70,31 @@ class Vp9Parser( } numSpatialLayers = packetSpatialLayers } - findRtpEncodingDesc(vp9Packet)?.let { enc -> - vp9Packet.getScalabilityStructure(eid = enc.eid)?.let { - source.setEncodingLayers(it.layers, vp9Packet.ssrc) - } + val ss = findRtpEncodingDesc(vp9Packet)?.let { enc -> + vp9Packet.getScalabilityStructure(eid = enc.eid) + } + + if (ss != null) { + val layers = + if (vp9Packet.isFlexibleMode) { + /* In flexible mode, the number of temporal layers isn't announced in the keyframe. + * Thus, add temporal layer information to the source's encoding layers based on the temporal + * layers we've seen previously. + */ + val layersList = ss.layers.toMutableList() + + for ((sid, maxTid) in layerMap) { + addTemporalLayers(layersList, sid, maxTid) + } + layersList.toTypedArray() + } else { + ss.layers + } + + source.setEncodingLayers(layers, vp9Packet.ssrc) + for (otherEnc in source.rtpEncodings) { - if (!ssrcsSeen.contains(otherEnc.primarySSRC)) { + if (!ssrcsInfo.contains(otherEnc.primarySSRC)) { source.setEncodingLayers(emptyArray(), otherEnc.primarySSRC) } } @@ -82,6 +113,19 @@ class Vp9Parser( } } + if (vp9Packet.isFlexibleMode && findRtpLayerDescs(vp9Packet).isEmpty()) { + val layers = source.getEncodingLayers(vp9Packet.ssrc).toMutableList() + /* In flexible mode, the number of temporal layers isn't announced in the keyframe. + * Thus, add temporal layer information to the source's encoding layers as we see packets with + * temporal layers. + */ + val changed = addTemporalLayers(layers, vp9Packet.spatialLayerIndex, vp9Packet.temporalLayerIndex) + if (changed) { + source.setEncodingLayers(layers.toTypedArray(), vp9Packet.ssrc) + packetInfo.layeringChanged = true + } + } + pictureIdState.setState(vp9Packet.hasPictureId, vp9Packet) { "Packet Data: ${vp9Packet.toHex(80)}" } @@ -89,4 +133,24 @@ class Vp9Parser( "Packet Data: ${vp9Packet.toHex(80)}" } } + + /** Add temporal layers to the list of layers. Needed if VP9 is encoded in flexible mode, because + * in flexible mode the scalability structure doesn't describe the temporal layers. + */ + private fun addTemporalLayers(layers: MutableList, sid: Int, maxTid: Int): Boolean { + var changed = false + + for (tid in 1..maxTid) { + val layer = layers.find { it.sid == sid && it.tid == tid } + if (layer == null) { + val prevLayer = layers.find { it.sid == sid && it.tid == tid - 1 } + if (prevLayer != null) { + val newLayer = prevLayer.copy(tid = tid, inherit = false) + layers.add(newLayer) + changed = true + } + } + } + return changed + } } diff --git a/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/rtp/codec/vpx/VpxRtpLayerDesc.kt b/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/rtp/codec/vpx/VpxRtpLayerDesc.kt index 9c42b14e04..0244e0bd1a 100644 --- a/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/rtp/codec/vpx/VpxRtpLayerDesc.kt +++ b/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/rtp/codec/vpx/VpxRtpLayerDesc.kt @@ -70,19 +70,21 @@ constructor( } /** - * Clone an existing layer desc, inheriting its statistics, + * Clone an existing layer desc, inheriting its statistics if [inherit], * modifying only specific values. */ - override fun copy(height: Int) = VpxRtpLayerDesc( + override fun copy(height: Int, tid: Int, inherit: Boolean) = VpxRtpLayerDesc( eid = this.eid, - tid = this.tid, + tid = tid, sid = this.sid, height = height, frameRate = this.frameRate, dependencyLayers = this.dependencyLayers, softDependencyLayers = this.softDependencyLayers ).also { - it.inheritFrom(this) + if (inherit) { + it.inheritFrom(this) + } } /** diff --git a/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/transform/node/incoming/BitrateCalculator.kt b/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/transform/node/incoming/BitrateCalculator.kt index 2daf4cc1da..e20eb85f7e 100644 --- a/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/transform/node/incoming/BitrateCalculator.kt +++ b/jitsi-media-transform/src/main/kotlin/org/jitsi/nlj/transform/node/incoming/BitrateCalculator.kt @@ -56,7 +56,13 @@ class VideoBitrateCalculator( val videoRtpPacket: VideoRtpPacket = packetInfo.packet as VideoRtpPacket val now = clock.millis() - mediaSourceDescs.findRtpLayerDescs(videoRtpPacket).forEach { + val layerDescs = mediaSourceDescs.findRtpLayerDescs(videoRtpPacket) + + if (layerDescs.isEmpty()) { + logger.warn("No layer found for packet $videoRtpPacket") + } + + layerDescs.forEach { if (it.updateBitrate(videoRtpPacket.length.bytes, now)) { /* When a layer is started when it was previously inactive, * we want to recalculate bandwidth allocation. diff --git a/jvb/src/test/kotlin/org/jitsi/videobridge/cc/allocation/BitrateControllerTest.kt b/jvb/src/test/kotlin/org/jitsi/videobridge/cc/allocation/BitrateControllerTest.kt index 4649e80586..a3f0f13156 100644 --- a/jvb/src/test/kotlin/org/jitsi/videobridge/cc/allocation/BitrateControllerTest.kt +++ b/jvb/src/test/kotlin/org/jitsi/videobridge/cc/allocation/BitrateControllerTest.kt @@ -1528,7 +1528,7 @@ class MockRtpLayerDesc( var bitrate: Bandwidth, sid: Int = -1 ) : RtpLayerDesc(eid, tid, sid, height, frameRate) { - override fun copy(height: Int): RtpLayerDesc { + override fun copy(height: Int, tid: Int, inherit: Boolean): RtpLayerDesc { TODO("Not yet implemented") }