-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathcreate_2gram_hmm.py
141 lines (118 loc) · 4.9 KB
/
create_2gram_hmm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
'''
create_2gram_hmm.py
Vincent Soesanto
CSE 415
Autumn 2019
This script generates an bigram model of a pos-enhanced input corpus 'training_data' using the ARPA format.
Usage: cat training_data | create_2gram_hmm.sh output_hmm
'''
import sys
import math
# command line args
input_file = sys.stdin
output_file_name = sys.argv[1]
# home development
# input_file_name = "examples/wsj_sec0.word_pos"
# output_file_name = "wsj_hmm_2g"
# global variables
transition_prob = {} # state to state
emission_prob = {} # state to obsv
word_unigrams = {}
tag_unigrams = {}
def take_inventory(split_line):
for i in range(len(split_line)):
item = split_line[i]
# by default, item is split by "/" which will capture the "/" in the word "</s>"
if "</s>" in item:
word = "</s>"
tag = "EOS"
else:
pair = split_line[i].rsplit("/", maxsplit=1)
word = pair[0]
tag = pair[1]
split_line[i] = (word, tag) # modify split_line by replacing string with tuple
# print(split_line) # report modified split_line
for i in range(len(split_line)):
# state to state, e.g. BOS N
pos_pos_ngram = ""
for j in range(2):
if i + j < len(split_line):
pos_pos_ngram += split_line[i + j][1] + " "
pos_pos_ngram = pos_pos_ngram.strip(" ")
if len(pos_pos_ngram.split(" ")) == 2:
# print("pos-pos ngram = " + pos_pos_ngram)
if pos_pos_ngram not in transition_prob:
transition_prob[pos_pos_ngram] = [0, split_line[i][1]]
transition_prob[pos_pos_ngram] = [transition_prob[pos_pos_ngram][0] + 1, split_line[i][1]]
# state to obsv, e.g. A cool
pos_word_ngram = ""
for j in range(1, -1, -1):
pos_word_ngram += split_line[i][j] + " "
pos_word_ngram = pos_word_ngram.strip(" ")
if len(pos_word_ngram.split(" ")) == 2:
# print("pos-word ngram = " + pos_word_ngram)
if pos_word_ngram not in emission_prob:
emission_prob[pos_word_ngram] = [0, split_line[i][1]]
emission_prob[pos_word_ngram] = [emission_prob[pos_word_ngram][0] + 1, split_line[i][1]]
# take word_unigram and tag_unigram counts
word = split_line[i][0]
tag = split_line[i][1]
if word not in word_unigrams:
word_unigrams[word] = 0
word_unigrams[word] = word_unigrams[word] + 1
if tag not in tag_unigrams:
tag_unigrams[tag] = 0
tag_unigrams[tag] = tag_unigrams[tag] + 1
# print()
# DRIVER
# input from stdin
for line in input_file:
split_line = ["<s>/BOS"] + line.strip().split(" ") + ["</s>/EOS"]
take_inventory(split_line)
# running on an ide
# with open(input_file_name, "r") as input_file:
# for line in input_file:
# split_line = ["<s>/BOS"] + line.strip("\n").split(" ") + ["</s>/EOS"]
# # print(split_line)
# take_inventory(split_line)
# report
with open(output_file_name, "w") as output_file:
# print("state_num=" + str(len(tag_unigrams)))
# print("sym_num=" + str(len(word_unigrams)))
# print("init_line_num=1")
# print("trans_line_num=" + str(len(transition_prob)))
# print("emiss_line_num=" + str(len(emission_prob)) + "\n")
# print("\init")
output_file.write("state_num=" + str(len(tag_unigrams)) + "\n")
output_file.write("sym_num=" + str(len(word_unigrams)) + "\n")
output_file.write("init_line_num=1\n")
output_file.write("trans_line_num=" + str(len(transition_prob)) + "\n")
output_file.write("emiss_line_num=" + str(len(emission_prob)) + "\n\n")
output_file.write("\init\n")
# get init pob
init_count = 0
for key in transition_prob:
if "BOS" in key:
init_count += transition_prob[key][0]
init_prob = init_count / tag_unigrams["BOS"]
log_init_prob = math.log10(init_prob)
# print("BOS " + str('{:.10f}'.format(init_prob)) + " " + str('{:.10f}'.format(log_init_prob)) + "\n")
output_file.write("BOS " + str('{:.10f}'.format(init_prob)) + " " + str('{:.10f}'.format(log_init_prob)) + "\n\n")
for prob_type in ['transition', 'emission']:
# print("\\" + prob_type)
output_file.write("\\" + prob_type + "\n")
inventory = transition_prob
if prob_type == "emission":
inventory = emission_prob
# sort keys alphanumerically
sorted_keys = sorted(inventory.keys())
for key in sorted_keys:
# get prob
numerator = inventory[key][0]
denominator = tag_unigrams[inventory[key][1]]
p = numerator / denominator
log_p = math.log10(p)
# print(key + " " + str('{:.10f}'.format(p)) + " " + str("{:.10f}".format(math.log10(p))))
output_file.write(key + " " + str('{:.10f}'.format(p)) + " " + str("{:.10f}".format(math.log10(p))) + "\n")
# print("\n")
output_file.write("\n")