@@ -236,7 +236,7 @@ astProject1
236
236
=> AstTensor AstMethodLet s (TKProduct x z ) -> AstTensor AstMethodLet s x
237
237
astProject1 u = case u of
238
238
Ast. AstPair x _z -> x
239
- Ast. AstCond b v1 v2 -> Ast. AstCond b (astProject1 v1) (astProject1 v2)
239
+ Ast. AstCond b v1 v2 -> astCond b (astProject1 v1) (astProject1 v2)
240
240
Ast. AstLet var t v -> Ast. AstLet var t (astProject1 v)
241
241
Ast. AstFromPrimal u1 -> Ast. AstFromPrimal $ astProject1 u1
242
242
Ast. AstFromDual u1 -> Ast. AstFromDual $ astProject1 u1
@@ -249,7 +249,7 @@ astProject2
249
249
=> AstTensor AstMethodLet s (TKProduct x z ) -> AstTensor AstMethodLet s z
250
250
astProject2 u = case u of
251
251
Ast. AstPair _x z -> z
252
- Ast. AstCond b v1 v2 -> Ast. AstCond b (astProject2 v1) (astProject2 v2)
252
+ Ast. AstCond b v1 v2 -> astCond b (astProject2 v1) (astProject2 v2)
253
253
Ast. AstLet var t v -> Ast. AstLet var t (astProject2 v)
254
254
Ast. AstFromPrimal u1 -> Ast. AstFromPrimal $ astProject2 u1
255
255
Ast. AstFromDual u1 -> Ast. AstFromDual $ astProject2 u1
@@ -724,6 +724,7 @@ astCond :: AstBool AstMethodLet
724
724
-> AstTensor AstMethodLet s y -> AstTensor AstMethodLet s y
725
725
-> AstTensor AstMethodLet s y
726
726
astCond (AstBoolConst b) v w = if b then v else w
727
+ astCond (Ast. AstBoolNot b) v w = astCond b w v
727
728
astCond b (Ast. AstFromPrimal v) (Ast. AstFromPrimal w) =
728
729
Ast. AstFromPrimal $ astCond b v w
729
730
astCond b (Ast. AstFromDual v) (Ast. AstFromDual w) =
0 commit comments