diff --git a/src/components/TreeReactFlow/Types.ts b/src/components/TreeReactFlow/Types.ts index d462b2ae..f9b50841 100644 --- a/src/components/TreeReactFlow/Types.ts +++ b/src/components/TreeReactFlow/Types.ts @@ -1,4 +1,5 @@ import { + EdgeFlavor, GlobalName, NodeFlavorBoxBody, NodeFlavorNoBody, @@ -229,9 +230,19 @@ export type PrimerEdge = { sourceHandle: Position; targetHandle: Position; zIndex: number; + // label?: string; + // TODO why? this seems like a bad abstraction + // but we can't specify classes in `EdgeBase` + // https://github.com/xyflow/xyflow/issues/420 + // ReactFlow seems to magically make use of this class instead + // in _nodes_ we can use classes, since we're not using an equivalent `NodeBase` + className: string; } & ({ type: "primer"; data: PrimerEdgeProps } | { type: "primer-def" }); -export type PrimerEdgeProps = { flavor: NodeFlavor }; +export type PrimerEdgeProps = { + flavor: NodeFlavor; + edgeFlavor: EdgeFlavor; +}; export type Positioned = T & { position: { x: number; y: number }; diff --git a/src/components/TreeReactFlow/index.tsx b/src/components/TreeReactFlow/index.tsx index 4ee33002..c9bd21e8 100644 --- a/src/components/TreeReactFlow/index.tsx +++ b/src/components/TreeReactFlow/index.tsx @@ -7,6 +7,7 @@ import { NodeType, Level, TypeDef, + EdgeFlavor, } from "@/primer-api"; import type { NodeChange } from "@xyflow/react"; import { @@ -23,6 +24,8 @@ import { getBezierPath, EdgeTypes, useReactFlow, + BaseEdge, + EdgeLabelRenderer, } from "@xyflow/react"; import "./reactflow.css"; import { MutableRefObject, PropsWithChildren, useId } from "react"; @@ -55,6 +58,7 @@ import { } from "./Types"; import { LayoutParams, layoutTree } from "./layoutTree"; import { + NodeFlavor, boxFlavorBackground, commonHoverClasses, flavorClasses, @@ -154,7 +158,7 @@ export const inlineTreeReactFlowProps: typeof defaultTreeReactFlowProps = { boxPadding: 35, layout: { ...defaultTreeReactFlowProps.layout, - margins: { child: 15, sibling: 12 }, + margins: { child: 19, sibling: 12 }, }, }; @@ -456,6 +460,8 @@ const edgeTypes = { primer: ({ data, id, + // TODO split out common edge props to ensure this is inferred to be a string + // label, sourceX, sourceY, sourcePosition, @@ -463,7 +469,7 @@ const edgeTypes = { targetY, targetPosition, }: EdgeProps & { data: PrimerEdgeProps }) => { - const [edgePath] = getBezierPath({ + const [edgePath, labelX, labelY] = getBezierPath({ sourceX, sourceY, sourcePosition, @@ -471,15 +477,70 @@ const edgeTypes = { targetY, targetPosition, }); + const label = (() => { + switch (data.edgeFlavor) { + case "Hole": + return undefined; + case "AnnTerm": + return undefined; + case "Ann": + return undefined; + case "AppFun": + return undefined; + case "AppArg": + return undefined; + case "ConField": + return undefined; + case "Lam": + return undefined; + case "LetEqual": + return "="; + case "LetIn": + return "in"; + case "MatchInput": + return undefined; + case "Pattern": + return undefined; + case "MatchOutput": + return undefined; + case "FunIn": + return undefined; + case "FunOut": + return undefined; + case "ForallKind": + return undefined; + case "Forall": + return undefined; + case "Bind": + return undefined; + } + })(); return ( - + + {label && ( + +
+ {label} +
+
)} - d={edgePath} - /> + ); }, "primer-def": ({ @@ -490,8 +551,9 @@ const edgeTypes = { targetX, targetY, targetPosition, + label, }: EdgeProps) => { - const [edgePath] = getSmoothStepPath({ + const [edgePath, labelX, labelY] = getSmoothStepPath({ sourceX, sourceY, sourcePosition, @@ -501,12 +563,24 @@ const edgeTypes = { offset: 0, }); return ( - + <> + + +
+ {label} +
+
+ ); }, }; @@ -520,10 +594,14 @@ type APITreeNode = { const augmentTree = async ( tree: APITree, - f: (tree: APITreeNode) => Promise<[T, (child: T, isRight: boolean) => E]> + f: ( + tree: APITreeNode + ) => Promise<[T, (child: T, flavor: EdgeFlavor, isRight: boolean) => E]> ): Promise> => { const childTrees = await Promise.all( - tree.childTrees.map((t) => augmentTree(t.snd, f)) + tree.childTrees.map((t) => + augmentTree(t.snd, f).then((r) => [r, t.fst] as const) + ) ); const [node, makeEdge] = await f({ children: tree.childTrees.length + (tree.rightChild ? 1 : 0), @@ -533,10 +611,18 @@ const augmentTree = async ( ? augmentTree(tree.rightChild.snd, f) : undefined); return { - ...(rightChild - ? { rightChild: [rightChild, makeEdge(rightChild.node, true)] } + ...(rightChild && tree.rightChild + ? { + rightChild: [ + rightChild, + makeEdge(rightChild.node, tree.rightChild.fst, true), + ], + } : {}), - childTrees: childTrees.map((e) => [e, makeEdge(e.node, false)]), + childTrees: childTrees.map(([e, flavor]) => [ + e, + makeEdge(e.node, flavor, false), + ]), node, }; }; @@ -551,7 +637,7 @@ const makePrimerNode = async ( ): Promise< [ PrimerNode, - (child: PrimerNode, isRight: boolean) => PrimerEdge, + (child: PrimerNode, flavor: EdgeFlavor, isRight: boolean) => PrimerEdge, /* Nodes of nested trees, already positioned. We have to lay these out first in order to know the dimensions of boxes to be drawn around them.*/ PrimerGraph[], @@ -574,15 +660,17 @@ const makePrimerNode = async ( showIDs: p.showIDs, }; const edgeCommon = ( + flavor: NodeFlavor, child: PrimerNode, isRight: boolean - ): Omit => ({ + ): Omit => ({ id: JSON.stringify([id, child.id]), source: id, target: child.id, zIndex, sourceHandle: isRight ? Position.Right : Position.Bottom, targetHandle: isRight ? Position.Left : Position.Top, + className: flavorEdgeClasses(flavor), }); const width = (hideLabel: boolean) => p.style == "inline" && !hideLabel @@ -600,11 +688,10 @@ const makePrimerNode = async ( data: { contents: prim.contents, ...common }, zIndex, }, - (child, isRight) => ({ + (child, edgeFlavor, isRight) => ({ type: "primer", - data: { flavor }, - className: flavorEdgeClasses(flavor), - ...edgeCommon(child, isRight), + data: { flavor, edgeFlavor }, + ...edgeCommon(flavor, child, isRight), }), [], ]; @@ -631,11 +718,10 @@ const makePrimerNode = async ( }, zIndex, }, - (child, isRight) => ({ + (child, edgeFlavor, isRight) => ({ type: "primer", - data: { flavor }, - className: flavorEdgeClasses(flavor), - ...edgeCommon(child, isRight), + data: { flavor, edgeFlavor }, + ...edgeCommon(flavor, child, isRight), }), [], ]; @@ -657,21 +743,24 @@ const makePrimerNode = async ( }, zIndex, }, - (child, isRight) => ({ + (child, edgeFlavor, isRight) => ({ type: "primer", - data: { flavor }, - className: flavorEdgeClasses(flavor), - ...edgeCommon(child, isRight), + data: { flavor, edgeFlavor }, + ...edgeCommon(flavor, child, isRight), }), [], ]; } case "NoBody": { const flavor = node.body.contents; - const makeChild = (child: PrimerNode, isRight: boolean): PrimerEdge => ({ + const makeChild = ( + child: PrimerNode, + edgeFlavor: EdgeFlavor, + isRight: boolean + ): PrimerEdge => ({ type: "primer", - data: { flavor }, - ...edgeCommon(child, isRight), + data: { flavor, edgeFlavor }, + ...edgeCommon(flavor, child, isRight), }); if (p.level == "Beginner") { return [ @@ -747,11 +836,10 @@ const makePrimerNode = async ( }, zIndex, }, - (child, isRight) => ({ + (child, edgeFlavor, isRight) => ({ type: "primer", - data: { flavor }, - className: flavorEdgeClasses(flavor), - ...edgeCommon(child, isRight), + data: { flavor, edgeFlavor }, + ...edgeCommon(flavor, child, isRight), }), bodyNested.concat({ nodes: bodyLayout.nodes.map((node) => ({ @@ -943,6 +1031,7 @@ const defToTree = async ( zIndex: 0, sourceHandle: Position.Bottom, targetHandle: Position.Top, + className: "stroke-grey-tertiary", }, ]); const sigTree = await defEdge(def.type_, "SigNode", sigEdgeId); @@ -1024,6 +1113,7 @@ const typeDefToTree = async ( targetHandle: Position.Top, zIndex: 0, type: "primer-def", + className: "stroke-grey-tertiary", }; return [ { @@ -1041,6 +1131,7 @@ const typeDefToTree = async ( zIndex: 0, sourceHandle: Position.Right, targetHandle: Position.Left, + className: "stroke-grey-tertiary", }), ]; }, undefined); @@ -1079,6 +1170,7 @@ const typeDefToTree = async ( zIndex: 0, sourceHandle: Position.Bottom, targetHandle: Position.Top, + className: "stroke-grey-tertiary", }, ]) ) @@ -1118,6 +1210,7 @@ const typeDefToTree = async ( sourceHandle: Position.Bottom, targetHandle: Position.Top, zIndex: 0, + className: "stroke-grey-tertiary", }, ]; })