From e20108878caf51f97b151e45cf127212487fb56a Mon Sep 17 00:00:00 2001 From: Damian at mba Date: Thu, 27 Oct 2022 21:17:23 +0200 Subject: [PATCH] fix attention weight inside .swap() --- ldm/invoke/prompt_parser.py | 13 +++++++------ tests/test_prompt_parser.py | 27 +++++++++++++++++++++++++++ 2 files changed, 34 insertions(+), 6 deletions(-) diff --git a/ldm/invoke/prompt_parser.py b/ldm/invoke/prompt_parser.py index 8807c7986b..4a6d470140 100644 --- a/ldm/invoke/prompt_parser.py +++ b/ldm/invoke/prompt_parser.py @@ -129,7 +129,7 @@ class CrossAttentionControlSubstitute(CrossAttentionControlledFragment): default_options = { 's_start': 0.0, - 's_end': 0.206, # ~= shape_freedom=0.5 + 's_end': 0.2062994740159002, # ~= shape_freedom=0.5 't_start': 0.0, 't_end': 1.0 } @@ -145,7 +145,7 @@ class CrossAttentionControlSubstitute(CrossAttentionControlledFragment): # so for shape_freedom = 0.5 we probably want s_end to be 0.2 # -> cube root and subtract from 1.0 merged_options['s_end'] = 1.0 - shape_freedom ** (1. / 3.) - print('converted shape_freedom argument to', merged_options) + #print('converted shape_freedom argument to', merged_options) merged_options.update(options) self.options = merged_options @@ -514,10 +514,11 @@ def build_parser_syntax(attention_plus_base: float, attention_minus_base: float) # cross attention control debug_cross_attention_control = False - original_fragment = pp.Or([empty_string.set_debug(debug_cross_attention_control), + original_fragment = pp.MatchFirst([ quoted_fragment.set_debug(debug_cross_attention_control), parenthesized_fragment.set_debug(debug_cross_attention_control), - pp.Word(pp.printables, exclude_chars=string.whitespace + '.').set_parse_action(make_text_fragment) + pp.FollowedBy(".swap") + pp.Word(pp.printables, exclude_chars=string.whitespace + '.').set_parse_action(make_text_fragment) + pp.FollowedBy(".swap"), + empty_string.set_debug(debug_cross_attention_control), ]) # support keyword=number arguments cross_attention_option_keyword = pp.Or([pp.Keyword("s_start"), pp.Keyword("s_end"), pp.Keyword("t_start"), pp.Keyword("t_end"), pp.Keyword("shape_freedom")]) @@ -525,8 +526,8 @@ def build_parser_syntax(attention_plus_base: float, attention_minus_base: float) edited_fragment = pp.MatchFirst([ (lparen + rparen).set_parse_action(lambda x: Fragment('')), lparen + - (quoted_fragment | - pp.Group(pp.ZeroOrMore(pp.Word(pp.printables, exclude_chars=string.whitespace + ',').set_parse_action(make_text_fragment))) + (quoted_fragment | attention | + pp.Group(pp.ZeroOrMore(build_escaped_word_parser_charbychar(',)').set_parse_action(make_text_fragment))) ) + pp.Dict(pp.ZeroOrMore(comma + cross_attention_option)) + rparen, diff --git a/tests/test_prompt_parser.py b/tests/test_prompt_parser.py index 4fd7616ade..f2ac1b9999 100644 --- a/tests/test_prompt_parser.py +++ b/tests/test_prompt_parser.py @@ -250,6 +250,33 @@ class PromptParserTestCase(unittest.TestCase): Fragment(',', 1), Fragment('fire', 2.0)])]) self.assertEqual(flames_to_trees_fire, parse_prompt('"(fire (flames)0.5)0.5".swap("(trees)0.7 houses"), (fire)2.0')) + self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a', 1), + CrossAttentionControlSubstitute([Fragment('cat',1)], [Fragment('dog',1)]), + Fragment('eating a', 1), + CrossAttentionControlSubstitute([Fragment('hotdog',1)], [Fragment('hotdog', pow(1.1,4))]) + ])]), + parse_prompt("a cat.swap(dog) eating a hotdog.swap(hotdog++++, shape_freedom=0.5)")) + + self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a', 1), + CrossAttentionControlSubstitute([Fragment('cat',1)], [Fragment('dog',1)]), + Fragment('eating a', 1), + CrossAttentionControlSubstitute([Fragment('hotdog',1)], [Fragment('hotdog', pow(1.1,4))]) + ])]), + parse_prompt("a cat.swap(dog) eating a hotdog.swap(\"hotdog++++\", shape_freedom=0.5)")) + + self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a', 1), + CrossAttentionControlSubstitute([Fragment('cat',1)], [Fragment('dog',1)]), + Fragment('eating a', 1), + CrossAttentionControlSubstitute([Fragment('hotdog',1)], [Fragment('h(o)tdog', pow(1.1,4))]) + ])]), + parse_prompt("a cat.swap(dog) eating a hotdog.swap(h\(o\)tdog++++, shape_freedom=0.5)")) + self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a', 1), + CrossAttentionControlSubstitute([Fragment('cat',1)], [Fragment('dog',1)]), + Fragment('eating a', 1), + CrossAttentionControlSubstitute([Fragment('hotdog',1)], [Fragment('h(o)tdog', pow(1.1,4))]) + ])]), + parse_prompt("a cat.swap(dog) eating a hotdog.swap(\"h\(o\)tdog++++\", shape_freedom=0.5)")) + def test_cross_attention_control_options(self): self.assertEqual(Conjunction([ FlattenedPrompt([Fragment('a', 1),